diff --git a/docs/docs/integrations/chat/sambastudio.ipynb b/docs/docs/integrations/chat/sambastudio.ipynb new file mode 100644 index 00000000000..9301d68fe33 --- /dev/null +++ b/docs/docs/integrations/chat/sambastudio.ipynb @@ -0,0 +1,383 @@ +{ + "cells": [ + { + "cell_type": "raw", + "metadata": { + "vscode": { + "languageId": "raw" + } + }, + "source": [ + "---\n", + "sidebar_label: SambaStudio\n", + "---" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatSambaStudio\n", + "\n", + "This will help you getting started with SambaNovaCloud [chat models](/docs/concepts/#chat-models). For detailed documentation of all ChatStudio features and configurations head to the [API reference](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaStudio.html).\n", + "\n", + "**[SambaNova](https://sambanova.ai/)'s** [SambaStudio](https://docs.sambanova.ai/sambastudio/latest/sambastudio-intro.html) SambaStudio is a rich, GUI-based platform that provides the functionality to train, deploy, and manage models in SambaNova [DataScale](https://sambanova.ai/products/datascale) systems.\n", + "\n", + "## Overview\n", + "### Integration details\n", + "\n", + "| Class | Package | Local | Serializable | JS support | Package downloads | Package latest |\n", + "| :--- | :--- | :---: | :---: | :---: | :---: | :---: |\n", + "| [ChatSambaStudio](https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaStudio.html) | [langchain-community](https://python.langchain.com/api_reference/community/index.html) | ❌ | ❌ | ❌ | ![PyPI - Downloads](https://img.shields.io/pypi/dm/langchain_community?style=flat-square&label=%20) | ![PyPI - Version](https://img.shields.io/pypi/v/langchain_community?style=flat-square&label=%20) |\n", + "\n", + "### Model features\n", + "\n", + "| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n", + "| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n", + "| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n", + "\n", + "## Setup\n", + "\n", + "To access ChatSambaStudio models you will need to [deploy an endpoint](https://docs.sambanova.ai/sambastudio/latest/language-models.html) in your SambaStudio platform, install the `langchain_community` integration package, and install the `SSEClient` Package.\n", + "\n", + "```bash\n", + "pip install langchain-community\n", + "pip install sseclient-py\n", + "```\n", + "\n", + "### Credentials\n", + "\n", + "Get the URL and API Key from your SambaStudio deployed endpoint and add them to your environment variables:\n", + "\n", + "``` bash\n", + "export SAMBASTUDIO_URL=\"your-api-key-here\"\n", + "export SAMBASTUDIO_API_KEY=\"your-api-key-here\"\n", + "```" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "import os\n", + "\n", + "if not os.getenv(\"SAMBASTUDIO_URL\"):\n", + " os.environ[\"SAMBASTUDIO_URL\"] = getpass.getpass(\"Enter your SambaStudio URL: \")\n", + "if not os.getenv(\"SAMBASTUDIO_API_KEY\"):\n", + " os.environ[\"SAMBASTUDIO_API_KEY\"] = getpass.getpass(\n", + " \"Enter your SambaStudio API key: \"\n", + " )" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If you want to get automated tracing of your model calls you can also set your [LangSmith](https://docs.smith.langchain.com/) API key by uncommenting below:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# os.environ[\"LANGCHAIN_TRACING_V2\"] = \"true\"\n", + "# os.environ[\"LANGCHAIN_API_KEY\"] = getpass.getpass(\"Enter your LangSmith API key: \")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Installation\n", + "\n", + "The LangChain __SambaStudio__ integration lives in the `langchain_community` package:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%pip install -qU langchain-community\n", + "%pip install -qu sseclient-py" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Instantiation\n", + "\n", + "Now we can instantiate our model object and generate chat completions:" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.chat_models.sambanova import ChatSambaStudio\n", + "\n", + "llm = ChatSambaStudio(\n", + " model=\"Meta-Llama-3-70B-Instruct-4096\", # set if using a CoE endpoint\n", + " max_tokens=1024,\n", + " temperature=0.7,\n", + " top_k=1,\n", + " top_p=0.01,\n", + " do_sample=True,\n", + " process_prompt=\"True\", # set if using a CoE endpoint\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Invocation" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content=\"J'adore la programmation.\", response_metadata={'id': 'item0', 'partial': False, 'value': {'completion': \"J'adore la programmation.\", 'logprobs': {'text_offset': [], 'top_logprobs': []}, 'prompt': '<|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant that translates English to French. Translate the user sentence.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nI love programming.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n', 'stop_reason': 'end_of_text', 'tokens': ['J', \"'\", 'ad', 'ore', ' la', ' programm', 'ation', '.'], 'total_tokens_count': 43}, 'params': {}, 'status': None}, id='item0')" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "messages = [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates English to French. Translate the user sentence.\",\n", + " ),\n", + " (\"human\", \"I love programming.\"),\n", + "]\n", + "ai_msg = llm.invoke(messages)\n", + "ai_msg" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "J'adore la programmation.\n" + ] + } + ], + "source": [ + "print(ai_msg.content)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Chaining\n", + "\n", + "We can [chain](/docs/how_to/sequence/) our model with a prompt template like so:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Ich liebe das Programmieren.', response_metadata={'id': 'item0', 'partial': False, 'value': {'completion': 'Ich liebe das Programmieren.', 'logprobs': {'text_offset': [], 'top_logprobs': []}, 'prompt': '<|start_header_id|>system<|end_header_id|>\\n\\nYou are a helpful assistant that translates English to German.<|eot_id|><|start_header_id|>user<|end_header_id|>\\n\\nI love programming.<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n', 'stop_reason': 'end_of_text', 'tokens': ['Ich', ' liebe', ' das', ' Programm', 'ieren', '.'], 'total_tokens_count': 36}, 'params': {}, 'status': None}, id='item0')" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "from langchain_core.prompts import ChatPromptTemplate\n", + "\n", + "prompt = ChatPromptTemplate(\n", + " [\n", + " (\n", + " \"system\",\n", + " \"You are a helpful assistant that translates {input_language} to {output_language}.\",\n", + " ),\n", + " (\"human\", \"{input}\"),\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | llm\n", + "chain.invoke(\n", + " {\n", + " \"input_language\": \"English\",\n", + " \"output_language\": \"German\",\n", + " \"input\": \"I love programming.\",\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Arrr, ye landlubber! Ye be wantin' to learn about owls, eh? Well, matey, settle yerself down with a pint o' grog and listen close, for I be tellin' ye about these fascinatin' creatures o' the night!\n", + "\n", + "Owls be birds, but not just any birds, me hearty! They be nocturnal, meanin' they do their huntin' at night, when the rest o' the world be sleepin'. And they be experts at it, too! Their big, round eyes be designed for seein' in the dark, with a special reflective layer called the tapetum lucidum that helps 'em spot prey in the shadows. It's like havin' a built-in lantern, savvy?\n", + "\n", + "But that be not all, me matey! Owls also have acute hearin', which helps 'em pinpoint the slightest sounds in the dark. And their ears be asymmetrical, meanin' one ear be higher than the other, which gives 'em better depth perception. It's like havin' a built-in sonar system, arrr!\n", + "\n", + "Now, ye might be wonderin' how owls fly so silently, like ghosts in the night. Well, it be because o' their special feathers, me hearty! They have soft, fringed feathers on their wings that help reduce noise and turbulence, makin' 'em the sneakiest flyers on the seven seas... er, skies!\n", + "\n", + "Owls come in all shapes and sizes, from the tiny elf owl to the great grey owl, which be one o' the largest owl species in the world. And they be found on every continent, except Antarctica, o' course. They be solitary creatures, but some species be known to form long-term monogamous relationships, like the barn owl and its mate.\n", + "\n", + "So, there ye have it, me hearty! Owls be amazin' creatures, with their clever adaptations and stealthy ways. Now, go forth and spread the word about these magnificent birds o' the night! And remember, if ye ever encounter an owl in the wild, be sure to show respect and keep a weather eye open, or ye might just find yerself on the receivin' end o' a silent, flyin' tackle! Arrr!" + ] + } + ], + "source": [ + "system = \"You are a helpful assistant with pirate accent.\"\n", + "human = \"I want to learn more about this animal: {animal}\"\n", + "prompt = ChatPromptTemplate.from_messages([(\"system\", system), (\"human\", human)])\n", + "\n", + "chain = prompt | llm\n", + "\n", + "for chunk in chain.stream({\"animal\": \"owl\"}):\n", + " print(chunk.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='The capital of France is Paris.', response_metadata={'id': 'item0', 'partial': False, 'value': {'completion': 'The capital of France is Paris.', 'logprobs': {'text_offset': [], 'top_logprobs': []}, 'prompt': '<|start_header_id|>user<|end_header_id|>\\n\\nwhat is the capital of France?<|eot_id|><|start_header_id|>assistant<|end_header_id|>\\n\\n', 'stop_reason': 'end_of_text', 'tokens': ['The', ' capital', ' of', ' France', ' is', ' Paris', '.'], 'total_tokens_count': 24}, 'params': {}, 'status': None}, id='item0')" + ] + }, + "execution_count": 7, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"human\",\n", + " \"what is the capital of {country}?\",\n", + " )\n", + " ]\n", + ")\n", + "\n", + "chain = prompt | llm\n", + "await chain.ainvoke({\"country\": \"France\"})" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Async Streaming" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Quantum computers use quantum bits (qubits) to process multiple possibilities simultaneously, exponentially faster than classical computers, enabling breakthroughs in fields like cryptography, optimization, and simulation." + ] + } + ], + "source": [ + "prompt = ChatPromptTemplate.from_messages(\n", + " [\n", + " (\n", + " \"human\",\n", + " \"in less than {num_words} words explain me {topic} \",\n", + " )\n", + " ]\n", + ")\n", + "chain = prompt | llm\n", + "\n", + "async for chunk in chain.astream({\"num_words\": 30, \"topic\": \"quantum computers\"}):\n", + " print(chunk.content, end=\"\", flush=True)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## API reference\n", + "\n", + "For detailed documentation of all ChatSambaStudio features and configurations head to the API reference: https://api.python.langchain.com/en/latest/chat_models/langchain_community.chat_models.sambanova.ChatSambaStudio.html" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.19" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/chat_models/__init__.py b/libs/community/langchain_community/chat_models/__init__.py index 021f6f5c9d1..db507637528 100644 --- a/libs/community/langchain_community/chat_models/__init__.py +++ b/libs/community/langchain_community/chat_models/__init__.py @@ -149,6 +149,7 @@ if TYPE_CHECKING: ) from langchain_community.chat_models.sambanova import ( ChatSambaNovaCloud, + ChatSambaStudio, ) from langchain_community.chat_models.snowflake import ( ChatSnowflakeCortex, @@ -215,6 +216,7 @@ __all__ = [ "ChatPerplexity", "ChatPremAI", "ChatSambaNovaCloud", + "ChatSambaStudio", "ChatSparkLLM", "ChatSnowflakeCortex", "ChatTongyi", @@ -274,6 +276,7 @@ _module_lookup = { "ChatOpenAI": "langchain_community.chat_models.openai", "ChatPerplexity": "langchain_community.chat_models.perplexity", "ChatSambaNovaCloud": "langchain_community.chat_models.sambanova", + "ChatSambaStudio": "langchain_community.chat_models.sambanova", "ChatSnowflakeCortex": "langchain_community.chat_models.snowflake", "ChatSparkLLM": "langchain_community.chat_models.sparkllm", "ChatTongyi": "langchain_community.chat_models.tongyi", diff --git a/libs/community/langchain_community/chat_models/sambanova.py b/libs/community/langchain_community/chat_models/sambanova.py index deaee597879..cd95f4beefa 100644 --- a/libs/community/langchain_community/chat_models/sambanova.py +++ b/libs/community/langchain_community/chat_models/sambanova.py @@ -1,5 +1,5 @@ import json -from typing import Any, Dict, Iterator, List, Optional +from typing import Any, Dict, Iterator, List, Optional, Tuple import requests from langchain_core.callbacks import ( @@ -13,6 +13,7 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, HumanMessage, SystemMessage, @@ -21,6 +22,46 @@ from langchain_core.messages import ( from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from pydantic import Field, SecretStr +from requests import Response + + +def _convert_message_to_dict(message: BaseMessage) -> Dict[str, Any]: + """ + convert a BaseMessage to a dictionary with Role / content + + Args: + message: BaseMessage + + Returns: + messages_dict: role / content dict + """ + if isinstance(message, ChatMessage): + message_dict = {"role": message.role, "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, ToolMessage): + message_dict = {"role": "tool", "content": message.content} + else: + raise TypeError(f"Got unknown type {message}") + return message_dict + + +def _create_message_dicts(messages: List[BaseMessage]) -> List[Dict[str, Any]]: + """ + Convert a list of BaseMessages to a list of dictionaries with Role / content + + Args: + messages: list of BaseMessages + + Returns: + messages_dicts: list of role / content dicts + """ + message_dicts = [_convert_message_to_dict(m) for m in messages] + return message_dicts class ChatSambaNovaCloud(BaseChatModel): @@ -28,7 +69,7 @@ class ChatSambaNovaCloud(BaseChatModel): SambaNova Cloud chat model. Setup: - To use, you should have the environment variables + To use, you should have the environment variables: ``SAMBANOVA_URL`` set with your SambaNova Cloud URL. ``SAMBANOVA_API_KEY`` set with your SambaNova Cloud API Key. http://cloud.sambanova.ai/ @@ -38,7 +79,6 @@ class ChatSambaNovaCloud(BaseChatModel): sambanova_url = SambaNova cloud endpoint URL, sambanova_api_key = set with your SambaNova cloud API key, model = model name, - streaming = set True for use streaming API max_tokens = max number of tokens to generate, temperature = model temperature, top_p = model top p, @@ -48,9 +88,9 @@ class ChatSambaNovaCloud(BaseChatModel): Key init args — completion params: model: str - The name of the model to use, e.g., llama3-8b. + The name of the model to use, e.g., Meta-Llama-3-70B-Instruct. streaming: bool - Whether to use streaming or not + Whether to use streaming handler when using non streaming methods max_tokens: int max tokens to generate temperature: float @@ -77,7 +117,6 @@ class ChatSambaNovaCloud(BaseChatModel): sambanova_url = SambaNova cloud endpoint URL, sambanova_api_key = set with your SambaNova cloud API key, model = model name, - streaming = set True for streaming max_tokens = max number of tokens to generate, temperature = model temperature, top_p = model top p, @@ -123,11 +162,11 @@ class ChatSambaNovaCloud(BaseChatModel): sambanova_api_key: SecretStr = Field(default="") """SambaNova Cloud api key""" - model: str = Field(default="llama3-8b") + model: str = Field(default="Meta-Llama-3.1-8B-Instruct") """The name of the model""" streaming: bool = Field(default=False) - """Whether to use streaming or not""" + """Whether to use streaming handler when using non streaming methods""" max_tokens: int = Field(default=1024) """max tokens to generate""" @@ -135,10 +174,10 @@ class ChatSambaNovaCloud(BaseChatModel): temperature: float = Field(default=0.7) """model temperature""" - top_p: float = Field(default=0.0) + top_p: Optional[float] = Field() """model top p""" - top_k: int = Field(default=1) + top_k: Optional[int] = Field() """model top k""" stream_options: dict = Field(default={"include_usage": True}) @@ -225,15 +264,15 @@ class ChatSambaNovaCloud(BaseChatModel): if response.status_code != 200: raise RuntimeError( f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{response.text}." + f"{response.status_code}.", + f"{response.text}.", ) response_dict = response.json() if response_dict.get("error"): raise RuntimeError( f"Sambanova /complete call failed with status code " - f"{response.status_code}." - f"{response_dict}." + f"{response.status_code}.", + f"{response_dict}.", ) return response_dict @@ -247,7 +286,7 @@ class ChatSambaNovaCloud(BaseChatModel): messages_dicts: List of role / content dicts to use as input. stop: list of stop tokens - Returns: + Yields: An iterator of response dicts. """ try: @@ -289,82 +328,38 @@ class ChatSambaNovaCloud(BaseChatModel): ) for event in client.events(): - chunk = { - "event": event.event, - "data": event.data, - "status_code": response.status_code, - } - - if chunk["event"] == "error_event" or chunk["status_code"] != 200: + if event.event == "error_event": raise RuntimeError( f"Sambanova /complete call failed with status code " - f"{chunk['status_code']}." - f"{chunk}." + f"{response.status_code}." + f"{event.data}." ) try: # check if the response is a final event # in that case event data response is '[DONE]' - if chunk["data"] != "[DONE]": - if isinstance(chunk["data"], str): - data = json.loads(chunk["data"]) + if event.data != "[DONE]": + if isinstance(event.data, str): + data = json.loads(event.data) else: raise RuntimeError( f"Sambanova /complete call failed with status code " - f"{chunk['status_code']}." - f"{chunk}." + f"{response.status_code}." + f"{event.data}." ) if data.get("error"): raise RuntimeError( f"Sambanova /complete call failed with status code " - f"{chunk['status_code']}." - f"{chunk}." + f"{response.status_code}." + f"{event.data}." ) yield data - except Exception: - raise Exception( - f"Error getting content chunk raw streamed response: {chunk}" + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"data: {event.data}" ) - def _convert_message_to_dict(self, message: BaseMessage) -> Dict[str, Any]: - """ - convert a BaseMessage to a dictionary with Role / content - - Args: - message: BaseMessage - - Returns: - messages_dict: role / content dict - """ - if isinstance(message, ChatMessage): - message_dict = {"role": message.role, "content": message.content} - elif isinstance(message, SystemMessage): - message_dict = {"role": "system", "content": message.content} - elif isinstance(message, HumanMessage): - message_dict = {"role": "user", "content": message.content} - elif isinstance(message, AIMessage): - message_dict = {"role": "assistant", "content": message.content} - elif isinstance(message, ToolMessage): - message_dict = {"role": "tool", "content": message.content} - else: - raise TypeError(f"Got unknown type {message}") - return message_dict - - def _create_message_dicts( - self, messages: List[BaseMessage] - ) -> List[Dict[str, Any]]: - """ - convert a lit of BaseMessages to a list of dictionaries with Role / content - - Args: - messages: list of BaseMessages - - Returns: - messages_dicts: list of role / content dicts - """ - message_dicts = [self._convert_message_to_dict(m) for m in messages] - return message_dicts - def _generate( self, messages: List[BaseMessage], @@ -373,9 +368,7 @@ class ChatSambaNovaCloud(BaseChatModel): **kwargs: Any, ) -> ChatResult: """ - SambaNovaCloud chat model logic. - - Call SambaNovaCloud API. + Call SambaNovaCloud models. Args: messages: the prompt composed of a list of messages. @@ -386,6 +379,9 @@ class ChatSambaNovaCloud(BaseChatModel): it makes it much easier to parse the output of the model downstream and understand why generation stopped. run_manager: A run manager with callbacks for the LLM. + + Returns: + result: ChatResult with model generation """ if self.streaming: stream_iter = self._stream( @@ -393,7 +389,7 @@ class ChatSambaNovaCloud(BaseChatModel): ) if stream_iter: return generate_from_stream(stream_iter) - messages_dicts = self._create_message_dicts(messages) + messages_dicts = _create_message_dicts(messages) response = self._handle_request(messages_dicts, stop) message = AIMessage( content=response["choices"][0]["message"]["content"], @@ -430,8 +426,11 @@ class ChatSambaNovaCloud(BaseChatModel): it makes it much easier to parse the output of the model downstream and understand why generation stopped. run_manager: A run manager with callbacks for the LLM. + + Yields: + chunk: ChatGenerationChunk with model partial generation """ - messages_dicts = self._create_message_dicts(messages) + messages_dicts = _create_message_dicts(messages) finish_reason = None for partial_response in self._handle_streaming_request(messages_dicts, stop): if len(partial_response["choices"]) > 0: @@ -463,3 +462,751 @@ class ChatSambaNovaCloud(BaseChatModel): if run_manager: run_manager.on_llm_new_token(chunk.text, chunk=chunk) yield chunk + + +class ChatSambaStudio(BaseChatModel): + """ + SambaStudio chat model. + + Setup: + To use, you should have the environment variables: + ``SAMBASTUDIO_URL`` set with your SambaStudio deployed endpoint URL. + ``SAMBASTUDIO_API_KEY`` set with your SambaStudio deployed endpoint Key. + https://docs.sambanova.ai/sambastudio/latest/index.html + Example: + .. code-block:: python + ChatSambaStudio( + sambastudio_url = set with your SambaStudio deployed endpoint URL, + sambastudio_api_key = set with your SambaStudio deployed endpoint Key. + model = model or expert name (set for CoE endpoints), + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + do_sample = wether to do sample + process_prompt = wether to process prompt + (set for CoE generic v1 and v2 endpoints) + stream_options = include usage to get generation metrics + special_tokens = start, start_role, end_role, end special tokens + (set for CoE generic v1 and v2 endpoints when process prompt + set to false or for StandAlone v1 and v2 endpoints) + model_kwargs: Optional = Extra Key word arguments to pass to the model. + ) + + Key init args — completion params: + model: str + The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096 + (set for CoE endpoints). + streaming: bool + Whether to use streaming + max_tokens: inthandler when using non streaming methods + max tokens to generate + temperature: float + model temperature + top_p: float + model top p + top_k: int + model top k + do_sample: bool + wether to do sample + process_prompt: + wether to process prompt (set for CoE generic v1 and v2 endpoints) + stream_options: dict + stream options, include usage to get generation metrics + special_tokens: dict + start, start_role, end_role and end special tokens + (set for CoE generic v1 and v2 endpoints when process prompt set to false + or for StandAlone v1 and v2 endpoints) default to llama3 special tokens + model_kwargs: dict + Extra Key word arguments to pass to the model. + + Key init args — client params: + sambastudio_url: str + SambaStudio endpoint Url + sambastudio_api_key: str + SambaStudio endpoint api key + + Instantiate: + .. code-block:: python + + from langchain_community.chat_models import ChatSambaStudio + + chat = ChatSambaStudio=( + sambastudio_url = set with your SambaStudio deployed endpoint URL, + sambastudio_api_key = set with your SambaStudio deployed endpoint Key. + model = model or expert name (set for CoE endpoints), + max_tokens = max number of tokens to generate, + temperature = model temperature, + top_p = model top p, + top_k = model top k, + do_sample = wether to do sample + process_prompt = wether to process prompt + (set for CoE generic v1 and v2 endpoints) + stream_options = include usage to get generation metrics + special_tokens = start, start_role, end_role, and special tokens + (set for CoE generic v1 and v2 endpoints when process prompt + set to false or for StandAlone v1 and v2 endpoints) + model_kwargs: Optional = Extra Key word arguments to pass to the model. + ) + Invoke: + .. code-block:: python + messages = [ + SystemMessage(content="your are an AI assistant."), + HumanMessage(content="tell me a joke."), + ] + response = chat.invoke(messages) + + Stream: + .. code-block:: python + + for chunk in chat.stream(messages): + print(chunk.content, end="", flush=True) + + Async: + .. code-block:: python + + response = chat.ainvoke(messages) + await response + + Token usage: + .. code-block:: python + response = chat.invoke(messages) + print(response.response_metadata["usage"]["prompt_tokens"] + print(response.response_metadata["usage"]["total_tokens"] + + Response metadata + .. code-block:: python + + response = chat.invoke(messages) + print(response.response_metadata) + """ + + sambastudio_url: str = Field(default="") + """SambaStudio Url""" + + sambastudio_api_key: SecretStr = Field(default="") + """SambaStudio api key""" + + base_url: str = Field(default="", exclude=True) + """SambaStudio non streaming Url""" + + streaming_url: str = Field(default="", exclude=True) + """SambaStudio streaming Url""" + + model: Optional[str] = Field() + """The name of the model or expert to use (for CoE endpoints)""" + + streaming: bool = Field(default=False) + """Whether to use streaming handler when using non streaming methods""" + + max_tokens: int = Field(default=1024) + """max tokens to generate""" + + temperature: Optional[float] = Field(default=0.7) + """model temperature""" + + top_p: Optional[float] = Field() + """model top p""" + + top_k: Optional[int] = Field() + """model top k""" + + do_sample: Optional[bool] = Field() + """whether to do sampling""" + + process_prompt: Optional[bool] = Field() + """whether process prompt (for CoE generic v1 and v2 endpoints)""" + + stream_options: dict = Field(default={"include_usage": True}) + """stream options, include usage to get generation metrics""" + + special_tokens: dict = Field( + default={ + "start": "<|begin_of_text|>", + "start_role": "<|begin_of_text|><|start_header_id|>{role}<|end_header_id|>", + "end_role": "<|eot_id|>", + "end": "<|start_header_id|>assistant<|end_header_id|>\n", + } + ) + """start, start_role, end_role and end special tokens + (set for CoE generic v1 and v2 endpoints when process prompt set to false + or for StandAlone v1 and v2 endpoints) + default to llama3 special tokens""" + + model_kwargs: Optional[Dict[str, Any]] = None + """Key word arguments to pass to the model.""" + + class Config: + populate_by_name = True + + @classmethod + def is_lc_serializable(cls) -> bool: + """Return whether this model can be serialized by Langchain.""" + return False + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "sambastudio_url": "sambastudio_url", + "sambastudio_api_key": "sambastudio_api_key", + } + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Return a dictionary of identifying parameters. + + This information is used by the LangChain callback system, which + is used for tracing purposes make it possible to monitor LLMs. + """ + return { + "model": self.model, + "streaming": self.streaming, + "max_tokens": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + "process_prompt": self.process_prompt, + "stream_options": self.stream_options, + "special_tokens": self.special_tokens, + "model_kwargs": self.model_kwargs, + } + + @property + def _llm_type(self) -> str: + """Get the type of language model used by this chat model.""" + return "sambastudio-chatmodel" + + def __init__(self, **kwargs: Any) -> None: + """init and validate environment variables""" + kwargs["sambastudio_url"] = get_from_dict_or_env( + kwargs, "sambastudio_url", "SAMBASTUDIO_URL" + ) + + kwargs["sambastudio_api_key"] = convert_to_secret_str( + get_from_dict_or_env(kwargs, "sambastudio_api_key", "SAMBASTUDIO_API_KEY") + ) + kwargs["base_url"], kwargs["streaming_url"] = self._get_sambastudio_urls( + kwargs["sambastudio_url"] + ) + super().__init__(**kwargs) + + def _get_role(self, message: BaseMessage) -> str: + """ + Get the role of LangChain BaseMessage + + Args: + message: LangChain BaseMessage + + Returns: + str: Role of the LangChain BaseMessage + """ + if isinstance(message, ChatMessage): + role = message.role + elif isinstance(message, SystemMessage): + role = "system" + elif isinstance(message, HumanMessage): + role = "user" + elif isinstance(message, AIMessage): + role = "assistant" + elif isinstance(message, ToolMessage): + role = "tool" + else: + raise TypeError(f"Got unknown type {message}") + return role + + def _messages_to_string(self, messages: List[BaseMessage]) -> str: + """ + Convert a list of BaseMessages to a: + - dumped json string with Role / content dict structure + when process_prompt is true, + - string with special tokens if process_prompt is false + for generic V1 and V2 endpoints + + Args: + messages: list of BaseMessages + + Returns: + str: string to send as model input depending on process_prompt param + """ + if self.process_prompt: + messages_dict: Dict[str, Any] = { + "conversation_id": "sambaverse-conversation-id", + "messages": [], + } + for message in messages: + messages_dict["messages"].append( + { + "message_id": message.id, + "role": self._get_role(message), + "content": message.content, + } + ) + messages_string = json.dumps(messages_dict) + else: + messages_string = self.special_tokens["start"] + for message in messages: + messages_string += self.special_tokens["start_role"].format( + role=self._get_role(message) + ) + messages_string += f" {message.content} " + messages_string += self.special_tokens["end_role"] + messages_string += self.special_tokens["end"] + + return messages_string + + def _get_sambastudio_urls(self, url: str) -> Tuple[str, str]: + """ + Get streaming and non streaming URLs from the given URL + + Args: + url: string with sambastudio base or streaming endpoint url + + Returns: + base_url: string with url to do non streaming calls + streaming_url: string with url to do streaming calls + """ + if "openai" in url: + base_url = url + stream_url = url + else: + if "stream" in url: + base_url = url.replace("stream/", "") + stream_url = url + else: + base_url = url + if "generic" in url: + stream_url = "generic/stream".join(url.split("generic")) + else: + raise ValueError("Unsupported URL") + return base_url, stream_url + + def _handle_request( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + streaming: Optional[bool] = False, + ) -> Response: + """ + Performs a post request to the LLM API. + + Args: + messages_dicts: List of role / content dicts to use as input. + stop: list of stop tokens + streaming: wether to do a streaming call + + Returns: + A request Response object + """ + + # create request payload for openai compatible API + if "openai" in self.sambastudio_url: + messages_dicts = _create_message_dicts(messages) + data = { + "messages": messages_dicts, + "max_tokens": self.max_tokens, + "stop": stop, + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "stream": streaming, + "stream_options": self.stream_options, + } + data = {key: value for key, value in data.items() if value is not None} + headers = { + "Authorization": f"Bearer " + f"{self.sambastudio_api_key.get_secret_value()}", + "Content-Type": "application/json", + } + + # create request payload for generic v1 API + elif "api/v2/predict/generic" in self.sambastudio_url: + items = [{"id": "item0", "value": self._messages_to_string(messages)}] + params: Dict[str, Any] = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = {key: value for key, value in params.items() if value is not None} + data = {"items": items, "params": params} + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + # create request payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + params = { + "select_expert": self.model, + "process_prompt": self.process_prompt, + "max_tokens_to_generate": self.max_tokens, + "temperature": self.temperature, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + if self.model_kwargs is not None: + params = {**params, **self.model_kwargs} + params = { + key: {"type": type(value).__name__, "value": str(value)} + for key, value in params.items() + if value is not None + } + if streaming: + data = { + "instance": self._messages_to_string(messages), + "params": params, + } + else: + data = { + "instances": [self._messages_to_string(messages)], + "params": params, + } + headers = {"key": self.sambastudio_api_key.get_secret_value()} + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + http_session = requests.Session() + if streaming: + response = http_session.post( + self.streaming_url, headers=headers, json=data, stream=True + ) + else: + response = http_session.post( + self.base_url, headers=headers, json=data, stream=False + ) + if response.status_code != 200: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{response.text}." + ) + return response + + def _process_response(self, response: Response) -> AIMessage: + """ + Process a non streaming response from the api + + Args: + response: A request Response object + + Returns + generation: an AIMessage with model generation + """ + + # Extract json payload form response + try: + response_dict = response.json() + except Exception as e: + raise RuntimeError( + f"Sambanova /complete call failed couldn't get JSON response {e}" + f"response: {response.text}" + ) + + # process response payload for openai compatible API + if "openai" in self.sambastudio_url: + content = response_dict["choices"][0]["message"]["content"] + id = response_dict["id"] + response_metadata = { + "finish_reason": response_dict["choices"][0]["finish_reason"], + "usage": response_dict.get("usage"), + "model_name": response_dict["model"], + "system_fingerprint": response_dict["system_fingerprint"], + "created": response_dict["created"], + } + + # process response payload for generic v2 API + elif "api/v2/predict/generic" in self.sambastudio_url: + content = response_dict["items"][0]["value"]["completion"] + id = response_dict["items"][0]["id"] + response_metadata = response_dict["items"][0] + + # process response payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + content = response_dict["predictions"][0]["completion"] + id = None + response_metadata = response_dict + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + return AIMessage( + content=content, + additional_kwargs={}, + response_metadata=response_metadata, + id=id, + ) + + def _process_stream_response( + self, response: Response + ) -> Iterator[BaseMessageChunk]: + """ + Process a streaming response from the api + + Args: + response: An iterable request Response object + + Yields: + generation: an AIMessageChunk with model partial generation + """ + + try: + import sseclient + except ImportError: + raise ImportError( + "could not import sseclient library" + "Please install it with `pip install sseclient-py`." + ) + + # process response payload for openai compatible API + if "openai" in self.sambastudio_url: + finish_reason = "" + client = sseclient.SSEClient(response) + for event in client.events(): + if event.event == "error_event": + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + try: + # check if the response is not a final event ("[DONE]") + if event.data != "[DONE]": + if isinstance(event.data, str): + data = json.loads(event.data) + else: + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + if data.get("error"): + raise RuntimeError( + f"Sambanova /complete call failed with status code " + f"{response.status_code}." + f"{event.data}." + ) + if len(data["choices"]) > 0: + finish_reason = data["choices"][0].get("finish_reason") + content = data["choices"][0]["delta"]["content"] + id = data["id"] + metadata = {} + else: + content = "" + id = data["id"] + metadata = { + "finish_reason": finish_reason, + "usage": data.get("usage"), + "model_name": data["model"], + "system_fingerprint": data["system_fingerprint"], + "created": data["created"], + } + yield AIMessageChunk( + content=content, + id=id, + response_metadata=metadata, + additional_kwargs={}, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"data: {event.data}" + ) + + # process response payload for generic v2 API + elif "api/v2/predict/generic" in self.sambastudio_url: + for line in response.iter_lines(): + try: + data = json.loads(line) + content = data["result"]["items"][0]["value"]["stream_token"] + id = data["result"]["items"][0]["id"] + if data["result"]["items"][0]["value"]["is_last_response"]: + metadata = { + "finish_reason": data["result"]["items"][0]["value"].get( + "stop_reason" + ), + "prompt": data["result"]["items"][0]["value"].get("prompt"), + "usage": { + "prompt_tokens_count": data["result"]["items"][0][ + "value" + ].get("prompt_tokens_count"), + "completion_tokens_count": data["result"]["items"][0][ + "value" + ].get("completion_tokens_count"), + "total_tokens_count": data["result"]["items"][0][ + "value" + ].get("total_tokens_count"), + "start_time": data["result"]["items"][0]["value"].get( + "start_time" + ), + "end_time": data["result"]["items"][0]["value"].get( + "end_time" + ), + "model_execution_time": data["result"]["items"][0][ + "value" + ].get("model_execution_time"), + "time_to_first_token": data["result"]["items"][0][ + "value" + ].get("time_to_first_token"), + "throughput_after_first_token": data["result"]["items"][ + 0 + ]["value"].get("throughput_after_first_token"), + "batch_size_used": data["result"]["items"][0][ + "value" + ].get("batch_size_used"), + }, + } + else: + metadata = {} + yield AIMessageChunk( + content=content, + id=id, + response_metadata=metadata, + additional_kwargs={}, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"line: {line}" + ) + + # process response payload for generic v1 API + elif "api/predict/generic" in self.sambastudio_url: + for line in response.iter_lines(): + try: + data = json.loads(line) + content = data["result"]["responses"][0]["stream_token"] + id = None + if data["result"]["responses"][0]["is_last_response"]: + metadata = { + "finish_reason": data["result"]["responses"][0].get( + "stop_reason" + ), + "prompt": data["result"]["responses"][0].get("prompt"), + "usage": { + "prompt_tokens_count": data["result"]["responses"][ + 0 + ].get("prompt_tokens_count"), + "completion_tokens_count": data["result"]["responses"][ + 0 + ].get("completion_tokens_count"), + "total_tokens_count": data["result"]["responses"][ + 0 + ].get("total_tokens_count"), + "start_time": data["result"]["responses"][0].get( + "start_time" + ), + "end_time": data["result"]["responses"][0].get( + "end_time" + ), + "model_execution_time": data["result"]["responses"][ + 0 + ].get("model_execution_time"), + "time_to_first_token": data["result"]["responses"][ + 0 + ].get("time_to_first_token"), + "throughput_after_first_token": data["result"][ + "responses" + ][0].get("throughput_after_first_token"), + "batch_size_used": data["result"]["responses"][0].get( + "batch_size_used" + ), + }, + } + else: + metadata = {} + yield AIMessageChunk( + content=content, + id=id, + response_metadata=metadata, + additional_kwargs={}, + ) + + except Exception as e: + raise RuntimeError( + f"Error getting content chunk raw streamed response: {e}" + f"line: {line}" + ) + + else: + raise ValueError( + f"Unsupported URL{self.sambastudio_url}" + "only openai, generic v1 and generic v2 APIs are supported" + ) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> ChatResult: + """ + Call SambaStudio models. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + + Returns: + result: ChatResult with model generation + """ + if self.streaming: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + if stream_iter: + return generate_from_stream(stream_iter) + response = self._handle_request(messages, stop, streaming=False) + message = self._process_response(response) + generation = ChatGeneration(message=message) + return ChatResult(generations=[generation]) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + """ + Stream the output of the SambaStudio model. + + Args: + messages: the prompt composed of a list of messages. + stop: a list of strings on which the model should stop generating. + If generation stops due to a stop token, the stop token itself + SHOULD BE INCLUDED as part of the output. This is not enforced + across models right now, but it's a good practice to follow since + it makes it much easier to parse the output of the model + downstream and understand why generation stopped. + run_manager: A run manager with callbacks for the LLM. + + Yields: + chunk: ChatGenerationChunk with model partial generation + """ + response = self._handle_request(messages, stop, streaming=True) + for ai_message_chunk in self._process_stream_response(response): + chunk = ChatGenerationChunk(message=ai_message_chunk) + if run_manager: + run_manager.on_llm_new_token(chunk.text, chunk=chunk) + yield chunk diff --git a/libs/community/tests/integration_tests/chat_models/test_sambanova.py b/libs/community/tests/integration_tests/chat_models/test_sambanova.py index 965a8156f2f..2683b5dc399 100644 --- a/libs/community/tests/integration_tests/chat_models/test_sambanova.py +++ b/libs/community/tests/integration_tests/chat_models/test_sambanova.py @@ -1,6 +1,9 @@ from langchain_core.messages import AIMessage, HumanMessage -from langchain_community.chat_models.sambanova import ChatSambaNovaCloud +from langchain_community.chat_models.sambanova import ( + ChatSambaNovaCloud, + ChatSambaStudio, +) def test_chat_sambanova_cloud() -> None: @@ -9,3 +12,11 @@ def test_chat_sambanova_cloud() -> None: response = chat.invoke([message]) assert isinstance(response, AIMessage) assert isinstance(response.content, str) + + +def test_chat_sambastudio() -> None: + chat = ChatSambaStudio() + message = HumanMessage(content="Hello") + response = chat.invoke([message]) + assert isinstance(response, AIMessage) + assert isinstance(response.content, str) diff --git a/libs/community/tests/unit_tests/chat_models/test_imports.py b/libs/community/tests/unit_tests/chat_models/test_imports.py index d8399ed8315..6be5a41d9ec 100644 --- a/libs/community/tests/unit_tests/chat_models/test_imports.py +++ b/libs/community/tests/unit_tests/chat_models/test_imports.py @@ -35,6 +35,7 @@ EXPECTED_ALL = [ "ChatPerplexity", "ChatPremAI", "ChatSambaNovaCloud", + "ChatSambaStudio", "ChatSparkLLM", "ChatTongyi", "ChatVertexAI",