community: SamabaStudio Tool Calling and Structured Output (#28025)

Description: Add tool calling and structured output support for
SambaStudio chat models, docs included

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
This commit is contained in:
Jorge Piedrahita Ortiz 2024-12-16 01:15:19 -05:00 committed by GitHub
parent c87f24d85d
commit 558b65ea32
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 684 additions and 57 deletions

View File

@ -34,7 +34,7 @@
"\n",
"| [Tool calling](/docs/how_to/tool_calling) | [Structured output](/docs/how_to/structured_output/) | JSON mode | [Image input](/docs/how_to/multimodal_inputs/) | Audio input | Video input | [Token-level streaming](/docs/how_to/chat_streaming/) | Native async | [Token usage](/docs/how_to/chat_token_usage_tracking/) | [Logprobs](/docs/how_to/logprobs/) |\n",
"| :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: | :---: |\n",
"| ❌ | ❌ | ❌ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"| ✅ | ✅ | ✅ | ❌ | ❌ | ❌ | ✅ | ✅ | ✅ | ❌ | \n",
"\n",
"## Setup\n",
"\n",
@ -119,20 +119,20 @@
},
{
"cell_type": "code",
"execution_count": 2,
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain_community.chat_models.sambanova import ChatSambaStudio\n",
"\n",
"llm = ChatSambaStudio(\n",
" model=\"Meta-Llama-3-70B-Instruct-4096\", # set if using a CoE endpoint\n",
" model=\"Meta-Llama-3-70B-Instruct-4096\", # set if using a Bundle endpoint\n",
" max_tokens=1024,\n",
" temperature=0.7,\n",
" top_k=1,\n",
" top_p=0.01,\n",
" do_sample=True,\n",
" process_prompt=\"True\", # set if using a CoE endpoint\n",
" process_prompt=\"True\", # set if using a Bundle endpoint\n",
")"
]
},
@ -349,6 +349,134 @@
" print(chunk.content, end=\"\", flush=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool calling"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from datetime import datetime\n",
"\n",
"from langchain_core.messages import HumanMessage, SystemMessage, ToolMessage\n",
"from langchain_core.tools import tool\n",
"\n",
"\n",
"@tool\n",
"def get_time(kind: str = \"both\") -> str:\n",
" \"\"\"Returns current date, current time or both.\n",
" Args:\n",
" kind: date, time or both\n",
" \"\"\"\n",
" if kind == \"date\":\n",
" date = datetime.now().strftime(\"%m/%d/%Y\")\n",
" return f\"Current date: {date}\"\n",
" elif kind == \"time\":\n",
" time = datetime.now().strftime(\"%H:%M:%S\")\n",
" return f\"Current time: {time}\"\n",
" else:\n",
" date = datetime.now().strftime(\"%m/%d/%Y\")\n",
" time = datetime.now().strftime(\"%H:%M:%S\")\n",
" return f\"Current date: {date}, Current time: {time}\"\n",
"\n",
"\n",
"tools = [get_time]\n",
"\n",
"\n",
"def invoke_tools(tool_calls, messages):\n",
" available_functions = {tool.name: tool for tool in tools}\n",
" for tool_call in tool_calls:\n",
" selected_tool = available_functions[tool_call[\"name\"]]\n",
" tool_output = selected_tool.invoke(tool_call[\"args\"])\n",
" print(f\"Tool output: {tool_output}\")\n",
" messages.append(ToolMessage(tool_output, tool_call_id=tool_call[\"id\"]))\n",
" return messages"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_with_tools = llm.bind_tools(tools=tools)\n",
"messages = [\n",
" HumanMessage(\n",
" content=\"I need to schedule a meeting for two weeks from today. Can you tell me the exact date of the meeting?\"\n",
" )\n",
"]"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Intermediate model response: [{'name': 'get_time', 'args': {'kind': 'date'}, 'id': 'call_4092d5dd21cd4eb494', 'type': 'tool_call'}]\n",
"Tool output: Current date: 11/07/2024\n",
"final response: The meeting will be exactly two weeks from today, which would be 25/07/2024.\n"
]
}
],
"source": [
"response = llm_with_tools.invoke(messages)\n",
"while len(response.tool_calls) > 0:\n",
" print(f\"Intermediate model response: {response.tool_calls}\")\n",
" messages.append(response)\n",
" messages = invoke_tools(response.tool_calls, messages)\n",
"response = llm_with_tools.invoke(messages)\n",
"\n",
"print(f\"final response: {response.content}\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Structured Outputs"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"Joke(setup='Why did the cat join a band?', punchline='Because it wanted to be the purr-cussionist!')"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"from pydantic import BaseModel, Field\n",
"\n",
"\n",
"class Joke(BaseModel):\n",
" \"\"\"Joke to tell user.\"\"\"\n",
"\n",
" setup: str = Field(description=\"The setup of the joke\")\n",
" punchline: str = Field(description=\"The punchline to the joke\")\n",
"\n",
"\n",
"structured_llm = llm.with_structured_output(Joke)\n",
"\n",
"structured_llm.invoke(\"Tell me a joke about cats\")"
]
},
{
"cell_type": "markdown",
"metadata": {},

View File

@ -958,25 +958,27 @@ class ChatSambaStudio(BaseChatModel):
Setup:
To use, you should have the environment variables:
``SAMBASTUDIO_URL`` set with your SambaStudio deployed endpoint URL.
``SAMBASTUDIO_API_KEY`` set with your SambaStudio deployed endpoint Key.
`SAMBASTUDIO_URL` set with your SambaStudio deployed endpoint URL.
`SAMBASTUDIO_API_KEY` set with your SambaStudio deployed endpoint Key.
https://docs.sambanova.ai/sambastudio/latest/index.html
Example:
.. code-block:: python
ChatSambaStudio(
sambastudio_url = set with your SambaStudio deployed endpoint URL,
sambastudio_api_key = set with your SambaStudio deployed endpoint Key.
model = model or expert name (set for CoE endpoints),
model = model or expert name (set for Bundle endpoints),
max_tokens = max number of tokens to generate,
temperature = model temperature,
top_p = model top p,
top_k = model top k,
do_sample = wether to do sample
process_prompt = wether to process prompt
(set for CoE generic v1 and v2 endpoints)
(set for Bundle generic v1 and v2 endpoints)
stream_options = include usage to get generation metrics
special_tokens = start, start_role, end_role, end special tokens
(set for CoE generic v1 and v2 endpoints when process prompt
(set for Bundle generic v1 and v2 endpoints when process prompt
set to false or for StandAlone v1 and v2 endpoints)
model_kwargs: Optional = Extra Key word arguments to pass to the model.
)
@ -984,7 +986,7 @@ class ChatSambaStudio(BaseChatModel):
Key init args completion params:
model: str
The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096
(set for CoE endpoints).
(set for Bundle endpoints).
streaming: bool
Whether to use streaming
max_tokens: inthandler when using non streaming methods
@ -998,12 +1000,12 @@ class ChatSambaStudio(BaseChatModel):
do_sample: bool
wether to do sample
process_prompt:
wether to process prompt (set for CoE generic v1 and v2 endpoints)
wether to process prompt (set for Bundle generic v1 and v2 endpoints)
stream_options: dict
stream options, include usage to get generation metrics
special_tokens: dict
start, start_role, end_role and end special tokens
(set for CoE generic v1 and v2 endpoints when process prompt set to false
(set for Bundle generic v1 and v2 endpoints when process prompt set to false
or for StandAlone v1 and v2 endpoints) default to llama3 special tokens
model_kwargs: dict
Extra Key word arguments to pass to the model.
@ -1022,22 +1024,24 @@ class ChatSambaStudio(BaseChatModel):
chat = ChatSambaStudio=(
sambastudio_url = set with your SambaStudio deployed endpoint URL,
sambastudio_api_key = set with your SambaStudio deployed endpoint Key.
model = model or expert name (set for CoE endpoints),
model = model or expert name (set for Bundle endpoints),
max_tokens = max number of tokens to generate,
temperature = model temperature,
top_p = model top p,
top_k = model top k,
do_sample = wether to do sample
process_prompt = wether to process prompt
(set for CoE generic v1 and v2 endpoints)
(set for Bundle generic v1 and v2 endpoints)
stream_options = include usage to get generation metrics
special_tokens = start, start_role, end_role, and special tokens
(set for CoE generic v1 and v2 endpoints when process prompt
(set for Bundle generic v1 and v2 endpoints when process prompt
set to false or for StandAlone v1 and v2 endpoints)
model_kwargs: Optional = Extra Key word arguments to pass to the model.
)
Invoke:
.. code-block:: python
messages = [
SystemMessage(content="your are an AI assistant."),
HumanMessage(content="tell me a joke."),
@ -1047,26 +1051,77 @@ class ChatSambaStudio(BaseChatModel):
Stream:
.. code-block:: python
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
Async:
.. code-block:: python
response = chat.ainvoke(messages)
await response
response = chat.ainvoke(messages)
await response
Tool calling:
.. code-block:: python
from pydantic import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(
...,
description="The city and state, e.g. Los Angeles, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Should I bring my umbrella today in LA?")
ai_msg.tool_calls
.. code-block:: python
[
{
'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_adf61180ea2b4d228a'
}
]
Structured output:
.. code-block:: python
from typing import Optional
from pydantic import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
setup: str = Field(description="The setup of the joke")
punchline: str = Field(description="The punchline to the joke")
structured_model = llm.with_structured_output(Joke)
structured_model.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(setup="Why did the cat join a band?",
punchline="Because it wanted to be the purr-cussionist!")
See `ChatSambaStudio.with_structured_output()` for more.
Token usage:
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata["usage"]["prompt_tokens"]
print(response.response_metadata["usage"]["total_tokens"]
response = chat.invoke(messages)
print(response.response_metadata["usage"]["prompt_tokens"]
print(response.response_metadata["usage"]["total_tokens"]
Response metadata
.. code-block:: python
response = chat.invoke(messages)
print(response.response_metadata)
response = chat.invoke(messages)
print(response.response_metadata)
"""
sambastudio_url: str = Field(default="")
@ -1082,7 +1137,7 @@ class ChatSambaStudio(BaseChatModel):
"""SambaStudio streaming Url"""
model: Optional[str] = Field(default=None)
"""The name of the model or expert to use (for CoE endpoints)"""
"""The name of the model or expert to use (for Bundle endpoints)"""
streaming: bool = Field(default=False)
"""Whether to use streaming handler when using non streaming methods"""
@ -1103,7 +1158,7 @@ class ChatSambaStudio(BaseChatModel):
"""whether to do sampling"""
process_prompt: Optional[bool] = Field(default=True)
"""whether process prompt (for CoE generic v1 and v2 endpoints)"""
"""whether process prompt (for Bundle generic v1 and v2 endpoints)"""
stream_options: Dict[str, Any] = Field(default={"include_usage": True})
"""stream options, include usage to get generation metrics"""
@ -1117,13 +1172,16 @@ class ChatSambaStudio(BaseChatModel):
}
)
"""start, start_role, end_role and end special tokens
(set for CoE generic v1 and v2 endpoints when process prompt set to false
(set for Bundle generic v1 and v2 endpoints when process prompt set to false
or for StandAlone v1 and v2 endpoints)
default to llama3 special tokens"""
model_kwargs: Optional[Dict[str, Any]] = None
"""Key word arguments to pass to the model."""
additional_headers: Dict[str, Any] = Field(default={})
"""Additional headers to send in request"""
class Config:
populate_by_name = True
@ -1179,6 +1237,358 @@ class ChatSambaStudio(BaseChatModel):
)
super().__init__(**kwargs)
def bind_tools(
self,
tools: Sequence[Union[Dict[str, Any], Type[Any], Callable[..., Any], BaseTool]],
*,
tool_choice: Optional[Union[Dict[str, Any], bool, str]] = None,
parallel_tool_calls: Optional[bool] = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
"""Bind tool-like objects to this chat model
tool_choice: does not currently support "any", choice like
should be one of ["auto", "none", "required"]
"""
formatted_tools = [convert_to_openai_tool(tool) for tool in tools]
if tool_choice:
if isinstance(tool_choice, str):
# tool_choice is a tool/function name
if tool_choice not in ("auto", "none", "required"):
tool_choice = "auto"
elif isinstance(tool_choice, bool):
if tool_choice:
tool_choice = "required"
elif isinstance(tool_choice, dict):
raise ValueError(
"tool_choice must be one of ['auto', 'none', 'required']"
)
else:
raise ValueError(
f"Unrecognized tool_choice type. Expected str, bool"
f"Received: {tool_choice}"
)
else:
tool_choice = "auto"
kwargs["tool_choice"] = tool_choice
kwargs["parallel_tool_calls"] = parallel_tool_calls
return super().bind(tools=formatted_tools, **kwargs)
def with_structured_output(
self,
schema: Optional[Union[Dict[str, Any], Type[BaseModel]]] = None,
*,
method: Literal[
"function_calling", "json_mode", "json_schema"
] = "function_calling",
include_raw: bool = False,
**kwargs: Any,
) -> Runnable[LanguageModelInput, Union[Dict[str, Any], BaseModel]]:
"""Model wrapper that returns outputs formatted to match the given schema.
Args:
schema:
The output schema. Can be passed in as:
- an OpenAI function/tool schema,
- a JSON Schema,
- a TypedDict class,
- or a Pydantic class.
If `schema` is a Pydantic class then the model output will be a
Pydantic instance of that class, and the model-generated fields will be
validated by the Pydantic class. Otherwise the model output will be a
dict and will not be validated. See :meth:`langchain_core.utils.function_calling.convert_to_openai_tool`
for more on how to properly specify types and descriptions of
schema fields when specifying a Pydantic or TypedDict class.
method:
The method for steering model generation, either "function_calling"
"json_mode" or "json_schema".
If "function_calling" then the schema will be converted
to an OpenAI function and the returned model will make use of the
function-calling API. If "json_mode" or "json_schema" then OpenAI's
JSON mode will be used.
Note that if using "json_mode" or "json_schema" then you must include instructions
for formatting the output into the desired schema into the model call.
include_raw:
If False then only the parsed structured output is returned. If
an error occurs during model output parsing it will be raised. If True
then both the raw model response (a BaseMessage) and the parsed model
response will be returned. If an error occurs during output parsing it
will be caught and returned as well. The final output is always a dict
with keys "raw", "parsed", and "parsing_error".
Returns:
A Runnable that takes same inputs as a :class:`langchain_core.language_models.chat.BaseChatModel`.
If `include_raw` is False and `schema` is a Pydantic class, Runnable outputs
an instance of `schema` (i.e., a Pydantic object).
Otherwise, if `include_raw` is False then Runnable outputs a dict.
If `include_raw` is True, then Runnable outputs a dict with keys:
- `"raw"`: BaseMessage
- `"parsed"`: None if there was a parsing error, otherwise the type depends on the `schema` as described above.
- `"parsing_error"`: Optional[BaseException]
Example: schema=Pydantic class, method="function_calling", include_raw=False:
.. code-block:: python
from typing import Optional
from langchain_community.chat_models import ChatSambaStudio
from pydantic import BaseModel, Field
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str = Field(
description="A justification for the answer."
)
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> AnswerWithJustification(
# answer='They weigh the same',
# justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same.'
# )
Example: schema=Pydantic class, method="function_calling", include_raw=True:
.. code-block:: python
from langchain_community.chat_models import ChatSambaStudio
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'function': {'arguments': '{"answer": "They weigh the same.", "justification": "A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount."}', 'name': 'AnswerWithJustification'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'function'}]}, response_metadata={'finish_reason': 'tool_calls', 'usage': {'acceptance_rate': 5, 'completion_tokens': 53, 'completion_tokens_after_first_per_sec': 343.7964936837758, 'completion_tokens_after_first_per_sec_first_ten': 439.1205661878638, 'completion_tokens_per_sec': 162.8511306784833, 'end_time': 1731527851.0698032, 'is_last_response': True, 'prompt_tokens': 213, 'start_time': 1731527850.7137961, 'time_to_first_token': 0.20475482940673828, 'total_latency': 0.32545061111450196, 'total_tokens': 266, 'total_tokens_per_sec': 817.3283162354066}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731527850}, id='95667eaf-447f-4b53-bb6e-b6e1094ded88', tool_calls=[{'name': 'AnswerWithJustification', 'args': {'answer': 'They weigh the same.', 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'}, 'id': 'call_17a431fc6a4240e1bd', 'type': 'tool_call'}]),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'),
# 'parsing_error': None
# }
Example: schema=TypedDict class, method="function_calling", include_raw=False:
.. code-block:: python
# IMPORTANT: If you are using Python <=3.8, you need to import Annotated
# from typing_extensions, not from typing.
from typing_extensions import Annotated, TypedDict
from langchain_community.chat_models import ChatSambaStudio
class AnswerWithJustification(TypedDict):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: Annotated[
Optional[str], None, "A justification for the answer."
]
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'
# }
Example: schema=OpenAI function schema, method="function_calling", include_raw=False:
.. code-block:: python
from langchain_community.chat_models import ChatSambaStudio
oai_schema = {
'name': 'AnswerWithJustification',
'description': 'An answer to the user question along with justification for the answer.',
'parameters': {
'type': 'object',
'properties': {
'answer': {'type': 'string'},
'justification': {'description': 'A justification for the answer.', 'type': 'string'}
},
'required': ['answer']
}
}
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(oai_schema)
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'A pound is a unit of weight or mass, so one pound of bricks and one pound of feathers both weigh the same amount.'
# }
Example: schema=Pydantic class, method="json_mode", include_raw=True:
.. code-block::
from langchain_community.chat_models import ChatSambaStudio
from pydantic import BaseModel
class AnswerWithJustification(BaseModel):
answer: str
justification: str
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(
AnswerWithJustification,
method="json_mode",
include_raw=True
)
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'),
# 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'),
# 'parsing_error': None
# }
Example: schema=None, method="json_mode", include_raw=True:
.. code-block::
from langchain_community.chat_models import ChatSambaStudio
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(method="json_mode", include_raw=True)
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 4.722222222222222, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 357.1315485254867, 'completion_tokens_after_first_per_sec_first_ten': 416.83279609305305, 'completion_tokens_per_sec': 240.92819585198137, 'end_time': 1731528164.8474727, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528164.4906917, 'time_to_first_token': 0.13837409019470215, 'total_latency': 0.3278985247892492, 'total_tokens': 149, 'total_tokens_per_sec': 454.4088757208256}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528164}, id='15261eaf-8a25-42ef-8ed5-f63d8bf5b1b0'),
# 'parsed': {
# 'answer': 'They are the same weight',
# 'justification': 'A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'},
# },
# 'parsing_error': None
# }
Example: schema=None, method="json_schema", include_raw=True:
.. code-block::
from langchain_community.chat_models import ChatSambaStudio
class AnswerWithJustification(BaseModel):
answer: str
justification: str
llm = ChatSambaStudio(model="Meta-Llama-3.1-70B-Instruct", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, method="json_schema", include_raw=True)
structured_llm.invoke(
"Answer the following question. "
"Make sure to return a JSON blob with keys 'answer' and 'justification'.\n\n"
"What's heavier a pound of bricks or a pound of feathers?"
)
# -> {
# 'raw': AIMessage(content='{\n "answer": "They are the same weight",\n "justification": "A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities."\n}', additional_kwargs={}, response_metadata={'finish_reason': 'stop', 'usage': {'acceptance_rate': 5.3125, 'completion_tokens': 79, 'completion_tokens_after_first_per_sec': 292.65701089829776, 'completion_tokens_after_first_per_sec_first_ten': 346.43324678555325, 'completion_tokens_per_sec': 200.012158915008, 'end_time': 1731528071.1708555, 'is_last_response': True, 'prompt_tokens': 70, 'start_time': 1731528070.737394, 'time_to_first_token': 0.16693782806396484, 'total_latency': 0.3949759876026827, 'total_tokens': 149, 'total_tokens_per_sec': 377.2381225105847}, 'model_name': 'Meta-Llama-3.1-70B-Instruct', 'system_fingerprint': 'fastcoe', 'created': 1731528070}, id='83208297-3eb9-4021-a856-ca78a15758df'),
# 'parsed': AnswerWithJustification(answer='They are the same weight', justification='A pound is a unit of weight or mass, so a pound of bricks and a pound of feathers both weigh the same amount, one pound. The difference is in their density and volume. A pound of feathers would take up more space than a pound of bricks due to the difference in their densities.'),
# 'parsing_error': None
# }
""" # noqa: E501
if kwargs:
raise ValueError(f"Received unsupported arguments {kwargs}")
is_pydantic_schema = _is_pydantic_class(schema)
if method == "function_calling":
if schema is None:
raise ValueError(
"schema must be specified when method is 'function_calling'. "
"Received None."
)
tool_name = convert_to_openai_tool(schema)["function"]["name"]
llm = self.bind_tools([schema], tool_choice=tool_name)
if is_pydantic_schema:
output_parser: OutputParserLike[Any] = PydanticToolsParser(
tools=[schema], # type: ignore[list-item]
first_tool_only=True,
)
else:
output_parser = JsonOutputKeyToolsParser(
key_name=tool_name, first_tool_only=True
)
elif method == "json_mode":
llm = self
# TODO bind response format when json mode available by API
# llm = self.bind(response_format={"type": "json_object"})
if is_pydantic_schema:
schema = cast(Type[BaseModel], schema)
output_parser = PydanticOutputParser(pydantic_object=schema)
else:
output_parser = JsonOutputParser()
elif method == "json_schema":
if schema is None:
raise ValueError(
"schema must be specified when method is not 'json_mode'. "
"Received None."
)
llm = self
# TODO bind response format when json schema available by API,
# update example
# llm = self.bind(
# response_format={"type": "json_object", "json_schema": schema}
# )
if is_pydantic_schema:
schema = cast(Type[BaseModel], schema)
output_parser = PydanticOutputParser(pydantic_object=schema)
else:
output_parser = JsonOutputParser()
else:
raise ValueError(
f"Unrecognized method argument. Expected one of 'function_calling' or "
f"'json_mode'. Received: '{method}'"
)
if include_raw:
parser_assign = RunnablePassthrough.assign(
parsed=itemgetter("raw") | output_parser, parsing_error=lambda _: None
)
parser_none = RunnablePassthrough.assign(parsed=lambda _: None)
parser_with_fallback = parser_assign.with_fallbacks(
[parser_none], exception_key="parsing_error"
)
return RunnableMap(raw=llm) | parser_with_fallback
else:
return llm | output_parser
def _get_role(self, message: BaseMessage) -> str:
"""
Get the role of LangChain BaseMessage
@ -1189,9 +1599,7 @@ class ChatSambaStudio(BaseChatModel):
Returns:
str: Role of the LangChain BaseMessage
"""
if isinstance(message, ChatMessage):
role = message.role
elif isinstance(message, SystemMessage):
if isinstance(message, SystemMessage):
role = "system"
elif isinstance(message, HumanMessage):
role = "user"
@ -1199,11 +1607,13 @@ class ChatSambaStudio(BaseChatModel):
role = "assistant"
elif isinstance(message, ToolMessage):
role = "tool"
elif isinstance(message, ChatMessage):
role = message.role
else:
raise TypeError(f"Got unknown type {message}")
return role
def _messages_to_string(self, messages: List[BaseMessage]) -> str:
def _messages_to_string(self, messages: List[BaseMessage], **kwargs: Any) -> str:
"""
Convert a list of BaseMessages to a:
- dumped json string with Role / content dict structure
@ -1221,17 +1631,48 @@ class ChatSambaStudio(BaseChatModel):
messages_dict: Dict[str, Any] = {
"conversation_id": "sambaverse-conversation-id",
"messages": [],
**kwargs,
}
for message in messages:
messages_dict["messages"].append(
{
if isinstance(message, AIMessage):
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
}
)
if "tool_calls" in message.additional_kwargs:
message_dict["tool_calls"] = message.additional_kwargs[
"tool_calls"
]
if message_dict["content"] == "":
message_dict["content"] = None
elif isinstance(message, ToolMessage):
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
"tool_call_id": message.tool_call_id,
}
else:
message_dict = {
"message_id": message.id,
"role": self._get_role(message),
"content": message.content,
}
messages_dict["messages"].append(message_dict)
messages_string = json.dumps(messages_dict)
else:
if "tools" in kwargs.keys():
raise NotImplementedError(
"tool calling not supported in API Generic V2 "
"without process_prompt, switch to OpenAI compatible API "
"or Generic V2 API with process_prompt=True"
)
messages_string = self.special_tokens["start"]
for message in messages:
messages_string += self.special_tokens["start_role"].format(
@ -1254,7 +1695,7 @@ class ChatSambaStudio(BaseChatModel):
base_url: string with url to do non streaming calls
streaming_url: string with url to do streaming calls
"""
if "openai" in url:
if "chat/completions" in url:
base_url = url
stream_url = url
else:
@ -1274,6 +1715,7 @@ class ChatSambaStudio(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
streaming: Optional[bool] = False,
**kwargs: Any,
) -> Response:
"""
Performs a post request to the LLM API.
@ -1288,7 +1730,7 @@ class ChatSambaStudio(BaseChatModel):
"""
# create request payload for openai compatible API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
messages_dicts = _create_message_dicts(messages)
data = {
"messages": messages_dicts,
@ -1300,17 +1742,21 @@ class ChatSambaStudio(BaseChatModel):
"top_k": self.top_k,
"stream": streaming,
"stream_options": self.stream_options,
**kwargs,
}
data = {key: value for key, value in data.items() if value is not None}
headers = {
"Authorization": f"Bearer "
f"{self.sambastudio_api_key.get_secret_value()}",
"Content-Type": "application/json",
**self.additional_headers,
}
# create request payload for generic v1 API
# create request payload for generic v2 API
elif "api/v2/predict/generic" in self.sambastudio_url:
items = [{"id": "item0", "value": self._messages_to_string(messages)}]
items = [
{"id": "item0", "value": self._messages_to_string(messages, **kwargs)}
]
params: Dict[str, Any] = {
"select_expert": self.model,
"process_prompt": self.process_prompt,
@ -1324,10 +1770,18 @@ class ChatSambaStudio(BaseChatModel):
params = {**params, **self.model_kwargs}
params = {key: value for key, value in params.items() if value is not None}
data = {"items": items, "params": params}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
headers = {
"key": self.sambastudio_api_key.get_secret_value(),
**self.additional_headers,
}
# create request payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
if "tools" in kwargs.keys():
raise NotImplementedError(
"tool calling not supported in API Generic V1, "
"switch to OpenAI compatible API or Generic V2 API"
)
params = {
"select_expert": self.model,
"process_prompt": self.process_prompt,
@ -1336,6 +1790,7 @@ class ChatSambaStudio(BaseChatModel):
"top_p": self.top_p,
"top_k": self.top_k,
"do_sample": self.do_sample,
**kwargs,
}
if self.model_kwargs is not None:
params = {**params, **self.model_kwargs}
@ -1354,7 +1809,10 @@ class ChatSambaStudio(BaseChatModel):
"instances": [self._messages_to_string(messages)],
"params": params,
}
headers = {"key": self.sambastudio_api_key.get_secret_value()}
headers = {
"key": self.sambastudio_api_key.get_secret_value(),
**self.additional_headers,
}
else:
raise ValueError(
@ -1399,9 +1857,15 @@ class ChatSambaStudio(BaseChatModel):
f"response: {response.text}"
)
additional_kwargs: Dict[str, Any] = {}
tool_calls = []
invalid_tool_calls = []
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
content = response_dict["choices"][0]["message"]["content"]
if "chat/completions" in self.sambastudio_url:
content = response_dict["choices"][0]["message"].get("content", "")
if content is None:
content = ""
id = response_dict["id"]
response_metadata = {
"finish_reason": response_dict["choices"][0]["finish_reason"],
@ -1410,12 +1874,44 @@ class ChatSambaStudio(BaseChatModel):
"system_fingerprint": response_dict["system_fingerprint"],
"created": response_dict["created"],
}
raw_tool_calls = response_dict["choices"][0]["message"].get("tool_calls")
if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
if isinstance(raw_tool_call["function"]["arguments"], dict):
raw_tool_call["function"]["arguments"] = json.dumps(
raw_tool_call["function"].get("arguments", {})
)
try:
tool_calls.append(
parse_tool_call(raw_tool_call, return_id=True)
)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
# process response payload for generic v2 API
elif "api/v2/predict/generic" in self.sambastudio_url:
content = response_dict["items"][0]["value"]["completion"]
id = response_dict["items"][0]["id"]
response_metadata = response_dict["items"][0]
raw_tool_calls = response_dict["items"][0]["value"].get("tool_calls")
if raw_tool_calls:
additional_kwargs["tool_calls"] = raw_tool_calls
for raw_tool_call in raw_tool_calls:
if isinstance(raw_tool_call["function"]["arguments"], dict):
raw_tool_call["function"]["arguments"] = json.dumps(
raw_tool_call["function"].get("arguments", {})
)
try:
tool_calls.append(
parse_tool_call(raw_tool_call, return_id=True)
)
except Exception as e:
invalid_tool_calls.append(
make_invalid_tool_call(raw_tool_call, str(e))
)
# process response payload for generic v1 API
elif "api/predict/generic" in self.sambastudio_url:
@ -1431,7 +1927,9 @@ class ChatSambaStudio(BaseChatModel):
return AIMessage(
content=content,
additional_kwargs={},
additional_kwargs=additional_kwargs,
tool_calls=tool_calls,
invalid_tool_calls=invalid_tool_calls,
response_metadata=response_metadata,
id=id,
)
@ -1458,7 +1956,7 @@ class ChatSambaStudio(BaseChatModel):
)
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
finish_reason = ""
client = sseclient.SSEClient(response)
for event in client.events():
@ -1674,7 +2172,7 @@ class ChatSambaStudio(BaseChatModel):
)
if stream_iter:
return generate_from_stream(stream_iter)
response = self._handle_request(messages, stop, streaming=False)
response = self._handle_request(messages, stop, streaming=False, **kwargs)
message = self._process_response(response)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@ -1702,7 +2200,7 @@ class ChatSambaStudio(BaseChatModel):
Yields:
chunk: ChatGenerationChunk with model partial generation
"""
response = self._handle_request(messages, stop, streaming=True)
response = self._handle_request(messages, stop, streaming=True, **kwargs)
for ai_message_chunk in self._process_stream_response(response):
chunk = ChatGenerationChunk(message=ai_message_chunk)
if run_manager:

View File

@ -27,20 +27,20 @@ class SambaStudio(LLM):
sambastudio_url="your-SambaStudio-environment-URL",
sambastudio_api_key="your-SambaStudio-API-key,
model_kwargs={
"model" : model or expert name (set for CoE endpoints),
"model" : model or expert name (set for Bundle endpoints),
"max_tokens" : max number of tokens to generate,
"temperature" : model temperature,
"top_p" : model top p,
"top_k" : model top k,
"do_sample" : wether to do sample
"process_prompt": wether to process prompt
(set for CoE generic v1 and v2 endpoints)
(set for Bundle generic v1 and v2 endpoints)
},
)
Key init args completion params:
model: str
The name of the model to use, e.g., Meta-Llama-3-70B-Instruct-4096
(set for CoE endpoints).
(set for Bundle endpoints).
streaming: bool
Whether to use streaming handler when using non streaming methods
model_kwargs: dict
@ -56,7 +56,8 @@ class SambaStudio(LLM):
do_sample: bool
wether to do sample
process_prompt:
wether to process prompt (set for CoE generic v1 and v2 endpoints)
wether to process prompt
(set for Bundle generic v1 and v2 endpoints)
Key init args client params:
sambastudio_url: str
SambaStudio endpoint Url
@ -72,14 +73,14 @@ class SambaStudio(LLM):
sambastudio_url = set with your SambaStudio deployed endpoint URL,
sambastudio_api_key = set with your SambaStudio deployed endpoint Key,
model_kwargs = {
"model" : model or expert name (set for CoE endpoints),
"model" : model or expert name (set for Bundle endpoints),
"max_tokens" : max number of tokens to generate,
"temperature" : model temperature,
"top_p" : model top p,
"top_k" : model top k,
"do_sample" : wether to do sample
"process_prompt" : wether to process prompt
(set for CoE generic v1 and v2 endpoints)
(set for Bundle generic v1 and v2 endpoints)
}
)
@ -174,7 +175,7 @@ class SambaStudio(LLM):
base_url: string with url to do non streaming calls
streaming_url: string with url to do streaming calls
"""
if "openai" in url:
if "chat/completions" in url:
base_url = url
stream_url = url
else:
@ -213,7 +214,7 @@ class SambaStudio(LLM):
_model_kwargs["stop_sequences"] = _stop_sequences
# set the parameters structure depending of the API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
if "select_expert" in _model_kwargs.keys():
_model_kwargs["model"] = _model_kwargs.pop("select_expert")
if "max_tokens_to_generate" in _model_kwargs.keys():
@ -278,7 +279,7 @@ class SambaStudio(LLM):
params = self._get_tuning_params(stop)
# create request payload for openAI v1 API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
messages_dict = [{"role": "user", "content": prompt[0]}]
data = {"messages": messages_dict, "stream": streaming, **params}
data = {key: value for key, value in data.items() if value is not None}
@ -377,7 +378,7 @@ class SambaStudio(LLM):
)
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
completion = response_dict["choices"][0]["message"]["content"]
# process response payload for generic v2 API
elif "api/v2/predict/generic" in self.sambastudio_url:
@ -412,7 +413,7 @@ class SambaStudio(LLM):
)
# process response payload for openai compatible API
if "openai" in self.sambastudio_url:
if "chat/completions" in self.sambastudio_url:
client = sseclient.SSEClient(response)
for event in client.events():
if event.event == "error_event":