mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 15:19:33 +00:00
community: Add Sambanova Cloud Chat model community integration (#26333)
**Description:** : Add SambaNova Cloud Chat model community integration Includes - chat model integration (following Standardize ChatModel docstrings) - tests - docs usage notebook (following Standardize ChatModel integration docs) https://cloud.sambanova.ai/ --------- Co-authored-by: luisfucros <luisfucros@gmail.com> Co-authored-by: ccurme <chester.curme@gmail.com>
This commit is contained in:
parent
2b83c7c3ab
commit
408a930d55
374
docs/docs/integrations/chat/sambanova.ipynb
Normal file
374
docs/docs/integrations/chat/sambanova.ipynb
Normal file
@ -0,0 +1,374 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "raw",
|
||||
"metadata": {
|
||||
"vscode": {
|
||||
"languageId": "raw"
|
||||
}
|
||||
},
|
||||
"source": [
|
||||
"---\n",
|
||||
"sidebar_label: SambaNovaCloud\n",
|
||||
"---"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# ChatSambaNovaCloud\n",
|
||||
"\n",
|
||||
"This will help you getting started with SambaNovaCloud [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatSambaNovaCloud features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html).\n",
|
||||
"\n",
|
||||
"**[SambaNova](https://sambanova.ai/)'s** [SambaNova Cloud](https://cloud.sambanova.ai/) is a platform for performing inference with open-source models\n",
|
||||
"\n",
|
||||
"## Overview\n",
|
||||
"### Integration details\n",
|
||||
"\n",
|
||||
"| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n",
|
||||
"| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| [ChatSambaNovaCloud](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html) | [langchain-community](https://python.langchain.com/v0.2/api_reference/community/index.html) | ❌ | ❌ | ❌ |  |  |\n",
|
||||
"\n",
|
||||
"### Model features\n",
|
||||
"\n",
|
||||
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
|
||||
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
|
||||
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
|
||||
"\n",
|
||||
"## Setup\n",
|
||||
"\n",
|
||||
"To access ChatSambaNovaCloud models you will need to create a [SambaNovaCloud](https://cloud.sambanova.ai/) account, get an API key, install the `langchain_community` integration package, and install the `SSEClient` Package.\n",
|
||||
"\n",
|
||||
"```bash\n",
|
||||
"pip install langchain-community\n",
|
||||
"pip install sseclient-py\n",
|
||||
"```\n",
|
||||
"\n",
|
||||
"### Credentials\n",
|
||||
"\n",
|
||||
"Get an API Key from [cloud.sambanova.ai](https://cloud.sambanova.ai/apis) and add it to your environment variables:\n",
|
||||
"\n",
|
||||
"``` bash\n",
|
||||
"export SAMBANOVA_API_KEY=\"your-api-key-here\"\n",
|
||||
"```"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import getpass\n",
|
||||
"import os\n",
|
||||
"\n",
|
||||
"if not os.getenv(\"SAMBANOVA_API_KEY\"):\n",
|
||||
" os.environ[\"SAMBANOVA_API_KEY\"] = getpass.getpass(\n",
|
||||
" \"Enter your SambaNova Cloud API key: \"\n",
|
||||
" )"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n",
|
||||
"# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"### Installation\n",
|
||||
"\n",
|
||||
"The LangChain __SambaNovaCloud__ integration lives in the `langchain_community` package:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"%pip install -qU langchain-community\n",
|
||||
"%pip install -qu sseclient-py"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Instantiation\n",
|
||||
"\n",
|
||||
"Now we can instantiate our model object and generate chat completions:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain_community.chat_models.sambanova import ChatSambaNovaCloud\n",
|
||||
"\n",
|
||||
"llm = ChatSambaNovaCloud(\n",
|
||||
" model=\"llama3-405b\", max_tokens=1024, temperature=0.7, top_k=1, top_p=0.01\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Invocation"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content=\"J'adore la programmation.\", response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 11, 'completion_tokens': 9, 'completion_tokens_after_first_per_sec': 97.07042823956884, 'completion_tokens_after_first_per_sec_first_ten': 276.3343994441849, 'completion_tokens_per_sec': 23.775192800224037, 'end_time': 1726158364.7954874, 'is_last_response': True, 'prompt_tokens': 56, 'start_time': 1726158364.3670964, 'time_to_first_token': 0.3459765911102295, 'total_latency': 0.3785458261316473, 'total_tokens': 65, 'total_tokens_per_sec': 171.70972577939582}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726158364}, id='7154b676-9d5a-4b1a-a425-73bbe69f28fc')"
|
||||
]
|
||||
},
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"messages = [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"I love programming.\"),\n",
|
||||
"]\n",
|
||||
"ai_msg = llm.invoke(messages)\n",
|
||||
"ai_msg"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"J'adore la programmation.\n"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"print(ai_msg.content)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Chaining\n",
|
||||
"\n",
|
||||
"We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='Ich liebe Programmieren.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 11, 'completion_tokens': 6, 'completion_tokens_after_first_per_sec': 47.80258530102961, 'completion_tokens_after_first_per_sec_first_ten': 215.59002827036753, 'completion_tokens_per_sec': 5.263977583489829, 'end_time': 1726158506.3777263, 'is_last_response': True, 'prompt_tokens': 51, 'start_time': 1726158505.1611376, 'time_to_first_token': 1.1119918823242188, 'total_latency': 1.1398224830627441, 'total_tokens': 57, 'total_tokens_per_sec': 50.00778704315337}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726158505}, id='226471ac-8c52-44bb-baa7-f9d2f8c54477')"
|
||||
]
|
||||
},
|
||||
"execution_count": 5,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"from langchain_core.prompts import ChatPromptTemplate\n",
|
||||
"\n",
|
||||
"prompt = ChatPromptTemplate(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"system\",\n",
|
||||
" \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n",
|
||||
" ),\n",
|
||||
" (\"human\", \"{input}\"),\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"chain.invoke(\n",
|
||||
" {\n",
|
||||
" \"input_language\": \"English\",\n",
|
||||
" \"output_language\": \"German\",\n",
|
||||
" \"input\": \"I love programming.\",\n",
|
||||
" }\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Yer lookin' fer some info on owls, eh? Alright then, matey, settle yerself down with a pint o' grog and listen close.\n",
|
||||
"\n",
|
||||
"Owls be nocturnal birds o' prey, meanin' they do most o' their huntin' at night. They got big, round eyes that be perfect fer seein' in the dark, like a trusty lantern on a dark sea. Their ears be sharp as a cutlass, too, helpin' 'em pinpoint the slightest sound o' a scurvy rodent scurryin' through the underbrush.\n",
|
||||
"\n",
|
||||
"These birds be known fer their silent flight, like a ghost ship sailin' through the night. Their feathers be special, with a soft, fringed edge that helps 'em sneak up on their prey. And when they strike, it be swift and deadly, like a pirate's sword.\n",
|
||||
"\n",
|
||||
"Owls be found all over the world, from the frozen tundras o' the north to the scorching deserts o' the south. They come in all shapes and sizes, from the tiny elf owl to the great grey owl, which be as big as a small dog.\n",
|
||||
"\n",
|
||||
"Now, I know what ye be thinkin', \"Pirate, what about their hootin'?\" Aye, owls be famous fer their hoots, which be a form o' communication. They use different hoots to warn off predators, attract a mate, or even just to say, \"Shiver me timbers, I be happy to be alive!\"\n",
|
||||
"\n",
|
||||
"So there ye have it, me hearty. Owls be fascinatin' creatures, and I hope ye found this info as interestin' as a chest overflowin' with gold doubloons. Fair winds and following seas!"
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"system = \"You are a helpful assistant with pirate accent.\"\n",
|
||||
"human = \"I want to learn more about this animal: {animal}\"\n",
|
||||
"prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"\n",
|
||||
"for chunk in chain.stream({\"animal\": \"owl\"}):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Async"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"AIMessage(content='The capital of France is Paris.', response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 13, 'completion_tokens': 8, 'completion_tokens_after_first_per_sec': 86.00726488715989, 'completion_tokens_after_first_per_sec_first_ten': 326.92555640828857, 'completion_tokens_per_sec': 21.74539360394493, 'end_time': 1726159287.9987085, 'is_last_response': True, 'prompt_tokens': 43, 'start_time': 1726159287.5738964, 'time_to_first_token': 0.34342360496520996, 'total_latency': 0.36789400760944074, 'total_tokens': 51, 'total_tokens_per_sec': 138.62688422514893}, 'model_name': 'Meta-Llama-3.1-405B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1726159287}, id='9b4ef015-50a2-434b-b980-29f8aa90c3e8')"
|
||||
]
|
||||
},
|
||||
"execution_count": 7,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"human\",\n",
|
||||
" \"what is the capital of {country}?\",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"chain = prompt | llm\n",
|
||||
"await chain.ainvoke({\"country\": \"France\"})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## Async Streaming"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 8,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stdout",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"Quantum computers use quantum bits (qubits) to process vast amounts of data simultaneously, leveraging quantum mechanics to solve complex problems exponentially faster than classical computers."
|
||||
]
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"prompt = ChatPromptTemplate.from_messages(\n",
|
||||
" [\n",
|
||||
" (\n",
|
||||
" \"human\",\n",
|
||||
" \"in less than {num_words} words explain me {topic} \",\n",
|
||||
" )\n",
|
||||
" ]\n",
|
||||
")\n",
|
||||
"chain = prompt | llm\n",
|
||||
"\n",
|
||||
"async for chunk in chain.astream({\"num_words\": 30, \"topic\": \"quantum computers\"}):\n",
|
||||
" print(chunk.content, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## API reference\n",
|
||||
"\n",
|
||||
"For detailed documentation of all ChatSambaNovaCloud features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaNovaCloud.html"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.9.19"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -147,6 +147,9 @@ if TYPE_CHECKING:
|
||||
from langchain_community.chat_models.promptlayer_openai import (
|
||||
PromptLayerChatOpenAI,
|
||||
)
|
||||
from langchain_community.chat_models.sambanova import (
|
||||
ChatSambaNovaCloud,
|
||||
)
|
||||
from langchain_community.chat_models.snowflake import (
|
||||
ChatSnowflakeCortex,
|
||||
)
|
||||
@ -211,6 +214,7 @@ __all__ = [
|
||||
"ChatOpenAI",
|
||||
"ChatPerplexity",
|
||||
"ChatPremAI",
|
||||
"ChatSambaNovaCloud",
|
||||
"ChatSparkLLM",
|
||||
"ChatSnowflakeCortex",
|
||||
"ChatTongyi",
|
||||
@ -269,6 +273,7 @@ _module_lookup = {
|
||||
"ChatOllama": "langchain_community.chat_models.ollama",
|
||||
"ChatOpenAI": "langchain_community.chat_models.openai",
|
||||
"ChatPerplexity": "langchain_community.chat_models.perplexity",
|
||||
"ChatSambaNovaCloud": "langchain_community.chat_models.sambanova",
|
||||
"ChatSnowflakeCortex": "langchain_community.chat_models.snowflake",
|
||||
"ChatSparkLLM": "langchain_community.chat_models.sparkllm",
|
||||
"ChatTongyi": "langchain_community.chat_models.tongyi",
|
||||
|
465
libs/community/langchain_community/chat_models/sambanova.py
Normal file
465
libs/community/langchain_community/chat_models/sambanova.py
Normal file
@ -0,0 +1,465 @@
|
||||
import json
|
||||
from typing import Any, Dict, Iterator, List, Optional
|
||||
|
||||
import requests
|
||||
from langchain_core.callbacks import (
|
||||
CallbackManagerForLLMRun,
|
||||
)
|
||||
from langchain_core.language_models.chat_models import (
|
||||
BaseChatModel,
|
||||
generate_from_stream,
|
||||
)
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
ChatMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from pydantic import Field, SecretStr
|
||||
|
||||
|
||||
class ChatSambaNovaCloud(BaseChatModel):
|
||||
"""
|
||||
SambaNova Cloud chat model.
|
||||
|
||||
Setup:
|
||||
To use, you should have the environment variables
|
||||
``SAMBANOVA_URL`` set with your SambaNova Cloud URL.
|
||||
``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key.
|
||||
http://cloud.sambanova.ai/
|
||||
Example:
|
||||
.. code-block:: python
|
||||
ChatSambaNovaCloud(
|
||||
sambanova_url = SambaNova cloud endpoint URL,
|
||||
sambanova_api_key = set with your SambaNova cloud API key,
|
||||
model = model name,
|
||||
streaming = set True for use streaming API
|
||||
max_tokens = max number of tokens to generate,
|
||||
temperature = model temperature,
|
||||
top_p = model top p,
|
||||
top_k = model top k,
|
||||
stream_options = include usage to get generation metrics
|
||||
)
|
||||
|
||||
Key init args — completion params:
|
||||
model: str
|
||||
The name of the model to use, e.g., llama3-8b.
|
||||
streaming: bool
|
||||
Whether to use streaming or not
|
||||
max_tokens: int
|
||||
max tokens to generate
|
||||
temperature: float
|
||||
model temperature
|
||||
top_p: float
|
||||
model top p
|
||||
top_k: int
|
||||
model top k
|
||||
stream_options: dict
|
||||
stream options, include usage to get generation metrics
|
||||
|
||||
Key init args — client params:
|
||||
sambanova_url: str
|
||||
SambaNova Cloud Url
|
||||
sambanova_api_key: str
|
||||
SambaNova Cloud api key
|
||||
|
||||
Instantiate:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import ChatSambaNovaCloud
|
||||
|
||||
chat = ChatSambaNovaCloud(
|
||||
sambanova_url = SambaNova cloud endpoint URL,
|
||||
sambanova_api_key = set with your SambaNova cloud API key,
|
||||
model = model name,
|
||||
streaming = set True for streaming
|
||||
max_tokens = max number of tokens to generate,
|
||||
temperature = model temperature,
|
||||
top_p = model top p,
|
||||
top_k = model top k,
|
||||
stream_options = include usage to get generation metrics
|
||||
)
|
||||
Invoke:
|
||||
.. code-block:: python
|
||||
messages = [
|
||||
SystemMessage(content="your are an AI assistant."),
|
||||
HumanMessage(content="tell me a joke."),
|
||||
]
|
||||
response = chat.invoke(messages)
|
||||
|
||||
Stream:
|
||||
.. code-block:: python
|
||||
|
||||
for chunk in chat.stream(messages):
|
||||
print(chunk.content, end="", flush=True)
|
||||
|
||||
Async:
|
||||
.. code-block:: python
|
||||
|
||||
response = chat.ainvoke(messages)
|
||||
await response
|
||||
|
||||
Token usage:
|
||||
.. code-block:: python
|
||||
response = chat.invoke(messages)
|
||||
print(response.response_metadata["usage"]["prompt_tokens"]
|
||||
print(response.response_metadata["usage"]["total_tokens"]
|
||||
|
||||
Response metadata
|
||||
.. code-block:: python
|
||||
|
||||
response = chat.invoke(messages)
|
||||
print(response.response_metadata)
|
||||
"""
|
||||
|
||||
sambanova_url: str = Field(default="")
|
||||
"""SambaNova Cloud Url"""
|
||||
|
||||
sambanova_api_key: SecretStr = Field(default="")
|
||||
"""SambaNova Cloud api key"""
|
||||
|
||||
model: str = Field(default="llama3-8b")
|
||||
"""The name of the model"""
|
||||
|
||||
streaming: bool = Field(default=False)
|
||||
"""Whether to use streaming or not"""
|
||||
|
||||
max_tokens: int = Field(default=1024)
|
||||
"""max tokens to generate"""
|
||||
|
||||
temperature: float = Field(default=0.7)
|
||||
"""model temperature"""
|
||||
|
||||
top_p: float = Field(default=0.0)
|
||||
"""model top p"""
|
||||
|
||||
top_k: int = Field(default=1)
|
||||
"""model top k"""
|
||||
|
||||
stream_options: dict = Field(default={"include_usage": True})
|
||||
"""stream options, include usage to get generation metrics"""
|
||||
|
||||
class Config:
|
||||
populate_by_name = True
|
||||
|
||||
@classmethod
|
||||
def is_lc_serializable(cls) -> bool:
|
||||
"""Return whether this model can be serialized by Langchain."""
|
||||
return False
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"sambanova_api_key": "sambanova_api_key"}
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Dict[str, Any]:
|
||||
"""Return a dictionary of identifying parameters.
|
||||
|
||||
This information is used by the LangChain callback system, which
|
||||
is used for tracing purposes make it possible to monitor LLMs.
|
||||
"""
|
||||
return {
|
||||
"model": self.model,
|
||||
"streaming": self.streaming,
|
||||
"max_tokens": self.max_tokens,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"stream_options": self.stream_options,
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Get the type of language model used by this chat model."""
|
||||
return "sambanovacloud-chatmodel"
|
||||
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
"""init and validate environment variables"""
|
||||
kwargs["sambanova_url"] = get_from_dict_or_env(
|
||||
kwargs,
|
||||
"sambanova_url",
|
||||
"SAMBANOVA_URL",
|
||||
default="https://api.sambanova.ai/v1/chat/completions",
|
||||
)
|
||||
kwargs["sambanova_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(kwargs, "sambanova_api_key", "SAMBANOVA_API_KEY")
|
||||
)
|
||||
super().__init__(**kwargs)
|
||||
|
||||
def _handle_request(
|
||||
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
|
||||
) -> Dict[str, Any]:
|
||||
"""
|
||||
Performs a post request to the LLM API.
|
||||
|
||||
Args:
|
||||
messages_dicts: List of role / content dicts to use as input.
|
||||
stop: list of stop tokens
|
||||
|
||||
Returns:
|
||||
An iterator of response dicts.
|
||||
"""
|
||||
data = {
|
||||
"messages": messages_dicts,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": stop,
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
}
|
||||
http_session = requests.Session()
|
||||
response = http_session.post(
|
||||
self.sambanova_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=data,
|
||||
)
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{response.text}."
|
||||
)
|
||||
response_dict = response.json()
|
||||
if response_dict.get("error"):
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{response_dict}."
|
||||
)
|
||||
return response_dict
|
||||
|
||||
def _handle_streaming_request(
|
||||
self, messages_dicts: List[Dict], stop: Optional[List[str]] = None
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
Performs an streaming post request to the LLM API.
|
||||
|
||||
Args:
|
||||
messages_dicts: List of role / content dicts to use as input.
|
||||
stop: list of stop tokens
|
||||
|
||||
Returns:
|
||||
An iterator of response dicts.
|
||||
"""
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
)
|
||||
data = {
|
||||
"messages": messages_dicts,
|
||||
"max_tokens": self.max_tokens,
|
||||
"stop": stop,
|
||||
"model": self.model,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"top_k": self.top_k,
|
||||
"stream": True,
|
||||
"stream_options": self.stream_options,
|
||||
}
|
||||
http_session = requests.Session()
|
||||
response = http_session.post(
|
||||
self.sambanova_url,
|
||||
headers={
|
||||
"Authorization": f"Bearer {self.sambanova_api_key.get_secret_value()}",
|
||||
"Content-Type": "application/json",
|
||||
},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
|
||||
client = sseclient.SSEClient(response)
|
||||
|
||||
if response.status_code != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response.status_code}."
|
||||
f"{response.text}."
|
||||
)
|
||||
|
||||
for event in client.events():
|
||||
chunk = {
|
||||
"event": event.event,
|
||||
"data": event.data,
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
|
||||
if chunk["event"] == "error_event" or chunk["status_code"] != 200:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
|
||||
try:
|
||||
# check if the response is a final event
|
||||
# in that case event data response is '[DONE]'
|
||||
if chunk["data"] != "[DONE]":
|
||||
if isinstance(chunk["data"], str):
|
||||
data = json.loads(chunk["data"])
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
if data.get("error"):
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
yield data
|
||||
except Exception:
|
||||
raise Exception(
|
||||
f"Error getting content chunk raw streamed response: {chunk}"
|
||||
)
|
||||
|
||||
def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]:
|
||||
"""
|
||||
convert a BaseMessage to a dictionary with Role / content
|
||||
|
||||
Args:
|
||||
message: BaseMessage
|
||||
|
||||
Returns:
|
||||
messages_dict: role / content dict
|
||||
"""
|
||||
if isinstance(message, ChatMessage):
|
||||
message_dict = {"role": message.role, "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, ToolMessage):
|
||||
message_dict = {"role": "tool", "content": message.content}
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
def _create_message_dicts(
|
||||
self, messages: List[BaseMessage]
|
||||
) -> List[Dict[str, Any]]:
|
||||
"""
|
||||
convert a lit of BaseMessages to a list of dictionaries with Role / content
|
||||
|
||||
Args:
|
||||
messages: list of BaseMessages
|
||||
|
||||
Returns:
|
||||
messages_dicts: list of role / content dicts
|
||||
"""
|
||||
message_dicts = [self._convert_message_to_dict(m) for m in messages]
|
||||
return message_dicts
|
||||
|
||||
def _generate(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResult:
|
||||
"""
|
||||
SambaNovaCloud chat model logic.
|
||||
|
||||
Call SambaNovaCloud API.
|
||||
|
||||
Args:
|
||||
messages: the prompt composed of a list of messages.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
If generation stops due to a stop token, the stop token itself
|
||||
SHOULD BE INCLUDED as part of the output. This is not enforced
|
||||
across models right now, but it's a good practice to follow since
|
||||
it makes it much easier to parse the output of the model
|
||||
downstream and understand why generation stopped.
|
||||
run_manager: A run manager with callbacks for the LLM.
|
||||
"""
|
||||
if self.streaming:
|
||||
stream_iter = self._stream(
|
||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
||||
)
|
||||
if stream_iter:
|
||||
return generate_from_stream(stream_iter)
|
||||
messages_dicts = self._create_message_dicts(messages)
|
||||
response = self._handle_request(messages_dicts, stop)
|
||||
message = AIMessage(
|
||||
content=response["choices"][0]["message"]["content"],
|
||||
additional_kwargs={},
|
||||
response_metadata={
|
||||
"finish_reason": response["choices"][0]["finish_reason"],
|
||||
"usage": response.get("usage"),
|
||||
"model_name": response["model"],
|
||||
"system_fingerprint": response["system_fingerprint"],
|
||||
"created": response["created"],
|
||||
},
|
||||
id=response["id"],
|
||||
)
|
||||
|
||||
generation = ChatGeneration(message=message)
|
||||
return ChatResult(generations=[generation])
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
messages: List[BaseMessage],
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
"""
|
||||
Stream the output of the SambaNovaCloud chat model.
|
||||
|
||||
Args:
|
||||
messages: the prompt composed of a list of messages.
|
||||
stop: a list of strings on which the model should stop generating.
|
||||
If generation stops due to a stop token, the stop token itself
|
||||
SHOULD BE INCLUDED as part of the output. This is not enforced
|
||||
across models right now, but it's a good practice to follow since
|
||||
it makes it much easier to parse the output of the model
|
||||
downstream and understand why generation stopped.
|
||||
run_manager: A run manager with callbacks for the LLM.
|
||||
"""
|
||||
messages_dicts = self._create_message_dicts(messages)
|
||||
finish_reason = None
|
||||
for partial_response in self._handle_streaming_request(messages_dicts, stop):
|
||||
if len(partial_response["choices"]) > 0:
|
||||
finish_reason = partial_response["choices"][0].get("finish_reason")
|
||||
content = partial_response["choices"][0]["delta"]["content"]
|
||||
id = partial_response["id"]
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(content=content, id=id, additional_kwargs={})
|
||||
)
|
||||
else:
|
||||
content = ""
|
||||
id = partial_response["id"]
|
||||
metadata = {
|
||||
"finish_reason": finish_reason,
|
||||
"usage": partial_response.get("usage"),
|
||||
"model_name": partial_response["model"],
|
||||
"system_fingerprint": partial_response["system_fingerprint"],
|
||||
"created": partial_response["created"],
|
||||
}
|
||||
chunk = ChatGenerationChunk(
|
||||
message=AIMessageChunk(
|
||||
content=content,
|
||||
id=id,
|
||||
response_metadata=metadata,
|
||||
additional_kwargs={},
|
||||
)
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||||
yield chunk
|
@ -0,0 +1,11 @@
|
||||
from langchain_core.messages import AIMessage, HumanMessage
|
||||
|
||||
from langchain_community.chat_models.sambanova import ChatSambaNovaCloud
|
||||
|
||||
|
||||
def test_chat_sambanova_cloud() -> None:
|
||||
chat = ChatSambaNovaCloud()
|
||||
message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([message])
|
||||
assert isinstance(response, AIMessage)
|
||||
assert isinstance(response.content, str)
|
@ -34,6 +34,7 @@ EXPECTED_ALL = [
|
||||
"ChatOpenAI",
|
||||
"ChatPerplexity",
|
||||
"ChatPremAI",
|
||||
"ChatSambaNovaCloud",
|
||||
"ChatSparkLLM",
|
||||
"ChatTongyi",
|
||||
"ChatVertexAI",
|
||||
|
Loading…
Reference in New Issue
Block a user