diff --git a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py index 9703b50358a..4a47e01b642 100644 --- a/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py +++ b/libs/cli/langchain_cli/integration_template/integration_template/chat_models.py @@ -13,7 +13,7 @@ from langchain_core.messages import ( ) from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import Field +from pydantic import ConfigDict, Field class Chat__ModuleName__(BaseChatModel): @@ -266,7 +266,7 @@ class Chat__ModuleName__(BaseChatModel): """ # noqa: E501 - model_name: str = Field(alias="model") + model: str = Field(alias="model_name") """The name of the model""" parrot_buffer_length: int """The number of characters from the last message of the prompt to be echoed.""" @@ -276,6 +276,10 @@ class Chat__ModuleName__(BaseChatModel): stop: Optional[List[str]] = None max_retries: int = 2 + model_config = ConfigDict( + populate_by_name=True, + ) + @property def _llm_type(self) -> str: """Return type of chat model.""" @@ -293,7 +297,7 @@ class Chat__ModuleName__(BaseChatModel): # rules in LLM monitoring applications (e.g., in LangSmith users # can provide per token pricing for their model and monitor # costs for the given LLM.) - "model_name": self.model_name, + "model": self.model, } def _generate( diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index 1bc9b66d880..cfb39ba1d87 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel): async_client: Any = Field(default=None, exclude=True) #: :meta private: root_client: Any = Field(default=None, exclude=True) #: :meta private: root_async_client: Any = Field(default=None, exclude=True) #: :meta private: - model_name: str = Field(default="gpt-3.5-turbo", alias="model") + model: str = Field(default="gpt-3.5-turbo", alias="model_name") """Model name to use.""" temperature: Optional[float] = None """What sampling temperature to use.""" @@ -690,6 +690,10 @@ class BaseChatOpenAI(BaseChatModel): model_config = ConfigDict(populate_by_name=True) + @property + def model_name(self) -> str: + return self.model + @model_validator(mode="before") @classmethod def build_extra(cls, values: dict[str, Any]) -> Any: diff --git a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py index 320d2b491f1..3ee4d16dbea 100644 --- a/libs/standard-tests/langchain_tests/unit_tests/chat_models.py +++ b/libs/standard-tests/langchain_tests/unit_tests/chat_models.py @@ -84,6 +84,10 @@ class ChatModelTests(BaseStandardTests): """Initialization parameters for the chat model.""" return {} + @property + def chat_model_model_param(self) -> dict: + return self.chat_model_params.get("model", "test-model-name") + @property def standard_chat_model_params(self) -> dict: """:private:""" @@ -833,6 +837,45 @@ class ChatModelUnitTests(ChatModelTests): ) assert model is not None + def test_model_param_name(self) -> None: + """Tests model initialization with a ``model=`` parameter. This should pass for + all integrations. + + .. dropdown:: Troubleshooting + + If this test fails, ensure that the model can be initialized with a + ``model`` parameter, and that the model parameter can be accessed as + ``.model``. + + If not, the easiest way to configure this is likely to add + ``from pydantic import ConfigDict`` at the top of your file, and add a + ``model_config`` class attribute to your model class: + + .. code-block:: python + + class MyChatModel(BaseChatModel): + model: str = Field(alias="model_name") + model_config = ConfigDict(populate_by_name=True) + + # optional property for backwards-compatibility + # for folks accessing chat_model.model_name + @property + def model_name(self) -> str: + return self.model + """ + params = { + **self.standard_chat_model_params, + **self.chat_model_params, + } + if "model_name" in params: + params["model"] = params.pop("model_name") + else: + params["model"] = self.chat_model_model_param + + model = self.chat_model_class(**params) + assert model is not None + assert model.model == params["model"] # type: ignore[attr-defined] + def test_init_from_env(self) -> None: """Test initialization from environment variables. Relies on the ``init_from_env_params`` property. Test is skipped if that property is not diff --git a/libs/standard-tests/tests/unit_tests/custom_chat_model.py b/libs/standard-tests/tests/unit_tests/custom_chat_model.py index cc9be763989..91fd17fc6ec 100644 --- a/libs/standard-tests/tests/unit_tests/custom_chat_model.py +++ b/libs/standard-tests/tests/unit_tests/custom_chat_model.py @@ -12,7 +12,7 @@ from langchain_core.messages import ( ) from langchain_core.messages.ai import UsageMetadata from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult -from pydantic import Field +from pydantic import ConfigDict, Field class ChatParrotLink(BaseChatModel): @@ -35,7 +35,7 @@ class ChatParrotLink(BaseChatModel): """ - model_name: str = Field(alias="model") + model: str = Field(alias="model_name") """The name of the model""" parrot_buffer_length: int """The number of characters from the last message of the prompt to be echoed.""" @@ -45,6 +45,10 @@ class ChatParrotLink(BaseChatModel): stop: Optional[list[str]] = None max_retries: int = 2 + model_config = ConfigDict( + populate_by_name=True, + ) + def _generate( self, messages: list[BaseMessage], @@ -81,7 +85,7 @@ class ChatParrotLink(BaseChatModel): additional_kwargs={}, # Used to add additional payload to the message response_metadata={ # Use for response metadata "time_in_seconds": 3, - "model_name": self.model_name, + "model_name": self.model, }, usage_metadata={ "input_tokens": ct_input_tokens, @@ -148,7 +152,7 @@ class ChatParrotLink(BaseChatModel): chunk = ChatGenerationChunk( message=AIMessageChunk( content="", - response_metadata={"time_in_sec": 3, "model_name": self.model_name}, + response_metadata={"time_in_sec": 3, "model_name": self.model}, ) ) if run_manager: @@ -174,5 +178,5 @@ class ChatParrotLink(BaseChatModel): # rules in LLM monitoring applications (e.g., in LangSmith users # can provide per token pricing for their model and monitor # costs for the given LLM.) - "model_name": self.model_name, + "model": self.model, }