This commit is contained in:
Erick Friis 2025-07-28 16:41:39 -04:00 committed by GitHub
commit 61127f00d4
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 64 additions and 9 deletions

View File

@ -13,7 +13,7 @@ from langchain_core.messages import (
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field from pydantic import ConfigDict, Field
class Chat__ModuleName__(BaseChatModel): class Chat__ModuleName__(BaseChatModel):
@ -266,7 +266,7 @@ class Chat__ModuleName__(BaseChatModel):
""" # noqa: E501 """ # noqa: E501
model_name: str = Field(alias="model") model: str = Field(alias="model_name")
"""The name of the model""" """The name of the model"""
parrot_buffer_length: int parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed.""" """The number of characters from the last message of the prompt to be echoed."""
@ -276,6 +276,10 @@ class Chat__ModuleName__(BaseChatModel):
stop: Optional[List[str]] = None stop: Optional[List[str]] = None
max_retries: int = 2 max_retries: int = 2
model_config = ConfigDict(
populate_by_name=True,
)
@property @property
def _llm_type(self) -> str: def _llm_type(self) -> str:
"""Return type of chat model.""" """Return type of chat model."""
@ -293,7 +297,7 @@ class Chat__ModuleName__(BaseChatModel):
# rules in LLM monitoring applications (e.g., in LangSmith users # rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor # can provide per token pricing for their model and monitor
# costs for the given LLM.) # costs for the given LLM.)
"model_name": self.model_name, "model": self.model,
} }
def _generate( def _generate(

View File

@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
async_client: Any = Field(default=None, exclude=True) #: :meta private: async_client: Any = Field(default=None, exclude=True) #: :meta private:
root_client: Any = Field(default=None, exclude=True) #: :meta private: root_client: Any = Field(default=None, exclude=True) #: :meta private:
root_async_client: Any = Field(default=None, exclude=True) #: :meta private: root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
model_name: str = Field(default="gpt-3.5-turbo", alias="model") model: str = Field(default="gpt-3.5-turbo", alias="model_name")
"""Model name to use.""" """Model name to use."""
temperature: Optional[float] = None temperature: Optional[float] = None
"""What sampling temperature to use.""" """What sampling temperature to use."""
@ -690,6 +690,10 @@ class BaseChatOpenAI(BaseChatModel):
model_config = ConfigDict(populate_by_name=True) model_config = ConfigDict(populate_by_name=True)
@property
def model_name(self) -> str:
return self.model
@model_validator(mode="before") @model_validator(mode="before")
@classmethod @classmethod
def build_extra(cls, values: dict[str, Any]) -> Any: def build_extra(cls, values: dict[str, Any]) -> Any:

View File

@ -84,6 +84,10 @@ class ChatModelTests(BaseStandardTests):
"""Initialization parameters for the chat model.""" """Initialization parameters for the chat model."""
return {} return {}
@property
def chat_model_model_param(self) -> dict:
return self.chat_model_params.get("model", "test-model-name")
@property @property
def standard_chat_model_params(self) -> dict: def standard_chat_model_params(self) -> dict:
""":private:""" """:private:"""
@ -833,6 +837,45 @@ class ChatModelUnitTests(ChatModelTests):
) )
assert model is not None assert model is not None
def test_model_param_name(self) -> None:
"""Tests model initialization with a ``model=`` parameter. This should pass for
all integrations.
.. dropdown:: Troubleshooting
If this test fails, ensure that the model can be initialized with a
``model`` parameter, and that the model parameter can be accessed as
``.model``.
If not, the easiest way to configure this is likely to add
``from pydantic import ConfigDict`` at the top of your file, and add a
``model_config`` class attribute to your model class:
.. code-block:: python
class MyChatModel(BaseChatModel):
model: str = Field(alias="model_name")
model_config = ConfigDict(populate_by_name=True)
# optional property for backwards-compatibility
# for folks accessing chat_model.model_name
@property
def model_name(self) -> str:
return self.model
"""
params = {
**self.standard_chat_model_params,
**self.chat_model_params,
}
if "model_name" in params:
params["model"] = params.pop("model_name")
else:
params["model"] = self.chat_model_model_param
model = self.chat_model_class(**params)
assert model is not None
assert model.model == params["model"] # type: ignore[attr-defined]
def test_init_from_env(self) -> None: def test_init_from_env(self) -> None:
"""Test initialization from environment variables. Relies on the """Test initialization from environment variables. Relies on the
``init_from_env_params`` property. Test is skipped if that property is not ``init_from_env_params`` property. Test is skipped if that property is not

View File

@ -12,7 +12,7 @@ from langchain_core.messages import (
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field from pydantic import ConfigDict, Field
class ChatParrotLink(BaseChatModel): class ChatParrotLink(BaseChatModel):
@ -35,7 +35,7 @@ class ChatParrotLink(BaseChatModel):
""" """
model_name: str = Field(alias="model") model: str = Field(alias="model_name")
"""The name of the model""" """The name of the model"""
parrot_buffer_length: int parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed.""" """The number of characters from the last message of the prompt to be echoed."""
@ -45,6 +45,10 @@ class ChatParrotLink(BaseChatModel):
stop: Optional[list[str]] = None stop: Optional[list[str]] = None
max_retries: int = 2 max_retries: int = 2
model_config = ConfigDict(
populate_by_name=True,
)
def _generate( def _generate(
self, self,
messages: list[BaseMessage], messages: list[BaseMessage],
@ -81,7 +85,7 @@ class ChatParrotLink(BaseChatModel):
additional_kwargs={}, # Used to add additional payload to the message additional_kwargs={}, # Used to add additional payload to the message
response_metadata={ # Use for response metadata response_metadata={ # Use for response metadata
"time_in_seconds": 3, "time_in_seconds": 3,
"model_name": self.model_name, "model_name": self.model,
}, },
usage_metadata={ usage_metadata={
"input_tokens": ct_input_tokens, "input_tokens": ct_input_tokens,
@ -148,7 +152,7 @@ class ChatParrotLink(BaseChatModel):
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=AIMessageChunk( message=AIMessageChunk(
content="", content="",
response_metadata={"time_in_sec": 3, "model_name": self.model_name}, response_metadata={"time_in_sec": 3, "model_name": self.model},
) )
) )
if run_manager: if run_manager:
@ -174,5 +178,5 @@ class ChatParrotLink(BaseChatModel):
# rules in LLM monitoring applications (e.g., in LangSmith users # rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor # can provide per token pricing for their model and monitor
# costs for the given LLM.) # costs for the given LLM.)
"model_name": self.model_name, "model": self.model,
} }