mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-03 03:59:42 +00:00
multiple: add stop
attribute (#22573)
This commit is contained in:
@@ -35,6 +35,8 @@ class ChatAI21(BaseChatModel, AI21Base):
|
|||||||
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
|
||||||
num_results: int = 1
|
num_results: int = 1
|
||||||
"""The number of responses to generate for a given prompt."""
|
"""The number of responses to generate for a given prompt."""
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
"""Default stop sequences."""
|
||||||
|
|
||||||
max_tokens: int = 16
|
max_tokens: int = 16
|
||||||
"""The maximum number of tokens to generate for each response."""
|
"""The maximum number of tokens to generate for each response."""
|
||||||
@@ -97,6 +99,8 @@ class ChatAI21(BaseChatModel, AI21Base):
|
|||||||
"top_k_return": self.top_k_return,
|
"top_k_return": self.top_k_return,
|
||||||
"n": self.n,
|
"n": self.n,
|
||||||
}
|
}
|
||||||
|
if self.stop:
|
||||||
|
base_params["stop_sequences"] = self.stop
|
||||||
|
|
||||||
if self.count_penalty is not None:
|
if self.count_penalty is not None:
|
||||||
base_params["count_penalty"] = self.count_penalty.to_dict()
|
base_params["count_penalty"] = self.count_penalty.to_dict()
|
||||||
|
@@ -492,6 +492,9 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
max_retries: int = 2
|
max_retries: int = 2
|
||||||
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
"""Number of retries allowed for requests sent to the Anthropic Completion API."""
|
||||||
|
|
||||||
|
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
|
||||||
|
"""Default stop sequences."""
|
||||||
|
|
||||||
anthropic_api_url: Optional[str] = Field(None, alias="base_url")
|
anthropic_api_url: Optional[str] = Field(None, alias="base_url")
|
||||||
"""Base URL for API requests. Only specify if using a proxy or service emulator.
|
"""Base URL for API requests. Only specify if using a proxy or service emulator.
|
||||||
|
|
||||||
@@ -611,6 +614,7 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
) -> Dict:
|
) -> Dict:
|
||||||
# get system prompt if any
|
# get system prompt if any
|
||||||
system, formatted_messages = _format_messages(messages)
|
system, formatted_messages = _format_messages(messages)
|
||||||
|
stop_sequences = stop or self.stop
|
||||||
rtn = {
|
rtn = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"max_tokens": self.max_tokens,
|
"max_tokens": self.max_tokens,
|
||||||
@@ -618,7 +622,7 @@ class ChatAnthropic(BaseChatModel):
|
|||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_k": self.top_k,
|
"top_k": self.top_k,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
"stop_sequences": stop,
|
"stop_sequences": stop_sequences,
|
||||||
"system": system,
|
"system": system,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
|
@@ -300,6 +300,8 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
|
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
|
||||||
|
"""Default stop sequences."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@@ -354,6 +356,7 @@ class ChatFireworks(BaseChatModel):
|
|||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
"n": self.n,
|
"n": self.n,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
"stop": self.stop,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
@@ -443,8 +446,6 @@ class ChatFireworks(BaseChatModel):
|
|||||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
return message_dicts, params
|
return message_dicts, params
|
||||||
|
@@ -123,6 +123,8 @@ class ChatGroq(BaseChatModel):
|
|||||||
"""Number of chat completions to generate for each prompt."""
|
"""Number of chat completions to generate for each prompt."""
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: Optional[int] = None
|
||||||
"""Maximum number of tokens to generate."""
|
"""Maximum number of tokens to generate."""
|
||||||
|
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
|
||||||
|
"""Default stop sequences."""
|
||||||
default_headers: Union[Mapping[str, str], None] = None
|
default_headers: Union[Mapping[str, str], None] = None
|
||||||
default_query: Union[Mapping[str, object], None] = None
|
default_query: Union[Mapping[str, object], None] = None
|
||||||
# Configure a custom httpx client. See the
|
# Configure a custom httpx client. See the
|
||||||
@@ -428,6 +430,7 @@ class ChatGroq(BaseChatModel):
|
|||||||
"stream": self.streaming,
|
"stream": self.streaming,
|
||||||
"n": self.n,
|
"n": self.n,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
|
"stop": self.stop,
|
||||||
**self.model_kwargs,
|
**self.model_kwargs,
|
||||||
}
|
}
|
||||||
if self.max_tokens is not None:
|
if self.max_tokens is not None:
|
||||||
@@ -461,8 +464,6 @@ class ChatGroq(BaseChatModel):
|
|||||||
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
params = self._default_params
|
params = self._default_params
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
return message_dicts, params
|
return message_dicts, params
|
||||||
|
@@ -144,6 +144,17 @@ class ChatModelIntegrationTests(ABC):
|
|||||||
assert isinstance(result.usage_metadata["output_tokens"], int)
|
assert isinstance(result.usage_metadata["output_tokens"], int)
|
||||||
assert isinstance(result.usage_metadata["total_tokens"], int)
|
assert isinstance(result.usage_metadata["total_tokens"], int)
|
||||||
|
|
||||||
|
def test_stop_sequence(
|
||||||
|
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
|
||||||
|
) -> None:
|
||||||
|
model = chat_model_class(**chat_model_params)
|
||||||
|
result = model.invoke("hi", stop=["you"])
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
|
model = chat_model_class(**chat_model_params, stop=["you"])
|
||||||
|
result = model.invoke("hi")
|
||||||
|
assert isinstance(result, AIMessage)
|
||||||
|
|
||||||
def test_tool_message_histories_string_content(
|
def test_tool_message_histories_string_content(
|
||||||
self,
|
self,
|
||||||
chat_model_class: Type[BaseChatModel],
|
chat_model_class: Type[BaseChatModel],
|
||||||
|
Reference in New Issue
Block a user