mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-08 14:31:55 +00:00
community[minor]: add ChatSnowflakeCortex
chat model (#21490)
**Description:** This PR adds a chat model integration for [Snowflake Cortex](https://docs.snowflake.com/en/user-guide/snowflake-cortex/llm-functions), which gives an instant access to industry-leading large language models (LLMs) trained by researchers at companies like Mistral, Reka, Meta, and Google, including [Snowflake Arctic](https://www.snowflake.com/en/data-cloud/arctic/), an open enterprise-grade model developed by Snowflake. **Dependencies:** Snowflake's [snowpark](https://pypi.org/project/snowflake-snowpark-python/) library is required for using this integration. **Twitter handle:** [@gethouseware](https://twitter.com/gethouseware) - [x] **Add tests and docs**: 1. integration tests: `libs/community/tests/integration_tests/chat_models/test_snowflake.py` 2. unit tests: `libs/community/tests/unit_tests/chat_models/test_snowflake.py` 3. example notebook: `docs/docs/integrations/chat/snowflake.ipynb` - [x] **Lint and test**: Run `make format`, `make lint` and `make test` from the root of the package(s) you've modified. See contribution guidelines for more: https://python.langchain.com/docs/contributing/
This commit is contained in:
@@ -140,6 +140,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.promptlayer_openai import (
|
||||
PromptLayerChatOpenAI,
|
||||
)
|
||||
from langchain_community.chat_models.snowflake import (
|
||||
ChatSnowflakeCortex,
|
||||
)
|
||||
from langchain_community.chat_models.solar import (
|
||||
SolarChat,
|
||||
)
|
||||
@@ -196,6 +199,7 @@ __all__ = [
|
||||
"ChatPerplexity",
|
||||
"ChatPremAI",
|
||||
"ChatSparkLLM",
|
||||
"ChatSnowflakeCortex",
|
||||
"ChatTongyi",
|
||||
"ChatVertexAI",
|
||||
"ChatYandexGPT",
|
||||
@@ -247,6 +251,7 @@ _module_lookup = {
|
||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||
"ChatOpenAI": "langchain_community.chat_models.openai",
|
||||
"ChatPerplexity": "langchain_community.chat_models.perplexity",
|
||||
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
|
||||
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
|
||||
"ChatTongyi": "langchain_community.chat_models.tongyi",
|
||||
"ChatVertexAI": "langchain_community.chat_models.vertexai",
|
||||
|
232
libs/community/langchain_community/chat_models/snowflake.py
Normal file
232
libs/community/langchain_community/chat_models/snowflake.py
Normal file
@@ -0,0 +1,232 @@
|
||||
import json
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatResult
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
from langchain_core.utils.utils import build_extra_kwargs
|
||||
|
||||
SUPPORTED_ROLES: List[str] = [
|
||||
"system",
|
||||
"user",
|
||||
"assistant",
|
||||
]
|
||||
|
||||
|
||||
class ChatSnowflakeCortexError(Exception):
|
||||
"""Error with Snowpark client."""
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
"""Convert a LangChain message to a dictionary.
|
||||
|
||||
Args:
|
||||
message: The LangChain message.
|
||||
|
||||
Returns:
|
||||
The dictionary.
|
||||
"""
|
||||
message_dict: Dict[str, Any] = {
|
||||
"content": message.content,
|
||||
}
|
||||
|
||||
# populate role and additional message data
|
||||
if isinstance(message, ChatMessage) and message.role in SUPPORTED_ROLES:
|
||||
message_dict["role"] = message.role
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict["role"] = "system"
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict["role"] = "user"
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict["role"] = "assistant"
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
def _truncate_at_stop_tokens(
|
||||
text: str,
|
||||
stop: Optional[List[str]],
|
||||
) -> str:
|
||||
"""Truncates text at the earliest stop token found."""
|
||||
if stop is None:
|
||||
return text
|
||||
|
||||
for stop_token in stop:
|
||||
stop_token_idx = text.find(stop_token)
|
||||
if stop_token_idx != -1:
|
||||
text = text[:stop_token_idx]
|
||||
return text
|
||||
|
||||
|
||||
class ChatSnowflakeCortex(BaseChatModel):
|
||||
"""Snowflake Cortex based Chat model
|
||||
|
||||
To use you must have the ``snowflake-snowpark-python`` Python package installed and
|
||||
either:
|
||||
|
||||
1. environment variables set with your snowflake credentials or
|
||||
2. directly passed in as kwargs to the ChatSnowflakeCortex constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatSnowflakeCortex
|
||||
chat = ChatSnowflakeCortex()
|
||||
"""
|
||||
|
||||
_sp_session: Any = None
|
||||
"""Snowpark session object."""
|
||||
|
||||
model: str = "snowflake-arctic"
|
||||
"""Snowflake cortex hosted LLM model name, defaulted to `snowflake-arctic`.
|
||||
Refer to docs for more options."""
|
||||
|
||||
cortex_function: str = "complete"
|
||||
"""Cortex function to use, defaulted to `complete`.
|
||||
Refer to docs for more options."""
|
||||
|
||||
temperature: float = 0.7
|
||||
"""Model temperature. Value should be >= 0 and <= 1.0"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of output tokens in the response."""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""top_p adjusts the number of choices for each predicted tokens based on
|
||||
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
snowflake_username: Optional[str] = Field(default=None, alias="username")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_USERNAME` if not provided."""
|
||||
snowflake_password: Optional[SecretStr] = Field(default=None, alias="password")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_PASSWORD` if not provided."""
|
||||
snowflake_account: Optional[str] = Field(default=None, alias="account")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_ACCOUNT` if not provided."""
|
||||
snowflake_database: Optional[str] = Field(default=None, alias="database")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_DATABASE` if not provided."""
|
||||
snowflake_schema: Optional[str] = Field(default=None, alias="schema")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_SCHEMA` if not provided."""
|
||||
snowflake_warehouse: Optional[str] = Field(default=None, alias="warehouse")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_WAREHOUSE` if not provided."""
|
||||
snowflake_role: Optional[str] = Field(default=None, alias="role")
|
||||
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
values["model_kwargs"] = build_extra_kwargs(
|
||||
extra, values, all_required_field_names
|
||||
)
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
try:
|
||||
from snowflake.snowpark import Session
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"`snowflake-snowpark-python` package not found, please install it with "
|
||||
"`pip install snowflake-snowpark-python`"
|
||||
)
|
||||
|
||||
values["snowflake_username"] = get_from_dict_or_env(
|
||||
values, "snowflake_username", "SNOWFLAKE_USERNAME"
|
||||
)
|
||||
values["snowflake_password"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "snowflake_password", "SNOWFLAKE_PASSWORD")
|
||||
)
|
||||
values["snowflake_account"] = get_from_dict_or_env(
|
||||
values, "snowflake_account", "SNOWFLAKE_ACCOUNT"
|
||||
)
|
||||
values["snowflake_database"] = get_from_dict_or_env(
|
||||
values, "snowflake_database", "SNOWFLAKE_DATABASE"
|
||||
)
|
||||
values["snowflake_schema"] = get_from_dict_or_env(
|
||||
values, "snowflake_schema", "SNOWFLAKE_SCHEMA"
|
||||
)
|
||||
values["snowflake_warehouse"] = get_from_dict_or_env(
|
||||
values, "snowflake_warehouse", "SNOWFLAKE_WAREHOUSE"
|
||||
)
|
||||
values["snowflake_role"] = get_from_dict_or_env(
|
||||
values, "snowflake_role", "SNOWFLAKE_ROLE"
|
||||
)
|
||||
|
||||
connection_params = {
|
||||
"account": values["snowflake_account"],
|
||||
"user": values["snowflake_username"],
|
||||
"password": values["snowflake_password"].get_secret_value(),
|
||||
"database": values["snowflake_database"],
|
||||
"schema": values["snowflake_schema"],
|
||||
"warehouse": values["snowflake_warehouse"],
|
||||
"role": values["snowflake_role"],
|
||||
}
|
||||
|
||||
try:
|
||||
values["_sp_session"] = Session.builder.configs(connection_params).create()
|
||||
except Exception as e:
|
||||
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
|
||||
|
||||
return values
|
||||
|
||||
def __del__(self) -> None:
|
||||
if getattr(self, "_sp_session", None) is not None:
|
||||
self._sp_session.close()
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Get the type of language model used by this chat model."""
|
||||
return f"snowflake-cortex-{self.model}"
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||
message_str = str(message_dicts)
|
||||
options = {"temperature": self.temperature}
|
||||
if self.top_p is not None:
|
||||
options["top_p"] = self.top_p
|
||||
if self.max_tokens is not None:
|
||||
options["max_tokens"] = self.max_tokens
|
||||
options_str = str(options)
|
||||
sql_stmt = f"""
|
||||
select snowflake.cortex.{self.cortex_function}(
|
||||
'{self.model}'
|
||||
,{message_str},{options_str}) as llm_response;"""
|
||||
|
||||
try:
|
||||
l_rows = self._sp_session.sql(sql_stmt).collect()
|
||||
except Exception as e:
|
||||
raise ChatSnowflakeCortexError(
|
||||
f"Error while making request to Snowflake Cortex via Snowpark: {e}"
|
||||
)
|
||||
|
||||
response = json.loads(l_rows[0]["LLM_RESPONSE"])
|
||||
ai_message_content = response["choices"][0]["messages"]
|
||||
|
||||
content = _truncate_at_stop_tokens(ai_message_content, stop)
|
||||
message = AIMessage(
|
||||
content=content,
|
||||
response_metadata=response["usage"],
|
||||
)
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
@@ -0,0 +1,59 @@
|
||||
"""Test ChatSnowflakeCortex
|
||||
Note: This test must be run with the following environment variables set:
|
||||
SNOWFLAKE_ACCOUNT="YOUR_SNOWFLAKE_ACCOUNT",
|
||||
SNOWFLAKE_USERNAME="YOUR_SNOWFLAKE_USERNAME",
|
||||
SNOWFLAKE_PASSWORD="YOUR_SNOWFLAKE_PASSWORD",
|
||||
SNOWFLAKE_DATABASE="YOUR_SNOWFLAKE_DATABASE",
|
||||
SNOWFLAKE_SCHEMA="YOUR_SNOWFLAKE_SCHEMA",
|
||||
SNOWFLAKE_WAREHOUSE="YOUR_SNOWFLAKE_WAREHOUSE"
|
||||
SNOWFLAKE_ROLE="YOUR_SNOWFLAKE_ROLE",
|
||||
"""
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
from langchain_core.outputs import ChatGeneration, LLMResult
|
||||
|
||||
from langchain_community.chat_models import ChatSnowflakeCortex
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def chat() -> ChatSnowflakeCortex:
|
||||
return ChatSnowflakeCortex()
|
||||
|
||||
|
||||
def test_chat_snowflake_cortex(chat: ChatSnowflakeCortex) -> None:
|
||||
"""Test ChatSnowflakeCortex."""
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat([message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_snowflake_cortex_system_message(chat: ChatSnowflakeCortex) -> None:
|
||||
"""Test ChatSnowflakeCortex for system message"""
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_snowflake_cortex_model() -> None:
|
||||
"""Test ChatSnowflakeCortex handles model_name."""
|
||||
chat = ChatSnowflakeCortex(
|
||||
model="foo",
|
||||
)
|
||||
assert chat.model == "foo"
|
||||
|
||||
|
||||
def test_chat_snowflake_cortex_generate(chat: ChatSnowflakeCortex) -> None:
|
||||
"""Test ChatSnowflakeCortex with generate."""
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.generate([[message], [message]])
|
||||
assert isinstance(response, LLMResult)
|
||||
assert len(response.generations) == 2
|
||||
for generations in response.generations:
|
||||
for generation in generations:
|
||||
assert isinstance(generation, ChatGeneration)
|
||||
assert isinstance(generation.text, str)
|
||||
assert generation.text == generation.message.content
|
@@ -51,6 +51,7 @@ EXPECTED_ALL = [
|
||||
"QianfanChatEndpoint",
|
||||
"VolcEngineMaasChat",
|
||||
"ChatOctoAI",
|
||||
"ChatSnowflakeCortex",
|
||||
]
|
||||
|
||||
|
||||
|
@@ -0,0 +1,24 @@
|
||||
"""Test ChatSnowflakeCortex."""
|
||||
|
||||
from langchain_core.messages import AIMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_community.chat_models.snowflake import _convert_message_to_dict
|
||||
|
||||
|
||||
def test_messages_to_prompt_dict_with_valid_messages() -> None:
|
||||
messages = [
|
||||
SystemMessage(content="System Prompt"),
|
||||
HumanMessage(content="User message #1"),
|
||||
AIMessage(content="AI message #1"),
|
||||
HumanMessage(content="User message #2"),
|
||||
AIMessage(content="AI message #2"),
|
||||
]
|
||||
result = [_convert_message_to_dict(m) for m in messages]
|
||||
expected = [
|
||||
{"role": "system", "content": "System Prompt"},
|
||||
{"role": "user", "content": "User message #1"},
|
||||
{"role": "assistant", "content": "AI message #1"},
|
||||
{"role": "user", "content": "User message #2"},
|
||||
{"role": "assistant", "content": "AI message #2"},
|
||||
]
|
||||
assert result == expected
|
Reference in New Issue
Block a user