community: Add Sambanova Cloud Chat model community integration (#26333)

**Description:** : Add SambaNova Cloud Chat model community integration
Includes 
- chat model integration (following Standardize ChatModel docstrings)
-  tests
- docs usage notebook (following Standardize ChatModel integration docs)

https://cloud.sambanova.ai/

---------

Co-authored-by: luisfucros <luisfucros@gmail.com>
Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
Jorge Piedrahita Ortiz 2024-09-24 09:11:32 -05:00 committed by GitHub
parent 2b83c7c3ab
commit 408a930d55
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 856 additions and 0 deletions

View File

@ -0,0 +1,374 @@
{
"cells": [
{
"cell_type": "raw",
"metadata": {
"vscode": {
"languageId": "raw"
}
},
"source": [
"---\n",
"sidebar_label: SambaNovaCloud\n",
"---"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# ChatSambaNovaCloud\n",
"\n",
"This will help you getting started with SambaNovaCloud [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatSambaNovaCloud features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html).\n",
"\n",
"**[SambaNova](https://sambanova.ai/)'s** [SambaNova Cloud](https://cloud.sambanova.ai/) is a platform for performing inference with open-source models\n",
"\n",
"## Overview\n",
"### Integration details\n",
"\n",
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
"| [ChatSambaNovaCloud](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html) | [langchain-community](https://python.langchain.com/v0.2/api_reference/community/index.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n",
"\n",
"### Model features\n",
"\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
"To access ChatSambaNovaCloud models you will need to create a [SambaNovaCloud](https://cloud.sambanova.ai/) account, get an API key, install the `langchain_community` integration package, and install the `SSEClient` Package.\n",
"\n",
"```bash\n",
"pip install langchain-community\n",
"pip install sseclient-py\n",
"```\n",
"\n",
"### Credentials\n",
"\n",
"Get an API Key from [cloud.sambanova.ai](https://cloud.sambanova.ai/apis) and add it to your environment variables:\n",
"\n",
"``` bash\n",
"export SAMBANOVA_API_KEY=\"your-api-key-here\"\n",
"```"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import getpass\n",
"import os\n",
"\n",
"if not os.getenv(\"SAMBANOVA_API_KEY\"):\n",
" os.environ[\"SAMBANOVA_API_KEY\"] = getpass.getpass(\n",
" \"Enter your SambaNova Cloud API key: \"\n",
" )"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Installation\n",
"\n",
"The LangChain __SambaNovaCloud__ integration lives in the `langchain_community` package:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"%pip install -qU langchain-community\n",
"%pip install -qu sseclient-py"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Instantiation\n",
"\n",
"Now we can instantiate our model object and generate chat completions:"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.sambanova import ChatSambaNovaCloud\n",
"\n",
"llm = ChatSambaNovaCloud(\n",
" model=\"llama3-405b\", max_tokens=1024, temperature=0.7, top_k=1, top_p=0.01\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Invocation"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content=\"J'adore la programmation.\", response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 11, 'completion_tokens': 9, 'completion_tokens_after_first_per_sec': 97.07042823956884, 'completion_tokens_after_first_per_sec_first_ten': 276.3343994441849, 'completion_tokens_per_sec': 23.775192800224037, 'end_time': 1726158364.7954874, 'is_last_response': True, 'prompt_tokens': 56, 'start_time': 1726158364.3670964, 'time_to_first_token': 0.3459765911102295, 'total_latency': 0.3785458261316473, 'total_tokens': 65, 'total_tokens_per_sec': 171.70972577939582}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726158364}, id='7154b676-9d5a-4b1a-a425-73bbe69f28fc')"
]
},
"execution_count": 3,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"messages = [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
" ),\n",
" (\"human\", \"I love programming.\"),\n",
"]\n",
"ai_msg = llm.invoke(messages)\n",
"ai_msg"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"J'adore la programmation.\n"
]
}
],
"source": [
"print(ai_msg.content)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Chaining\n",
"\n",
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Ich liebe Programmieren.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 11, 'completion_tokens': 6, 'completion_tokens_after_first_per_sec': 47.80258530102961, 'completion_tokens_after_first_per_sec_first_ten': 215.59002827036753, 'completion_tokens_per_sec': 5.263977583489829, 'end_time': 1726158506.3777263, 'is_last_response': True, 'prompt_tokens': 51, 'start_time': 1726158505.1611376, 'time_to_first_token': 1.1119918823242188, 'total_latency': 1.1398224830627441, 'total_tokens': 57, 'total_tokens_per_sec': 50.00778704315337}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726158505}, id='226471ac-8c52-44bb-baa7-f9d2f8c54477')"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"from langchain_core.prompts import ChatPromptTemplate\n",
"\n",
"prompt = ChatPromptTemplate(\n",
" [\n",
" (\n",
" \"system\",\n",
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
" ),\n",
" (\"human\", \"{input}\"),\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"chain.invoke(\n",
" {\n",
" \"input_language\": \"English\",\n",
" \"output_language\": \"German\",\n",
" \"input\": \"I love programming.\",\n",
" }\n",
")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Streaming"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Yer lookin' fer some info on owls, eh? Alright then, matey, settle yerself down with a pint o' grog and listen close.\n",
"\n",
"Owls be nocturnal birds o' prey, meanin' they do most o' their huntin' at night. They got big, round eyes that be perfect fer seein' in the dark, like a trusty lantern on a dark sea. Their ears be sharp as a cutlass, too, helpin' 'em pinpoint the slightest sound o' a scurvy rodent scurryin' through the underbrush.\n",
"\n",
"These birds be known fer their silent flight, like a ghost ship sailin' through the night. Their feathers be special, with a soft, fringed edge that helps 'em sneak up on their prey. And when they strike, it be swift and deadly, like a pirate's sword.\n",
"\n",
"Owls be found all over the world, from the frozen tundras o' the north to the scorching deserts o' the south. They come in all shapes and sizes, from the tiny elf owl to the great grey owl, which be as big as a small dog.\n",
"\n",
"Now, I know what ye be thinkin', \"Pirate, what about their hootin'?\" Aye, owls be famous fer their hoots, which be a form o' communication. They use different hoots to warn off predators, attract a mate, or even just to say, \"Shiver me timbers, I be happy to be alive!\"\n",
"\n",
"So there ye have it, me hearty. Owls be fascinatin' creatures, and I hope ye found this info as interestin' as a chest overflowin' with gold doubloons. Fair winds and following seas!"
]
}
],
"source": [
"system = \"You are a helpful assistant with pirate accent.\"\n",
"human = \"I want to learn more about this animal: {animal}\"\n",
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
"\n",
"chain = prompt | llm\n",
"\n",
"for chunk in chain.stream({\"animal\": \"owl\"}):\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Async"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='The capital of France is Paris.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 13, 'completion_tokens': 8, 'completion_tokens_after_first_per_sec': 86.00726488715989, 'completion_tokens_after_first_per_sec_first_ten': 326.92555640828857, 'completion_tokens_per_sec': 21.74539360394493, 'end_time': 1726159287.9987085, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726159287.5738964, 'time_to_first_token': 0.34342360496520996, 'total_latency': 0.36789400760944074, 'total_tokens': 51, 'total_tokens_per_sec': 138.62688422514893}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726159287}, id='9b4ef015-50a2-434b-b980-29f8aa90c3e8')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"human\",\n",
" \"what is the capital of {country}?\",\n",
" )\n",
" ]\n",
")\n",
"\n",
"chain = prompt | llm\n",
"await chain.ainvoke({\"country\": \"France\"})"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Async Streaming"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Quantum computers use quantum bits (qubits) to process vast amounts of data simultaneously, leveraging quantum mechanics to solve complex problems exponentially faster than classical computers."
]
}
],
"source": [
"prompt = ChatPromptTemplate.from_messages(\n",
" [\n",
" (\n",
" \"human\",\n",
" \"in less than {num_words} words explain me {topic} \",\n",
" )\n",
" ]\n",
")\n",
"chain = prompt | llm\n",
"\n",
"async for chunk in chain.astream({\"num_words\": 30, \"topic\": \"quantum computers\"}):\n",
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## API reference\n",
"\n",
"For detailed documentation of all ChatSambaNovaCloud features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.19"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -147,6 +147,9 @@ if TYPE_CHECKING:
from langchain_community.chat_models.promptlayer_openai import (
PromptLayerChatOpenAI,
)
from langchain_community.chat_models.sambanova import (
ChatSambaNovaCloud,
)
from langchain_community.chat_models.snowflake import (
ChatSnowflakeCortex,
)
@ -211,6 +214,7 @@ __all__ = [
"ChatOpenAI",
"ChatPerplexity",
"ChatPremAI",
"ChatSambaNovaCloud",
"ChatSparkLLM",
"ChatSnowflakeCortex",
"ChatTongyi",
@ -269,6 +273,7 @@ _module_lookup = {
"ChatOllama": "langchain_community.chat_models.ollama",
"ChatOpenAI": "langchain_community.chat_models.openai",
"ChatPerplexity": "langchain_community.chat_models.perplexity",
"ChatSambaNovaCloud": "langchain_community.chat_models.sambanova",
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
"ChatTongyi": "langchain_community.chat_models.tongyi",

View File

@ -0,0 +1,465 @@
import json
from typing import Any, Dict, Iterator, List, Optional
import requests
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import (
BaseChatModel,
generate_from_stream,
)
from langchain_core.messages import (
AIMessage,
AIMessageChunk,
BaseMessage,
ChatMessage,
HumanMessage,
SystemMessage,
ToolMessage,
)
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from pydantic import Field, SecretStr
class ChatSambaNovaCloud(BaseChatModel):
"""
SambaNova Cloud chat model.
Setup:
To use, you should have the environment variables
``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
http://cloud.sambanova.ai/
Example:
.. code-block:: python
ChatSambaNovaCloud(
sambanova_url = SambaNova cloud endpoint URL,
sambanova_api_key = set with your SambaNova cloud API key,
model = model name,
streaming = set True for use streaming API
max_tokens = max number of tokens to generate,
temperature = model temperature,
top_p = model top p,
top_k = model top k,
stream_options = include usage to get generation metrics
)
Key init args completion params:
model: str
The name of the model to use, e.g., llama3-8b.
streaming: bool
Whether to use streaming or not
max_tokens: int
max tokens to generate
temperature: float
model temperature
top_p: float
model top p
top_k: int
model top k
stream_options: dict
stream options, include usage to get generation metrics
Key init args client params:
sambanova_url: str
SambaNova Cloud Url
sambanova_api_key: str
SambaNova Cloud api key
Instantiate:
.. code-block:: python
from langchain_community.chat_models import ChatSambaNovaCloud
chat = ChatSambaNovaCloud(
sambanova_url = SambaNova cloud endpoint URL,
sambanova_api_key = set with your SambaNova cloud API key,
model = model name,
streaming = set True for streaming
max_tokens = max number of tokens to generate,
temperature = model temperature,
top_p = model top p,
top_k = model top k,
stream_options = include usage to get generation metrics
)
Invoke:
.. code-block:: python
messages = [
SystemMessage(content="your are an AI assistant."),
HumanMessage(content="tell me a joke."),
]
response = chat.invoke(messages)
Stream:
.. code-block:: python
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
Async:
.. code-block:: python
response = chat.ainvoke(messages)
await response
Token usage:
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata["usage"]["prompt_tokens"]
print(response.response_metadata["usage"]["total_tokens"]
Response metadata
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata)
"""
sambanova_url: str = Field(default="")
"""SambaNova Cloud Url"""
sambanova_api_key: SecretStr = Field(default="")
"""SambaNova Cloud api key"""
model: str = Field(default="llama3-8b")
"""The name of the model"""
streaming: bool = Field(default=False)
"""Whether to use streaming or not"""
max_tokens: int = Field(default=1024)
"""max tokens to generate"""
temperature: float = Field(default=0.7)
"""model temperature"""
top_p: float = Field(default=0.0)
"""model top p"""
top_k: int = Field(default=1)
"""model top k"""
stream_options: dict = Field(default={"include_usage": True})
"""stream options, include usage to get generation metrics"""
class Config:
populate_by_name = True
@classmethod
def is_lc_serializable(cls) -> bool:
"""Return whether this model can be serialized by Langchain."""
return False
@property
def lc_secrets(self) -> Dict[str, str]:
return {"sambanova_api_key": "sambanova_api_key"}
@property
def _identifying_params(self) -> Dict[str, Any]:
"""Return a dictionary of identifying parameters.
This information is used by the LangChain callback system, which
is used for tracing purposes make it possible to monitor LLMs.
"""
return {
"model": self.model,
"streaming": self.streaming,
"max_tokens": self.max_tokens,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"stream_options": self.stream_options,
}
@property
def _llm_type(self) -> str:
"""Get the type of language model used by this chat model."""
return "sambanovacloud-chatmodel"
def __init__(self, **kwargs: Any) -> None:
"""init and validate environment variables"""
kwargs["sambanova_url"] = get_from_dict_or_env(
kwargs,
"sambanova_url",
"SAMBANOVA_URL",
default="https://api.sambanova.ai/v1/chat/completions",
)
kwargs["sambanova_api_key"] = convert_to_secret_str(
get_from_dict_or_env(kwargs, "sambanova_api_key", "SAMBANOVA_API_KEY")
)
super().__init__(**kwargs)
def _handle_request(
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
) -> Dict[str, Any]:
"""
Performs a post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Returns:
An iterator of response dicts.
"""
data = {
"messages": messages_dicts,
"max_tokens": self.max_tokens,
"stop": stop,
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
}
http_session = requests.Session()
response = http_session.post(
self.sambanova_url,
headers={
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
"Content-Type": "application/json",
},
json=data,
)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}."
f"{response.text}."
)
response_dict = response.json()
if response_dict.get("error"):
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}."
f"{response_dict}."
)
return response_dict
def _handle_streaming_request(
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
) -> Iterator[Dict]:
"""
Performs an streaming post request to the LLM API.
Args:
messages_dicts: List of role / content dicts to use as input.
stop: list of stop tokens
Returns:
An iterator of response dicts.
"""
try:
import sseclient
except ImportError:
raise ImportError(
"could not import sseclient library"
"Please install it with `pip install sseclient-py`."
)
data = {
"messages": messages_dicts,
"max_tokens": self.max_tokens,
"stop": stop,
"model": self.model,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"stream": True,
"stream_options": self.stream_options,
}
http_session = requests.Session()
response = http_session.post(
self.sambanova_url,
headers={
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
"Content-Type": "application/json",
},
json=data,
stream=True,
)
client = sseclient.SSEClient(response)
if response.status_code != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{response.status_code}."
f"{response.text}."
)
for event in client.events():
chunk = {
"event": event.event,
"data": event.data,
"status_code": response.status_code,
}
if chunk["event"] == "error_event" or chunk["status_code"] != 200:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
)
try:
# check if the response is a final event
# in that case event data response is '[DONE]'
if chunk["data"] != "[DONE]":
if isinstance(chunk["data"], str):
data = json.loads(chunk["data"])
else:
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
)
if data.get("error"):
raise RuntimeError(
f"Sambanova /complete call failed with status code "
f"{chunk['status_code']}."
f"{chunk}."
)
yield data
except Exception:
raise Exception(
f"Error getting content chunk raw streamed response: {chunk}"
)
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
"""
convert a BaseMessage to a dictionary with Role / content
Args:
message: BaseMessage
Returns:
messages_dict: role / content dict
"""
if isinstance(message, ChatMessage):
message_dict = {"role": message.role, "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, ToolMessage):
message_dict = {"role": "tool", "content": message.content}
else:
raise TypeError(f"Got unknown type {message}")
return message_dict
def _create_message_dicts(
self, messages: List[BaseMessage]
) -> List[Dict[str, Any]]:
"""
convert a lit of BaseMessages to a list of dictionaries with Role / content
Args:
messages: list of BaseMessages
Returns:
messages_dicts: list of role / content dicts
"""
message_dicts = [self._convert_message_to_dict(m) for m in messages]
return message_dicts
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
SambaNovaCloud chat model logic.
Call SambaNovaCloud API.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
if self.streaming:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
if stream_iter:
return generate_from_stream(stream_iter)
messages_dicts = self._create_message_dicts(messages)
response = self._handle_request(messages_dicts, stop)
message = AIMessage(
content=response["choices"][0]["message"]["content"],
additional_kwargs={},
response_metadata={
"finish_reason": response["choices"][0]["finish_reason"],
"usage": response.get("usage"),
"model_name": response["model"],
"system_fingerprint": response["system_fingerprint"],
"created": response["created"],
},
id=response["id"],
)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
def _stream(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
"""
Stream the output of the SambaNovaCloud chat model.
Args:
messages: the prompt composed of a list of messages.
stop: a list of strings on which the model should stop generating.
If generation stops due to a stop token, the stop token itself
SHOULD BE INCLUDED as part of the output. This is not enforced
across models right now, but it's a good practice to follow since
it makes it much easier to parse the output of the model
downstream and understand why generation stopped.
run_manager: A run manager with callbacks for the LLM.
"""
messages_dicts = self._create_message_dicts(messages)
finish_reason = None
for partial_response in self._handle_streaming_request(messages_dicts, stop):
if len(partial_response["choices"]) > 0:
finish_reason = partial_response["choices"][0].get("finish_reason")
content = partial_response["choices"][0]["delta"]["content"]
id = partial_response["id"]
chunk = ChatGenerationChunk(
message=AIMessageChunk(content=content, id=id, additional_kwargs={})
)
else:
content = ""
id = partial_response["id"]
metadata = {
"finish_reason": finish_reason,
"usage": partial_response.get("usage"),
"model_name": partial_response["model"],
"system_fingerprint": partial_response["system_fingerprint"],
"created": partial_response["created"],
}
chunk = ChatGenerationChunk(
message=AIMessageChunk(
content=content,
id=id,
response_metadata=metadata,
additional_kwargs={},
)
)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk

View File

@ -0,0 +1,11 @@
from langchain_core.messages import AIMessage, HumanMessage
from langchain_community.chat_models.sambanova import ChatSambaNovaCloud
def test_chat_sambanova_cloud() -> None:
chat = ChatSambaNovaCloud()
message = HumanMessage(content="Hello")
response = chat.invoke([message])
assert isinstance(response, AIMessage)
assert isinstance(response.content, str)

View File

@ -34,6 +34,7 @@ EXPECTED_ALL = [
"ChatOpenAI",
"ChatPerplexity",
"ChatPremAI",
"ChatSambaNovaCloud",
"ChatSparkLLM",
"ChatTongyi",
"ChatVertexAI",