community[patch]: Fix sparkllm embeddings api bug. (#19122)

- **Description:** Fix sparkllm embeddings api bug.
@baskaryan PTAL
This commit is contained in:
Guangdong Liu
2024-03-16 06:08:49 +08:00
committed by GitHub
parent b9c62fb905
commit cced3eb9bc
2 changed files with 73 additions and 36 deletions

View File

@@ -70,14 +70,21 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
api_key=self.spark_api_key.get_secret_value(),
api_secret=self.spark_api_secret.get_secret_value(),
)
content = self._get_body(self.spark_app_id.get_secret_value(), texts)
response = requests.post(
url, json=content, headers={"content-type": "application/json"}
).text
res_arr = self._parser_message(response)
if res_arr is not None:
return res_arr.tolist()
return None
embed_result: list = []
for text in texts:
query_context = {"messages": [{"content": text, "role": "user"}]}
content = self._get_body(
self.spark_app_id.get_secret_value(), query_context
)
response = requests.post(
url, json=content, headers={"content-type": "application/json"}
).text
res_arr = self._parser_message(response)
if res_arr is not None:
embed_result.append(res_arr.tolist())
else:
embed_result.append(None)
return embed_result
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents.
@@ -145,7 +152,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
return u
@staticmethod
def _get_body(appid: str, text: List[str]) -> Dict[str, Any]:
def _get_body(appid: str, text: dict) -> Dict[str, Any]:
body = {
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},