mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-13 14:50:00 +00:00
Merge 8b40258e72
into 3a487bf720
This commit is contained in:
commit
61127f00d4
@ -13,7 +13,7 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from pydantic import Field
|
from pydantic import ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class Chat__ModuleName__(BaseChatModel):
|
class Chat__ModuleName__(BaseChatModel):
|
||||||
@ -266,7 +266,7 @@ class Chat__ModuleName__(BaseChatModel):
|
|||||||
|
|
||||||
""" # noqa: E501
|
""" # noqa: E501
|
||||||
|
|
||||||
model_name: str = Field(alias="model")
|
model: str = Field(alias="model_name")
|
||||||
"""The name of the model"""
|
"""The name of the model"""
|
||||||
parrot_buffer_length: int
|
parrot_buffer_length: int
|
||||||
"""The number of characters from the last message of the prompt to be echoed."""
|
"""The number of characters from the last message of the prompt to be echoed."""
|
||||||
@ -276,6 +276,10 @@ class Chat__ModuleName__(BaseChatModel):
|
|||||||
stop: Optional[List[str]] = None
|
stop: Optional[List[str]] = None
|
||||||
max_retries: int = 2
|
max_retries: int = 2
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _llm_type(self) -> str:
|
def _llm_type(self) -> str:
|
||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
@ -293,7 +297,7 @@ class Chat__ModuleName__(BaseChatModel):
|
|||||||
# rules in LLM monitoring applications (e.g., in LangSmith users
|
# rules in LLM monitoring applications (e.g., in LangSmith users
|
||||||
# can provide per token pricing for their model and monitor
|
# can provide per token pricing for their model and monitor
|
||||||
# costs for the given LLM.)
|
# costs for the given LLM.)
|
||||||
"model_name": self.model_name,
|
"model": self.model,
|
||||||
}
|
}
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
|
@ -448,7 +448,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
root_client: Any = Field(default=None, exclude=True) #: :meta private:
|
root_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
root_async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||||
model_name: str = Field(default="gpt-3.5-turbo", alias="model")
|
model: str = Field(default="gpt-3.5-turbo", alias="model_name")
|
||||||
"""Model name to use."""
|
"""Model name to use."""
|
||||||
temperature: Optional[float] = None
|
temperature: Optional[float] = None
|
||||||
"""What sampling temperature to use."""
|
"""What sampling temperature to use."""
|
||||||
@ -690,6 +690,10 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
|
|
||||||
model_config = ConfigDict(populate_by_name=True)
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self.model
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def build_extra(cls, values: dict[str, Any]) -> Any:
|
def build_extra(cls, values: dict[str, Any]) -> Any:
|
||||||
|
@ -84,6 +84,10 @@ class ChatModelTests(BaseStandardTests):
|
|||||||
"""Initialization parameters for the chat model."""
|
"""Initialization parameters for the chat model."""
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def chat_model_model_param(self) -> dict:
|
||||||
|
return self.chat_model_params.get("model", "test-model-name")
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def standard_chat_model_params(self) -> dict:
|
def standard_chat_model_params(self) -> dict:
|
||||||
""":private:"""
|
""":private:"""
|
||||||
@ -833,6 +837,45 @@ class ChatModelUnitTests(ChatModelTests):
|
|||||||
)
|
)
|
||||||
assert model is not None
|
assert model is not None
|
||||||
|
|
||||||
|
def test_model_param_name(self) -> None:
|
||||||
|
"""Tests model initialization with a ``model=`` parameter. This should pass for
|
||||||
|
all integrations.
|
||||||
|
|
||||||
|
.. dropdown:: Troubleshooting
|
||||||
|
|
||||||
|
If this test fails, ensure that the model can be initialized with a
|
||||||
|
``model`` parameter, and that the model parameter can be accessed as
|
||||||
|
``.model``.
|
||||||
|
|
||||||
|
If not, the easiest way to configure this is likely to add
|
||||||
|
``from pydantic import ConfigDict`` at the top of your file, and add a
|
||||||
|
``model_config`` class attribute to your model class:
|
||||||
|
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
class MyChatModel(BaseChatModel):
|
||||||
|
model: str = Field(alias="model_name")
|
||||||
|
model_config = ConfigDict(populate_by_name=True)
|
||||||
|
|
||||||
|
# optional property for backwards-compatibility
|
||||||
|
# for folks accessing chat_model.model_name
|
||||||
|
@property
|
||||||
|
def model_name(self) -> str:
|
||||||
|
return self.model
|
||||||
|
"""
|
||||||
|
params = {
|
||||||
|
**self.standard_chat_model_params,
|
||||||
|
**self.chat_model_params,
|
||||||
|
}
|
||||||
|
if "model_name" in params:
|
||||||
|
params["model"] = params.pop("model_name")
|
||||||
|
else:
|
||||||
|
params["model"] = self.chat_model_model_param
|
||||||
|
|
||||||
|
model = self.chat_model_class(**params)
|
||||||
|
assert model is not None
|
||||||
|
assert model.model == params["model"] # type: ignore[attr-defined]
|
||||||
|
|
||||||
def test_init_from_env(self) -> None:
|
def test_init_from_env(self) -> None:
|
||||||
"""Test initialization from environment variables. Relies on the
|
"""Test initialization from environment variables. Relies on the
|
||||||
``init_from_env_params`` property. Test is skipped if that property is not
|
``init_from_env_params`` property. Test is skipped if that property is not
|
||||||
|
@ -12,7 +12,7 @@ from langchain_core.messages import (
|
|||||||
)
|
)
|
||||||
from langchain_core.messages.ai import UsageMetadata
|
from langchain_core.messages.ai import UsageMetadata
|
||||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||||
from pydantic import Field
|
from pydantic import ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
class ChatParrotLink(BaseChatModel):
|
class ChatParrotLink(BaseChatModel):
|
||||||
@ -35,7 +35,7 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str = Field(alias="model")
|
model: str = Field(alias="model_name")
|
||||||
"""The name of the model"""
|
"""The name of the model"""
|
||||||
parrot_buffer_length: int
|
parrot_buffer_length: int
|
||||||
"""The number of characters from the last message of the prompt to be echoed."""
|
"""The number of characters from the last message of the prompt to be echoed."""
|
||||||
@ -45,6 +45,10 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
stop: Optional[list[str]] = None
|
stop: Optional[list[str]] = None
|
||||||
max_retries: int = 2
|
max_retries: int = 2
|
||||||
|
|
||||||
|
model_config = ConfigDict(
|
||||||
|
populate_by_name=True,
|
||||||
|
)
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
self,
|
self,
|
||||||
messages: list[BaseMessage],
|
messages: list[BaseMessage],
|
||||||
@ -81,7 +85,7 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
additional_kwargs={}, # Used to add additional payload to the message
|
additional_kwargs={}, # Used to add additional payload to the message
|
||||||
response_metadata={ # Use for response metadata
|
response_metadata={ # Use for response metadata
|
||||||
"time_in_seconds": 3,
|
"time_in_seconds": 3,
|
||||||
"model_name": self.model_name,
|
"model_name": self.model,
|
||||||
},
|
},
|
||||||
usage_metadata={
|
usage_metadata={
|
||||||
"input_tokens": ct_input_tokens,
|
"input_tokens": ct_input_tokens,
|
||||||
@ -148,7 +152,7 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=AIMessageChunk(
|
message=AIMessageChunk(
|
||||||
content="",
|
content="",
|
||||||
response_metadata={"time_in_sec": 3, "model_name": self.model_name},
|
response_metadata={"time_in_sec": 3, "model_name": self.model},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
if run_manager:
|
if run_manager:
|
||||||
@ -174,5 +178,5 @@ class ChatParrotLink(BaseChatModel):
|
|||||||
# rules in LLM monitoring applications (e.g., in LangSmith users
|
# rules in LLM monitoring applications (e.g., in LangSmith users
|
||||||
# can provide per token pricing for their model and monitor
|
# can provide per token pricing for their model and monitor
|
||||||
# costs for the given LLM.)
|
# costs for the given LLM.)
|
||||||
"model_name": self.model_name,
|
"model": self.model,
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user