mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-26 16:43:35 +00:00
community: sambaverse api update (#21816)
- **Description:** fix sambaverse integration to make it compatible with sambaverse API update / minor changes in docs
This commit is contained in:
parent
7976fb1663
commit
700b1c7212
@ -22,7 +22,8 @@
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
|
||||
"**Sambaverse** allows you to interact with multiple open-source models. You can view the list of available models and interact with them in the [playground](https://sambaverse.sambanova.ai/playground).\n",
|
||||
" **Please note that Sambaverse's free offering is performance-limited.** Companies that are ready to evaluate the production tokens-per-second performance, volume throughput, and 10x lower total cost of ownership (TCO) of SambaNova should [contact us](https://sambaverse.sambanova.ai/contact-us) for a non-limited evaluation instance."
|
||||
]
|
||||
},
|
||||
{
|
||||
@ -88,9 +89,10 @@
|
||||
" \"temperature\": 0.01,\n",
|
||||
" \"process_prompt\": True,\n",
|
||||
" \"select_expert\": \"llama-2-7b-chat-hf\",\n",
|
||||
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n",
|
||||
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n",
|
||||
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n",
|
||||
" # \"stop_sequences\": '\\\"sequence1\\\",\\\"sequence2\\\"',\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
@ -177,10 +179,10 @@
|
||||
" \"do_sample\": True,\n",
|
||||
" \"max_tokens_to_generate\": 1000,\n",
|
||||
" \"temperature\": 0.01,\n",
|
||||
" # \"repetition_penalty\": {\"type\": \"float\", \"value\": \"1\"},\n",
|
||||
" # \"top_k\": {\"type\": \"int\", \"value\": \"50\"},\n",
|
||||
" # \"top_logprobs\": {\"type\": \"int\", \"value\": \"0\"},\n",
|
||||
" # \"top_p\": {\"type\": \"float\", \"value\": \"1\"}\n",
|
||||
" # \"repetition_penalty\": 1.0,\n",
|
||||
" # \"top_k\": 50,\n",
|
||||
" # \"top_logprobs\": 0,\n",
|
||||
" # \"top_p\": 1.0\n",
|
||||
" },\n",
|
||||
")\n",
|
||||
"\n",
|
||||
|
@ -47,8 +47,19 @@ class SVEndpointHandler:
|
||||
"""
|
||||
result: Dict[str, Any] = {}
|
||||
try:
|
||||
text_result = response.text.strip().split("\n")[-1]
|
||||
result = {"data": json.loads("".join(text_result.split("data: ")[1:]))}
|
||||
lines_result = response.text.strip().split("\n")
|
||||
text_result = lines_result[-1]
|
||||
if response.status_code == 200 and json.loads(text_result).get("error"):
|
||||
completion = ""
|
||||
for line in lines_result[:-1]:
|
||||
completion += json.loads(line)["result"]["responses"][0][
|
||||
"stream_token"
|
||||
]
|
||||
text_result = lines_result[-2]
|
||||
result = json.loads(text_result)
|
||||
result["result"]["responses"][0]["completion"] = completion
|
||||
else:
|
||||
result = json.loads(text_result)
|
||||
except Exception as e:
|
||||
result["detail"] = str(e)
|
||||
if "status_code" not in result:
|
||||
@ -58,25 +69,19 @@ class SVEndpointHandler:
|
||||
@staticmethod
|
||||
def _process_streaming_response(
|
||||
response: requests.Response,
|
||||
) -> Generator[GenerationChunk, None, None]:
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"could not import sseclient library"
|
||||
"Please install it with `pip install sseclient-py`."
|
||||
)
|
||||
client = sseclient.SSEClient(response)
|
||||
close_conn = False
|
||||
for event in client.events():
|
||||
if event.event == "error_event":
|
||||
close_conn = True
|
||||
text = json.dumps({"event": event.event, "data": event.data})
|
||||
chunk = GenerationChunk(text=text)
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
return chunk
|
||||
yield chunk
|
||||
if close_conn:
|
||||
client.close()
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
|
||||
def _get_full_url(self) -> str:
|
||||
"""
|
||||
@ -105,25 +110,21 @@ class SVEndpointHandler:
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
parsed_input = []
|
||||
for element in input:
|
||||
parsed_element = {
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{
|
||||
"message_id": 0,
|
||||
"role": "user",
|
||||
"content": element,
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
}
|
||||
parsed_input.append(json.dumps(parsed_element))
|
||||
parsed_input = json.dumps(parsed_element)
|
||||
if params:
|
||||
data = {"inputs": parsed_input, "params": json.loads(params)}
|
||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": parsed_input}
|
||||
data = {"instance": parsed_input}
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
headers={
|
||||
@ -141,7 +142,7 @@ class SVEndpointHandler:
|
||||
sambaverse_model_name: Optional[str],
|
||||
input: Union[List[str], str],
|
||||
params: Optional[str] = "",
|
||||
) -> Iterator[GenerationChunk]:
|
||||
) -> Iterator[Dict]:
|
||||
"""
|
||||
NLP predict using inline input string.
|
||||
|
||||
@ -153,25 +154,21 @@ class SVEndpointHandler:
|
||||
:returns: Prediction results
|
||||
:rtype: dict
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
parsed_input = []
|
||||
for element in input:
|
||||
parsed_element = {
|
||||
"conversation_id": "sambaverse-conversation-id",
|
||||
"messages": [
|
||||
{
|
||||
"message_id": 0,
|
||||
"role": "user",
|
||||
"content": element,
|
||||
"content": input,
|
||||
}
|
||||
],
|
||||
}
|
||||
parsed_input.append(json.dumps(parsed_element))
|
||||
parsed_input = json.dumps(parsed_element)
|
||||
if params:
|
||||
data = {"inputs": parsed_input, "params": json.loads(params)}
|
||||
data = {"instance": parsed_input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": parsed_input}
|
||||
data = {"instance": parsed_input}
|
||||
# Streaming output
|
||||
response = self.http_session.post(
|
||||
self._get_full_url(),
|
||||
@ -213,7 +210,7 @@ class Sambaverse(LLM):
|
||||
"max_tokens_to_generate": 100,
|
||||
"temperature": 0.7,
|
||||
"top_p": 1.0,
|
||||
"repetition_penalty": 1,
|
||||
"repetition_penalty": 1.0,
|
||||
"top_k": 50,
|
||||
},
|
||||
)
|
||||
@ -279,13 +276,17 @@ class Sambaverse(LLM):
|
||||
The tuning parameters as a JSON string.
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _stop_sequences
|
||||
_model_kwargs["stop_sequences"] = ",".join(f'"{x}"' for x in _stop_sequences)
|
||||
_kwarg_stop_sequences = _model_kwargs.get("stop_sequences", [])
|
||||
_stop_sequences = stop or _kwarg_stop_sequences
|
||||
if not _kwarg_stop_sequences:
|
||||
_model_kwargs["stop_sequences"] = ",".join(
|
||||
f'"{x}"' for x in _stop_sequences
|
||||
)
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
}
|
||||
_model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||
tuning_params = json.dumps(tuning_params_dict)
|
||||
return tuning_params
|
||||
|
||||
@ -313,14 +314,17 @@ class Sambaverse(LLM):
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
)
|
||||
if response["status_code"] != 200:
|
||||
optional_details = response["details"]
|
||||
optional_message = response["message"]
|
||||
optional_code = response["error"].get("code")
|
||||
optional_details = response["error"].get("details")
|
||||
optional_message = response["error"].get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}. Details: {optional_details}"
|
||||
f"{response['status_code']}. Message: {optional_message}"
|
||||
f"{response['status_code']}."
|
||||
f"Message: {optional_message}"
|
||||
f"Details: {optional_details}"
|
||||
f"Code: {optional_code}"
|
||||
)
|
||||
return response["data"]["completion"]
|
||||
return response["result"]["responses"][0]["completion"]
|
||||
|
||||
def _handle_completion_requests(
|
||||
self, prompt: Union[List[str], str], stop: Optional[List[str]]
|
||||
@ -359,7 +363,20 @@ class Sambaverse(LLM):
|
||||
for chunk in sdk.nlp_predict_stream(
|
||||
self.sambaverse_api_key, self.sambaverse_model_name, prompt, tuning_params
|
||||
):
|
||||
yield chunk
|
||||
if chunk["status_code"] != 200:
|
||||
optional_code = chunk["error"].get("code")
|
||||
optional_details = chunk["error"].get("details")
|
||||
optional_message = chunk["error"].get("message")
|
||||
raise ValueError(
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{chunk['status_code']}."
|
||||
f"Message: {optional_message}"
|
||||
f"Details: {optional_details}"
|
||||
f"Code: {optional_code}"
|
||||
)
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
generated_chunk = GenerationChunk(text=text)
|
||||
yield generated_chunk
|
||||
|
||||
def _stream(
|
||||
self,
|
||||
|
Loading…
Reference in New Issue
Block a user