mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-22 14:49:29 +00:00
community: sambastudio llms api v2 support (#25063)
- **Description:** SambaStudio GenericV2 API support
This commit is contained in:
parent
8d784db107
commit
c7154a4045
@ -514,7 +514,7 @@ class SSEndpointHandler:
|
||||
response: requests.Response,
|
||||
) -> Generator[Dict, None, None]:
|
||||
"""Process the streaming response"""
|
||||
if "nlp" in self.api_base_uri:
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
try:
|
||||
import sseclient
|
||||
except ImportError:
|
||||
@ -535,14 +535,15 @@ class SSEndpointHandler:
|
||||
yield chunk
|
||||
if close_conn:
|
||||
client.close()
|
||||
elif "generic" in self.api_base_uri:
|
||||
elif (
|
||||
"api/v2/predict/generic" in self.api_base_uri
|
||||
or "api/predict/generic" in self.api_base_uri
|
||||
):
|
||||
try:
|
||||
for line in response.iter_lines():
|
||||
chunk = json.loads(line)
|
||||
if "status_code" not in chunk:
|
||||
chunk["status_code"] = response.status_code
|
||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||
@ -583,12 +584,18 @@ class SSEndpointHandler:
|
||||
"""
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if "nlp" in self.api_base_uri:
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "generic" in self.api_base_uri:
|
||||
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||
if params:
|
||||
data = {"items": items, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"items": items}
|
||||
elif "api/predict/generic" in self.api_base_uri:
|
||||
if params:
|
||||
data = {"instances": input, "params": json.loads(params)}
|
||||
else:
|
||||
@ -623,14 +630,22 @@ class SSEndpointHandler:
|
||||
:returns: Prediction results
|
||||
:type: dict
|
||||
"""
|
||||
if "nlp" in self.api_base_uri:
|
||||
if "api/predict/nlp" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
if params:
|
||||
data = {"inputs": input, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"inputs": input}
|
||||
elif "generic" in self.api_base_uri:
|
||||
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||
if isinstance(input, str):
|
||||
input = [input]
|
||||
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||
if params:
|
||||
data = {"items": items, "params": json.loads(params)}
|
||||
else:
|
||||
data = {"items": items}
|
||||
elif "api/predict/generic" in self.api_base_uri:
|
||||
if isinstance(input, list):
|
||||
input = input[0]
|
||||
if params:
|
||||
@ -770,6 +785,9 @@ class SambaStudio(LLM):
|
||||
# _model_kwargs["stop_sequences"] = ",".join(
|
||||
# f'"{x}"' for x in _stop_sequences
|
||||
# )
|
||||
if "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
tuning_params_dict = _model_kwargs
|
||||
else:
|
||||
tuning_params_dict = {
|
||||
k: {"type": type(v).__name__, "value": str(v)}
|
||||
for k, v in (_model_kwargs.items())
|
||||
@ -814,9 +832,11 @@ class SambaStudio(LLM):
|
||||
f"Sambanova /complete call failed with status code "
|
||||
f"{response['status_code']}.\n response {response}"
|
||||
)
|
||||
if "nlp" in self.sambastudio_base_uri:
|
||||
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||
return response["data"][0]["completion"]
|
||||
elif "generic" in self.sambastudio_base_uri:
|
||||
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
return response["items"][0]["value"]["completion"]
|
||||
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||
return response["predictions"][0]["completion"]
|
||||
else:
|
||||
raise ValueError(
|
||||
@ -885,10 +905,15 @@ class SambaStudio(LLM):
|
||||
f"{chunk['status_code']}."
|
||||
f"{chunk}."
|
||||
)
|
||||
if "nlp" in self.sambastudio_base_uri:
|
||||
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||
text = json.loads(chunk["data"])["stream_token"]
|
||||
elif "generic" in self.sambastudio_base_uri:
|
||||
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||
text = chunk["result"]["items"][0]["value"]["stream_token"]
|
||||
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||
if len(chunk["result"]["responses"]) > 0:
|
||||
text = chunk["result"]["responses"][0]["stream_token"]
|
||||
else:
|
||||
text = ""
|
||||
else:
|
||||
raise ValueError(
|
||||
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
||||
|
Loading…
Reference in New Issue
Block a user