community[patch]: sambanova llm integration improvement (#23137)

- **Description:** sambanova sambaverse integration improvement: removed
input parsing that was changing raw user input, and was making to use
process prompt parameter as true mandatory
This commit is contained in:
Jorge Piedrahita Ortiz
2024-06-19 12:30:14 -05:00
committed by GitHub
parent e162893d7f
commit b3e53ffca0
2 changed files with 50 additions and 38 deletions

View File

@@ -43,7 +43,7 @@ class SVEndpointHandler:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
@@ -87,7 +87,7 @@ class SVEndpointHandler:
"""
Return the full API URL for a given path.
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}{self.API_BASE_PATH}"
@@ -108,23 +108,12 @@ class SVEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
response = self.http_session.post(
self._get_full_url(),
headers={
@@ -152,23 +141,12 @@ class SVEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
parsed_element = {
"conversation_id": "sambaverse-conversation-id",
"messages": [
{
"message_id": 0,
"role": "user",
"content": input,
}
],
}
parsed_input = json.dumps(parsed_element)
if params:
data = {"instance": parsed_input, "params": json.loads(params)}
data = {"instance": input, "params": json.loads(params)}
else:
data = {"instance": parsed_input}
data = {"instance": input}
# Streaming output
response = self.http_session.post(
self._get_full_url(),
@@ -522,7 +500,7 @@ class SSEndpointHandler:
:param requests.Response response: the response object to process
:return: the response dict
:rtype: dict
:type: dict
"""
result: Dict[str, Any] = {}
try:
@@ -581,7 +559,7 @@ class SSEndpointHandler:
:param str path: the sub-path
:returns: the full API URL for the sub-path
:rtype: str
:type: str
"""
return f"{self.host_url}/{self.api_base_uri}/{path}"
@@ -603,7 +581,7 @@ class SSEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if isinstance(input, str):
input = [input]
@@ -645,7 +623,7 @@ class SSEndpointHandler:
:param str input_str: Input string
:param str params: Input params string
:returns: Prediction results
:rtype: dict
:type: dict
"""
if "nlp" in self.api_base_uri:
if isinstance(input, str):