mirror of
https://github.com/hwchase17/langchain.git
synced 2025-08-10 13:27:36 +00:00
community[patch]: Sambanova integration api update (#21848)
- **Description:**: SambaStudio generic endpoint compatibility added Improved error description, and handling streaming examples added
This commit is contained in:
parent
c6da9533ac
commit
e6207ad4f3
@ -99,6 +99,36 @@
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import Sambaverse\n",
|
||||
"\n",
|
||||
"llm = Sambaverse(\n",
|
||||
" sambaverse_model_name=\"Meta/llama-2-7b-chat-hf\",\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": True,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
@ -147,12 +177,14 @@
|
||||
"import os\n",
|
||||
"\n",
|
||||
"sambastudio_base_url = \"<Your SambaStudio environment URL>\"\n",
|
||||
"# sambastudio_base_uri = \"<Your SambaStudio endpoint base URI>\" # optional, \"api/predict/nlp\" set as default\n",
|
||||
"sambastudio_project_id = \"<Your SambaStudio project id>\"\n",
|
||||
"sambastudio_endpoint_id = \"<Your SambaStudio endpoint id>\"\n",
|
||||
"sambastudio_api_key = \"<Your SambaStudio endpoint API key>\"\n",
|
||||
"\n",
|
||||
"# Set the environment variables\n",
|
||||
"os.environ[\"SAMBASTUDIO_BASE_URL\"] = sambastudio_base_url\n",
|
||||
"# os.environ[\"SAMBASTUDIO_BASE_URI\"] = sambastudio_base_uri\n",
|
||||
"os.environ[\"SAMBASTUDIO_PROJECT_ID\"] = sambastudio_project_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_ENDPOINT_ID\"] = sambastudio_endpoint_id\n",
|
||||
"os.environ[\"SAMBASTUDIO_API_KEY\"] = sambastudio_api_key"
|
||||
@ -188,6 +220,33 @@
|
||||
"\n",
|
||||
"print(llm.invoke(\"Why should I use open source models?\"))"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Streaming response\n",
|
||||
"\n",
|
||||
"from langchain_community.llms.sambanova import SambaStudio\n",
|
||||
"\n",
|
||||
"llm = SambaStudio(\n",
|
||||
" streaming=True,\n",
|
||||
" model_kwargs={\n",
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"for chunk in llm.stream(\"Why should I use open source models?\"):\n",
|
||||
" print(chunk, end=\"\", flush=True)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
|
@ -216,7 +216,7 @@ class Sambaverse(LLM):
|
||||
)
|
||||
"""
|
||||
|
||||
sambaverse_url: str = "https://sambaverse.sambanova.ai"
|
||||
sambaverse_url: str = ""
|
||||
"""Sambaverse url to use"""
|
||||
|
||||
sambaverse_api_key: str = ""
|
||||
@ -244,7 +244,10 @@ class Sambaverse(LLM):
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key exists in environment."""
|
||||
values["sambaverse_url"] = get_from_dict_or_env(
|
||||
values, "sambaverse_url", "SAMBAVERSE_URL"
|
||||
values,
|
||||
"sambaverse_url",
|
||||
"SAMBAVERSE_URL",
|
||||
default="https://sambaverse.sambanova.ai",
|
||||
)
|
||||
values["sambaverse_api_key"] = get_from_dict_or_env(
|
||||
values, "sambaverse_api_key", "SAMBAVERSE_API_KEY"
|
||||
@ -314,16 +317,24 @@ class Sambaverse(LLM):
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
optional_code = response["error"].get("code")
|
||||
optional_details = response["error"].get("details")
|
||||
optional_message = response["error"].get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}."
|
||||
f"Message: {optional_message}"
|
||||
f"Details: {optional_details}"
|
||||
f"Code: {optional_code}"
|
||||
)
|
||||
error = response.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}."
|
||||
f"{response}."
|
||||
)
|
||||
return response["result"]["responses"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
@ -364,16 +375,24 @@ class Sambaverse(LLM):
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
):
|
||||
if chunk["status_code"] != 200:
|
||||
optional_code = chunk["error"].get("code")
|
||||
optional_details = chunk["error"].get("details")
|
||||
optional_message = chunk["error"].get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"Message: {optional_message}"
|
||||
f"Details: {optional_details}"
|
||||
f"Code: {optional_code}"
|
||||
)
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
@ -477,19 +496,18 @@ class SSEndpointHandler:
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
"""
|
||||
|
||||
API_BASE_PATH = "/api"
|
||||
|
||||
def __init__(self, host_url: str):
|
||||
def __init__(self, host_url: str, api_base_uri: str):
|
||||
"""
|
||||
Initialize the SSEndpointHandler.
|
||||
|
||||
:param str host_url: Base URL of the DaaS API service
|
||||
:param str api_base_uri: Base URI of the DaaS API service
|
||||
"""
|
||||
self.host_url = host_url
|
||||
self.api_base_uri = api_base_uri
|
||||
self.http_session = requests.Session()
|
||||
|
||||
@staticmethod
|
||||
def _process_response(response: requests.Response) -> Dict:
|
||||
def _process_response(self, response: requests.Response) -> Dict:
|
||||
"""
|
||||
Processes the API response and returns the resulting dict.
|
||||
|
||||
@ -515,28 +533,47 @@ class SSEndpointHandler:
|
||||
result["status_code"] = response.status_code
|
||||
return result
|
||||
|
||||
@staticmethod
|
||||
def _process_streaming_response(
|
||||
self,
|
||||
response: requests.Response,
|
||||
) -> Generator[GenerationChunk, None, None]:
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
if "nlp" in self.api_base_uri:
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
)
|
||||
client = sseclient.SSEClient(response)
|
||||
close_conn = False
|
||||
for event in client.events():
|
||||
if event.event == "error_event":
|
||||
close_conn = True
|
||||
chunk = {
|
||||
"event": event.event,
|
||||
"data": event.data,
|
||||
"status_code": response.status_code,
|
||||
}
|
||||
yield chunk
|
||||
if close_conn:
|
||||
client.close()
|
||||
elif "generic" in self.api_base_uri:
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
client = sseclient.SSEClient(response)
|
||||
close_conn = False
|
||||
for event in client.events():
|
||||
if event.event == "error_event":
|
||||
close_conn = True
|
||||
text = json.dumps({"event": event.event, "data": event.data})
|
||||
chunk = GenerationChunk(text=text)
|
||||
yield chunk
|
||||
if close_conn:
|
||||
client.close()
|
||||
|
||||
def _get_full_url(self, path: str) -> str:
|
||||
"""
|
||||
@ -546,7 +583,7 @@ class SSEndpointHandler:
|
||||
:returns: the full API URL for the sub-path
|
||||
:rtype: str
|
||||
"""
|
||||
return f"{self.host_url}{self.API_BASE_PATH}{path}"
|
||||
return f"{self.host_url}/{self.api_base_uri}/{path}"
|
||||
|
||||
def nlp_predict(
|
||||
self,
|
||||
@ -570,16 +607,26 @@ class SSEndpointHandler:
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
if "nlp" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "generic" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"instances": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instances": input}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(f"/predict/nlp/{project}/{endpoint}"),
|
||||
self._get_full_url(f"{project}/{endpoint}"),
|
||||
headers={"key": key},
|
||||
json=data,
|
||||
)
|
||||
return SSEndpointHandler._process_response(response)
|
||||
return self._process_response(response)
|
||||
|
||||
def nlp_predict_stream(
|
||||
self,
|
||||
@ -588,7 +635,7 @@ class SSEndpointHandler:
|
||||
key: str,
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[GenerationChunk]:
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
@ -600,20 +647,32 @@ class SSEndpointHandler:
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
if "nlp" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "generic" in self.api_base_uri:
|
||||
if isinstance(input, list):
|
||||
input = input[0]
|
||||
if params:
|
||||
data = {"instance": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"instance": input}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.api_base_uri} not implemented"
|
||||
)
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(f"/predict/nlp/stream/{project}/{endpoint}"),
|
||||
self._get_full_url(f"stream/{project}/{endpoint}"),
|
||||
headers={"key": key},
|
||||
json=data,
|
||||
stream=True,
|
||||
)
|
||||
for chunk in SSEndpointHandler._process_streaming_response(response):
|
||||
for chunk in self._process_streaming_response(response):
|
||||
yield chunk
|
||||
|
||||
|
||||
@ -623,6 +682,7 @@ class SambaStudio(LLM):
|
||||
|
||||
To use, you should have the environment variables
|
||||
``SAMBASTUDIO_BASE_URL`` set with your SambaStudio environment URL.
|
||||
``SAMBASTUDIO_BASE_URI`` set with your SambaStudio api base URI.
|
||||
``SAMBASTUDIO_PROJECT_ID`` set with your SambaStudio project ID.
|
||||
``SAMBASTUDIO_ENDPOINT_ID`` set with your SambaStudio endpoint ID.
|
||||
``SAMBASTUDIO_API_KEY`` set with your SambaStudio endpoint API key.
|
||||
@ -637,6 +697,7 @@ class SambaStudio(LLM):
|
||||
from langchain_community.llms.sambanova import Sambaverse
|
||||
SambaStudio(
|
||||
sambastudio_base_url="your-SambaStudio-environment-URL",
|
||||
sambastudio_base_uri="your-SambaStudio-base-URI",
|
||||
sambastudio_project_id="your-SambaStudio-project-ID",
|
||||
sambastudio_endpoint_id="your-SambaStudio-endpoint-ID",
|
||||
sambastudio_api_key="your-SambaStudio-endpoint-API-key,
|
||||
@ -655,6 +716,9 @@ class SambaStudio(LLM):
|
||||
sambastudio_base_url: str = ""
|
||||
"""Base url to use"""
|
||||
|
||||
sambastudio_base_uri: str = ""
|
||||
"""endpoint base uri"""
|
||||
|
||||
sambastudio_project_id: str = ""
|
||||
"""Project id on sambastudio for model"""
|
||||
|
||||
@ -695,6 +759,12 @@ class SambaStudio(LLM):
|
||||
values["sambastudio_base_url"] = get_from_dict_or_env(
|
||||
values, "sambastudio_base_url", "SAMBASTUDIO_BASE_URL"
|
||||
)
|
||||
values["sambastudio_base_uri"] = get_from_dict_or_env(
|
||||
values,
|
||||
"sambastudio_base_uri",
|
||||
"SAMBASTUDIO_BASE_URI",
|
||||
default="api/predict/nlp",
|
||||
)
|
||||
values["sambastudio_project_id"] = get_from_dict_or_env(
|
||||
values, "sambastudio_project_id", "SAMBASTUDIO_PROJECT_ID"
|
||||
)
|
||||
@ -718,14 +788,17 @@ class SambaStudio(LLM):
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _stop_sequences
|
||||
# _model_kwargs['stop_sequences'] = ','.join(
|
||||
# f"'{x}'" for x in _stop_sequences)
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
# if not _kwarg_stop_sequences:
|
||||
# _model_kwargs["stop_sequences"] = ",".join(
|
||||
# f'"{x}"' for x in _stop_sequences
|
||||
# )
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
@ -754,12 +827,25 @@ class SambaStudio(LLM):
|
||||
tuning_params,
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
optional_detail = response["detail"]
|
||||
optional_detail = response.get("detail")
|
||||
if optional_detail:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n Details: {optional_detail}"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n response {response}"
|
||||
)
|
||||
if "nlp" in self.sambastudio_base_uri:
|
||||
return response["data"][0]["completion"]
|
||||
elif "generic" in self.sambastudio_base_uri:
|
||||
return response["predictions"][0]["completion"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}. Details: {optional_detail}"
|
||||
f"handling of endpoint uri: {self.sambastudio_base_uri} not implemented"
|
||||
)
|
||||
return response["data"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
@ -777,7 +863,9 @@ class SambaStudio(LLM):
|
||||
Raises:
|
||||
ValueError: If the prediction fails.
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||
ss_endpoint = SSEndpointHandler(
|
||||
self.sambastudio_base_url, self.sambastudio_base_uri
|
||||
)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
return self._handle_nlp_predict(ss_endpoint, prompt, tuning_params)
|
||||
|
||||
@ -802,7 +890,36 @@ class SambaStudio(LLM):
|
||||
prompt,
|
||||
tuning_params,
|
||||
):
|
||||
yield chunk
|
||||
if chunk["status_code"] != 200:
|
||||
error = chunk.get("error")
|
||||
if error:
|
||||
optional_code = error.get("code")
|
||||
optional_details = error.get("details")
|
||||
optional_message = error.get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}.\n"
|
||||
f"Message: {optional_message}\n"
|
||||
f"Details: {optional_details}\n"
|
||||
f"Code: {optional_code}\n"
|
||||
)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
if "nlp" in self.sambastudio_base_uri:
|
||||
text = json.loads(chunk["data"])["stream_token"]
|
||||
elif "generic" in self.sambastudio_base_uri:
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
||||
f"not implemented"
|
||||
)
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
@ -820,7 +937,9 @@ class SambaStudio(LLM):
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
"""
|
||||
ss_endpoint = SSEndpointHandler(self.sambastudio_base_url)
|
||||
ss_endpoint = SSEndpointHandler(
|
||||
self.sambastudio_base_url, self.sambastudio_base_uri
|
||||
)
|
||||
tuning_params = self._get_tuning_params(stop)
|
||||
try:
|
||||
if self.streaming:
|
||||
|
Loading…
Reference in New Issue
Block a user