mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-14 17:07:25 +00:00
Add serialisation arguments to Bedrock and ChatBedrock (#13465)
This commit is contained in:
parent
427331d621
commit
ea6e017b85
@ -41,6 +41,22 @@ class BedrockChat(BaseChatModel, BedrockBase):
|
|||||||
"""Return type of chat model."""
|
"""Return type of chat model."""
|
||||||
return "amazon_bedrock_chat"
|
return "amazon_bedrock_chat"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
print(self.region_name)
|
||||||
|
|
||||||
|
if self.region_name:
|
||||||
|
attributes["region_name"] = self.region_name
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@ from langchain.utilities.anthropic import (
|
|||||||
get_num_tokens_anthropic,
|
get_num_tokens_anthropic,
|
||||||
get_token_ids_anthropic,
|
get_token_ids_anthropic,
|
||||||
)
|
)
|
||||||
|
from langchain.utils import get_from_dict_or_env
|
||||||
|
|
||||||
HUMAN_PROMPT = "\n\nHuman:"
|
HUMAN_PROMPT = "\n\nHuman:"
|
||||||
ASSISTANT_PROMPT = "\n\nAssistant:"
|
ASSISTANT_PROMPT = "\n\nAssistant:"
|
||||||
@ -195,6 +196,13 @@ class BedrockBase(BaseModel, ABC):
|
|||||||
# use default credentials
|
# use default credentials
|
||||||
session = boto3.Session()
|
session = boto3.Session()
|
||||||
|
|
||||||
|
values["region_name"] = get_from_dict_or_env(
|
||||||
|
values,
|
||||||
|
"region_name",
|
||||||
|
"AWS_DEFAULT_REGION",
|
||||||
|
default=None,
|
||||||
|
)
|
||||||
|
|
||||||
client_params = {}
|
client_params = {}
|
||||||
if values["region_name"]:
|
if values["region_name"]:
|
||||||
client_params["region_name"] = values["region_name"]
|
client_params["region_name"] = values["region_name"]
|
||||||
@ -340,6 +348,20 @@ class Bedrock(LLM, BedrockBase):
|
|||||||
"""Return type of llm."""
|
"""Return type of llm."""
|
||||||
return "amazon_bedrock"
|
return "amazon_bedrock"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def is_lc_serializable(cls) -> bool:
|
||||||
|
"""Return whether this model can be serialized by Langchain."""
|
||||||
|
return True
|
||||||
|
|
||||||
|
@property
|
||||||
|
def lc_attributes(self) -> Dict[str, Any]:
|
||||||
|
attributes: Dict[str, Any] = {}
|
||||||
|
|
||||||
|
if self.region_name:
|
||||||
|
attributes["region_name"] = self.region_name
|
||||||
|
|
||||||
|
return attributes
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user