This commit is contained in:
Erick Friis 2025-02-04 18:16:10 -08:00
parent 6d58ccb013
commit 3b441f312a
2 changed files with 14 additions and 6 deletions

View File

@ -13,7 +13,7 @@ from langchain_core.messages import (
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
from pydantic import ConfigDict, Field
class Chat__ModuleName__(BaseChatModel):
@ -266,7 +266,7 @@ class Chat__ModuleName__(BaseChatModel):
""" # noqa: E501
model_name: str = Field(alias="model")
model: str = Field(alias="model_name")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
@ -276,6 +276,10 @@ class Chat__ModuleName__(BaseChatModel):
stop: Optional[List[str]] = None
max_retries: int = 2
model_config = ConfigDict(
populate_by_name=True,
)
@property
def _llm_type(self) -> str:
"""Return type of chat model."""
@ -293,7 +297,7 @@ class Chat__ModuleName__(BaseChatModel):
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
"model": self.model,
}
def _generate(

View File

@ -11,7 +11,7 @@ from langchain_core.messages import (
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from pydantic import Field
from pydantic import ConfigDict, Field
class ChatParrotLink(BaseChatModel):
@ -33,7 +33,7 @@ class ChatParrotLink(BaseChatModel):
[HumanMessage(content="world")]])
"""
model_name: str = Field(alias="model")
model: str = Field(alias="model_name")
"""The name of the model"""
parrot_buffer_length: int
"""The number of characters from the last message of the prompt to be echoed."""
@ -43,6 +43,10 @@ class ChatParrotLink(BaseChatModel):
stop: Optional[List[str]] = None
max_retries: int = 2
model_config = ConfigDict(
populate_by_name=True,
)
def _generate(
self,
messages: List[BaseMessage],
@ -163,5 +167,5 @@ class ChatParrotLink(BaseChatModel):
# rules in LLM monitoring applications (e.g., in LangSmith users
# can provide per token pricing for their model and monitor
# costs for the given LLM.)
"model_name": self.model_name,
"model": self.model,
}