community: ChatSnowflakeCortex - Add streaming functionality (#27753)

Description:
snowflake.py
Add _stream and _stream_content methods to enable streaming
functionality
fix pydantic issues and added functionality with the overall langchain
version upgrade
added bind_tools method for agentic workflows support through langgraph
updated the _generate method to account for agentic workflows support
through langgraph
cosmetic changes to comments and if conditions

snowflake.ipynb
Added _stream example
cosmetic changes to comments
fixed lint errors

check_pydantic.sh
Decreased counter from 126 to 125 as suggested when formatting

---------

Co-authored-by: Prathamesh Nimkar <prathamesh.nimkar@snowflake.com>
Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Prathamesh Nimkar 2024-12-11 21:35:40 -05:00 committed by GitHub
parent d834c6b618
commit ca054ed1b1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 223 additions and 69 deletions

View File

@ -22,24 +22,16 @@
},
{
"cell_type": "code",
"execution_count": 1,
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"outputs": [],
"source": [
"%pip install --upgrade --quiet snowflake-snowpark-python"
]
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -73,14 +65,14 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models import ChatSnowflakeCortex\n",
"from langchain_core.messages import HumanMessage, SystemMessage\n",
"\n",
"# By default, we'll be using the cortex provided model: `snowflake-arctic`, with function: `complete`\n",
"# By default, we'll be using the cortex provided model: `mistral-large`, with function: `complete`\n",
"chat = ChatSnowflakeCortex()"
]
},
@ -92,16 +84,16 @@
"\n",
"```python\n",
"chat = ChatSnowflakeCortex(\n",
" # change default cortex model and function\n",
" model=\"snowflake-arctic\",\n",
" # Change the default cortex model and function\n",
" model=\"mistral-large\",\n",
" cortex_function=\"complete\",\n",
"\n",
" # change default generation parameters\n",
" # Change the default generation parameters\n",
" temperature=0,\n",
" max_tokens=10,\n",
" top_p=0.95,\n",
"\n",
" # specify snowflake credentials\n",
" # Specify your Snowflake Credentials\n",
" account=\"YOUR_SNOWFLAKE_ACCOUNT\",\n",
" username=\"YOUR_SNOWFLAKE_USERNAME\",\n",
" password=\"YOUR_SNOWFLAKE_PASSWORD\",\n",
@ -117,28 +109,13 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Calling the model\n",
"We can now call the model using the `invoke` or `generate` method.\n",
"\n",
"#### Generation"
"### Calling the chat model\n",
"We can now call the chat model using the `invoke` or `stream` methods."
]
},
{
"cell_type": "code",
"execution_count": 9,
"cell_type": "markdown",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\" Large language models are artificial intelligence systems designed to understand, generate, and manipulate human language. These models are typically based on deep learning techniques and are trained on vast amounts of text data to learn patterns and structures in language. They can perform a wide range of language-related tasks, such as language translation, text generation, sentiment analysis, and answering questions. Some well-known large language models include Google's BERT, OpenAI's GPT series, and Facebook's RoBERTa. These models have shown remarkable performance in various natural language processing tasks, and their applications continue to expand as research in AI progresses.\", response_metadata={'completion_tokens': 131, 'prompt_tokens': 29, 'total_tokens': 160}, id='run-5435bd0a-83fd-4295-b237-66cbd1b5c0f3-0')"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" SystemMessage(content=\"You are a friendly assistant.\"),\n",
@ -151,14 +128,31 @@
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming\n",
"`ChatSnowflakeCortex` doesn't support streaming as of now. Support for streaming will be coming in the later versions!"
"### Stream"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Sample input prompt\n",
"messages = [\n",
" SystemMessage(content=\"You are a friendly assistant.\"),\n",
" HumanMessage(content=\"What are large language models?\"),\n",
"]\n",
"\n",
"# Invoke the stream method and print each chunk as it arrives\n",
"print(\"Stream Method Response:\")\n",
"for chunk in chat._stream(messages):\n",
" print(chunk.message.content)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"display_name": "langchain",
"language": "python",
"name": "python3"
},
@ -172,7 +166,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.9"
"version": "3.9.20"
}
},
"nbformat": 4,

View File

@ -1,22 +1,35 @@
import json
from typing import Any, Dict, List, Optional
from typing import (
Any,
Callable,
Dict,
Iterator,
List,
Literal,
Optional,
Sequence,
Type,
Union,
)
from langchain_core.callbacks.manager import CallbackManagerForLLMRun
from langchain_core.language_models import BaseChatModel
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.tools import BaseTool
from langchain_core.utils import (
convert_to_secret_str,
get_from_dict_or_env,
get_pydantic_field_names,
pre_init,
)
from langchain_core.utils.function_calling import convert_to_openai_tool
from langchain_core.utils.utils import _build_model_kwargs
from pydantic import Field, SecretStr, model_validator
@ -44,7 +57,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
"content": message.content,
}
# populate role and additional message data
# Populate role and additional message data
if isinstance(message, ChatMessage) and message.role in SUPPORTED_ROLES:
message_dict["role"] = message.role
elif isinstance(message, SystemMessage):
@ -76,8 +89,8 @@ def _truncate_at_stop_tokens(
class ChatSnowflakeCortex(BaseChatModel):
"""Snowflake Cortex based Chat model
To use you must have the ``snowflake-snowpark-python`` Python package installed and
either:
To use the chat model, you must have the ``snowflake-snowpark-python`` Python
package installed and either:
1. environment variables set with your snowflake credentials or
2. directly passed in as kwargs to the ChatSnowflakeCortex constructor.
@ -89,24 +102,30 @@ class ChatSnowflakeCortex(BaseChatModel):
chat = ChatSnowflakeCortex()
"""
_sp_session: Any = None
# test_tools: Dict[str, Any] = Field(default_factory=dict)
test_tools: Dict[str, Union[Dict[str, Any], Type, Callable, BaseTool]] = Field(
default_factory=dict
)
session: Any = None
"""Snowpark session object."""
model: str = "snowflake-arctic"
"""Snowflake cortex hosted LLM model name, defaulted to `snowflake-arctic`.
Refer to docs for more options."""
model: str = "mistral-large"
"""Snowflake cortex hosted LLM model name, defaulted to `mistral-large`.
Refer to docs for more options. Also note, not all models support
agentic workflows."""
cortex_function: str = "complete"
"""Cortex function to use, defaulted to `complete`.
Refer to docs for more options."""
temperature: float = 0.7
temperature: float = 0
"""Model temperature. Value should be >= 0 and <= 1.0"""
max_tokens: Optional[int] = None
"""The maximum number of output tokens in the response."""
top_p: Optional[float] = None
top_p: Optional[float] = 0
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""
@ -126,6 +145,27 @@ class ChatSnowflakeCortex(BaseChatModel):
snowflake_role: Optional[str] = Field(default=None, alias="role")
"""Automatically inferred from env var `SNOWFLAKE_ROLE` if not provided."""
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type, Callable, BaseTool]],
*,
tool_choice: Optional[
Union[dict, str, Literal["auto", "any", "none"], bool]
] = "auto",
**kwargs: Any,
) -> "ChatSnowflakeCortex":
"""Bind tool-like objects to this chat model, ensuring they conform to
expected formats."""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
# self.test_tools.update(formatted_tools)
formatted_tools_dict = {
tool["name"]: tool for tool in formatted_tools if "name" in tool
}
self.test_tools.update(formatted_tools_dict)
return self
@model_validator(mode="before")
@classmethod
def build_extra(cls, values: Dict[str, Any]) -> Any:
@ -134,14 +174,15 @@ class ChatSnowflakeCortex(BaseChatModel):
values = _build_model_kwargs(values, all_required_field_names)
return values
@pre_init
@model_validator(mode="before")
def validate_environment(cls, values: Dict) -> Dict:
try:
from snowflake.snowpark import Session
except ImportError:
raise ImportError(
"`snowflake-snowpark-python` package not found, please install it with "
"`pip install snowflake-snowpark-python`"
"""`snowflake-snowpark-python` package not found, please install:
`pip install snowflake-snowpark-python`
"""
)
values["snowflake_username"] = get_from_dict_or_env(
@ -174,18 +215,19 @@ class ChatSnowflakeCortex(BaseChatModel):
"schema": values["snowflake_schema"],
"warehouse": values["snowflake_warehouse"],
"role": values["snowflake_role"],
"client_session_keep_alive": "True",
}
try:
values["_sp_session"] = Session.builder.configs(connection_params).create()
values["session"] = Session.builder.configs(connection_params).create()
except Exception as e:
raise ChatSnowflakeCortexError(f"Failed to create session: {e}")
return values
def __del__(self) -> None:
if getattr(self, "_sp_session", None) is not None:
self._sp_session.close()
if getattr(self, "session", None) is not None:
self.session.close()
@property
def _llm_type(self) -> str:
@ -200,23 +242,55 @@ class ChatSnowflakeCortex(BaseChatModel):
**kwargs: Any,
) -> ChatResult:
message_dicts = [_convert_message_to_dict(m) for m in messages]
message_str = str(message_dicts)
options = {"temperature": self.temperature}
if self.top_p is not None:
options["top_p"] = self.top_p
if self.max_tokens is not None:
options["max_tokens"] = self.max_tokens
options_str = str(options)
# Check for tool invocation in the messages and prepare for tool use
tool_output = None
for message in messages:
if (
isinstance(message.content, dict)
and isinstance(message, SystemMessage)
and "invoke_tool" in message.content
):
tool_info = json.loads(message.content.get("invoke_tool"))
tool_name = tool_info.get("tool_name")
if tool_name in self.test_tools:
tool_args = tool_info.get("args", {})
tool_output = self.test_tools[tool_name](**tool_args)
break
# Prepare messages for SQL query
if tool_output:
message_dicts.append(
{"tool_output": str(tool_output)}
) # Ensure tool_output is a string
# JSON dump the message_dicts and options without additional escaping
message_json = json.dumps(message_dicts)
options = {
"temperature": self.temperature,
"top_p": self.top_p if self.top_p is not None else 1.0,
"max_tokens": self.max_tokens if self.max_tokens is not None else 2048,
}
options_json = json.dumps(options) # JSON string of options
# Form the SQL statement using JSON literals
sql_stmt = f"""
select snowflake.cortex.{self.cortex_function}(
'{self.model}'
,{message_str},{options_str}) as llm_response;"""
'{self.model}',
parse_json('{message_json}'),
parse_json('{options_json}')
) as llm_response;
"""
try:
l_rows = self._sp_session.sql(sql_stmt).collect()
# Use the Snowflake Cortex Complete function
self.session.sql(
f"USE WAREHOUSE {self.session.get_current_warehouse()};"
).collect()
l_rows = self.session.sql(sql_stmt).collect()
except Exception as e:
raise ChatSnowflakeCortexError(
f"Error while making request to Snowflake Cortex via Snowpark: {e}"
f"Error while making request to Snowflake Cortex: {e}"
)
response = json.loads(l_rows[0]["LLM_RESPONSE"])
@ -229,3 +303,89 @@ class ChatSnowflakeCortex(BaseChatModel):
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream_content(
self, content: str, stop: Optional[List[str]]
) -> Iterator[ChatGenerationChunk]:
"""
Stream the output of the model in chunks to return ChatGenerationChunk.
"""
chunk_size = 50 # Define a reasonable chunk size for streaming
truncated_content = _truncate_at_stop_tokens(content, stop)
for i in range(0, len(truncated_content), chunk_size):
chunk_content = truncated_content[i : i + chunk_size]
# Create and yield a ChatGenerationChunk with partial content
yield ChatGenerationChunk(message=AIMessageChunk(content=chunk_content))
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""Stream the output of the model in chunks to return ChatGenerationChunk."""
message_dicts = [_convert_message_to_dict(m) for m in messages]
# Check for and potentially use a tool before streaming
for message in messages:
if (
isinstance(message, str)
and isinstance(message, SystemMessage)
and "invoke_tool" in message.content
):
tool_info = json.loads(message.content)
tool_list = tool_info.get("invoke_tools", [])
for tool in tool_list:
tool_name = tool.get("tool_name")
tool_args = tool.get("args", {})
if tool_name in self.test_tools:
tool_args = tool_info.get("args", {})
tool_result = self.test_tools[tool_name](**tool_args)
additional_context = {"tool_output": tool_result}
message_dicts.append(
additional_context
) # Append tool result to message dicts
# JSON dump the message_dicts and options without additional escaping
message_json = json.dumps(message_dicts)
options = {
"temperature": self.temperature,
"top_p": self.top_p if self.top_p is not None else 1.0,
"max_tokens": self.max_tokens if self.max_tokens is not None else 2048,
# "stream": True,
}
options_json = json.dumps(options) # JSON string of options
# Form the SQL statement using JSON literals
sql_stmt = f"""
select snowflake.cortex.{self.cortex_function}(
'{self.model}',
parse_json('{message_json}'),
parse_json('{options_json}')
) as llm_stream_response;
"""
try:
# Use the Snowflake Cortex Complete function
self.session.sql(
f"USE WAREHOUSE {self.session.get_current_warehouse()};"
).collect()
result = self.session.sql(sql_stmt).collect()
# Iterate over the generator to yield streaming responses
for row in result:
response = json.loads(row["LLM_STREAM_RESPONSE"])
ai_message_content = response["choices"][0]["messages"]
# Stream response content in chunks
for chunk in self._stream_content(ai_message_content, stop):
yield chunk
except Exception as e:
raise ChatSnowflakeCortexError(
f"Error while making request to Snowflake Cortex stream: {e}"
)

View File

@ -20,7 +20,7 @@ count=$(git grep -E '(@root_validator)|(@validator)|(@field_validator)|(@pre_ini
# PRs that increase the current count will not be accepted.
# PRs that decrease update the code in the repository
# and allow decreasing the count of are welcome!
current_count=125
current_count=124
if [ "$count" -gt "$current_count" ]; then
echo "The PR seems to be introducing new usage of @root_validator and/or @field_validator."