mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-23 07:09:31 +00:00
community: sambastudio llms api v2 support (#25063)
- **Description:** SambaStudio GenericV2 API support
This commit is contained in:
parent
8d784db107
commit
c7154a4045
@ -514,7 +514,7 @@ class SSEndpointHandler:
|
|||||||
response: requests.Response,
|
response: requests.Response,
|
||||||
) -> Generator[Dict, None, None]:
|
) -> Generator[Dict, None, None]:
|
||||||
"""Process the streaming response"""
|
"""Process the streaming response"""
|
||||||
if "nlp" in self.api_base_uri:
|
if "api/predict/nlp" in self.api_base_uri:
|
||||||
try:
|
try:
|
||||||
import sseclient
|
import sseclient
|
||||||
except ImportError:
|
except ImportError:
|
||||||
@ -535,14 +535,15 @@ class SSEndpointHandler:
|
|||||||
yield chunk
|
yield chunk
|
||||||
if close_conn:
|
if close_conn:
|
||||||
client.close()
|
client.close()
|
||||||
elif "generic" in self.api_base_uri:
|
elif (
|
||||||
|
"api/v2/predict/generic" in self.api_base_uri
|
||||||
|
or "api/predict/generic" in self.api_base_uri
|
||||||
|
):
|
||||||
try:
|
try:
|
||||||
for line in response.iter_lines():
|
for line in response.iter_lines():
|
||||||
chunk = json.loads(line)
|
chunk = json.loads(line)
|
||||||
if "status_code" not in chunk:
|
if "status_code" not in chunk:
|
||||||
chunk["status_code"] = response.status_code
|
chunk["status_code"] = response.status_code
|
||||||
if chunk["status_code"] == 200 and chunk.get("error"):
|
|
||||||
chunk["result"] = {"responses": [{"stream_token": ""}]}
|
|
||||||
yield chunk
|
yield chunk
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Error processing streaming response: {e}")
|
raise RuntimeError(f"Error processing streaming response: {e}")
|
||||||
@ -583,12 +584,18 @@ class SSEndpointHandler:
|
|||||||
"""
|
"""
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
if "nlp" in self.api_base_uri:
|
if "api/predict/nlp" in self.api_base_uri:
|
||||||
if params:
|
if params:
|
||||||
data = {"inputs": input, "params": json.loads(params)}
|
data = {"inputs": input, "params": json.loads(params)}
|
||||||
else:
|
else:
|
||||||
data = {"inputs": input}
|
data = {"inputs": input}
|
||||||
elif "generic" in self.api_base_uri:
|
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||||
|
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||||
|
if params:
|
||||||
|
data = {"items": items, "params": json.loads(params)}
|
||||||
|
else:
|
||||||
|
data = {"items": items}
|
||||||
|
elif "api/predict/generic" in self.api_base_uri:
|
||||||
if params:
|
if params:
|
||||||
data = {"instances": input, "params": json.loads(params)}
|
data = {"instances": input, "params": json.loads(params)}
|
||||||
else:
|
else:
|
||||||
@ -623,14 +630,22 @@ class SSEndpointHandler:
|
|||||||
:returns: Prediction results
|
:returns: Prediction results
|
||||||
:type: dict
|
:type: dict
|
||||||
"""
|
"""
|
||||||
if "nlp" in self.api_base_uri:
|
if "api/predict/nlp" in self.api_base_uri:
|
||||||
if isinstance(input, str):
|
if isinstance(input, str):
|
||||||
input = [input]
|
input = [input]
|
||||||
if params:
|
if params:
|
||||||
data = {"inputs": input, "params": json.loads(params)}
|
data = {"inputs": input, "params": json.loads(params)}
|
||||||
else:
|
else:
|
||||||
data = {"inputs": input}
|
data = {"inputs": input}
|
||||||
elif "generic" in self.api_base_uri:
|
elif "api/v2/predict/generic" in self.api_base_uri:
|
||||||
|
if isinstance(input, str):
|
||||||
|
input = [input]
|
||||||
|
items = [{"id": f"item{i}", "value": item} for i, item in enumerate(input)]
|
||||||
|
if params:
|
||||||
|
data = {"items": items, "params": json.loads(params)}
|
||||||
|
else:
|
||||||
|
data = {"items": items}
|
||||||
|
elif "api/predict/generic" in self.api_base_uri:
|
||||||
if isinstance(input, list):
|
if isinstance(input, list):
|
||||||
input = input[0]
|
input = input[0]
|
||||||
if params:
|
if params:
|
||||||
@ -770,10 +785,13 @@ class SambaStudio(LLM):
|
|||||||
# _model_kwargs["stop_sequences"] = ",".join(
|
# _model_kwargs["stop_sequences"] = ",".join(
|
||||||
# f'"{x}"' for x in _stop_sequences
|
# f'"{x}"' for x in _stop_sequences
|
||||||
# )
|
# )
|
||||||
tuning_params_dict = {
|
if "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||||
k: {"type": type(v).__name__, "value": str(v)}
|
tuning_params_dict = _model_kwargs
|
||||||
for k, v in (_model_kwargs.items())
|
else:
|
||||||
}
|
tuning_params_dict = {
|
||||||
|
k: {"type": type(v).__name__, "value": str(v)}
|
||||||
|
for k, v in (_model_kwargs.items())
|
||||||
|
}
|
||||||
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
# _model_kwargs["stop_sequences"] = _kwarg_stop_sequences
|
||||||
tuning_params = json.dumps(tuning_params_dict)
|
tuning_params = json.dumps(tuning_params_dict)
|
||||||
return tuning_params
|
return tuning_params
|
||||||
@ -814,9 +832,11 @@ class SambaStudio(LLM):
|
|||||||
f"Sambanova /complete call failed with status code "
|
f"Sambanova /complete call failed with status code "
|
||||||
f"{response['status_code']}.\n response {response}"
|
f"{response['status_code']}.\n response {response}"
|
||||||
)
|
)
|
||||||
if "nlp" in self.sambastudio_base_uri:
|
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||||
return response["data"][0]["completion"]
|
return response["data"][0]["completion"]
|
||||||
elif "generic" in self.sambastudio_base_uri:
|
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||||
|
return response["items"][0]["value"]["completion"]
|
||||||
|
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||||
return response["predictions"][0]["completion"]
|
return response["predictions"][0]["completion"]
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@ -885,10 +905,15 @@ class SambaStudio(LLM):
|
|||||||
f"{chunk['status_code']}."
|
f"{chunk['status_code']}."
|
||||||
f"{chunk}."
|
f"{chunk}."
|
||||||
)
|
)
|
||||||
if "nlp" in self.sambastudio_base_uri:
|
if "api/predict/nlp" in self.sambastudio_base_uri:
|
||||||
text = json.loads(chunk["data"])["stream_token"]
|
text = json.loads(chunk["data"])["stream_token"]
|
||||||
elif "generic" in self.sambastudio_base_uri:
|
elif "api/v2/predict/generic" in self.sambastudio_base_uri:
|
||||||
text = chunk["result"]["responses"][0]["stream_token"]
|
text = chunk["result"]["items"][0]["value"]["stream_token"]
|
||||||
|
elif "api/predict/generic" in self.sambastudio_base_uri:
|
||||||
|
if len(chunk["result"]["responses"]) > 0:
|
||||||
|
text = chunk["result"]["responses"][0]["stream_token"]
|
||||||
|
else:
|
||||||
|
text = ""
|
||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
f"handling of endpoint uri: {self.sambastudio_base_uri}"
|
||||||
|
Loading…
Reference in New Issue
Block a user