mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-14 05:56:40 +00:00
community[patch]: Fix sparkllm embeddings api bug. (#19122)
- **Description:** Fix sparkllm embeddings api bug. @baskaryan PTAL
This commit is contained in:
@@ -70,14 +70,21 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
api_key=self.spark_api_key.get_secret_value(),
|
||||
api_secret=self.spark_api_secret.get_secret_value(),
|
||||
)
|
||||
content = self._get_body(self.spark_app_id.get_secret_value(), texts)
|
||||
response = requests.post(
|
||||
url, json=content, headers={"content-type": "application/json"}
|
||||
).text
|
||||
res_arr = self._parser_message(response)
|
||||
if res_arr is not None:
|
||||
return res_arr.tolist()
|
||||
return None
|
||||
embed_result: list = []
|
||||
for text in texts:
|
||||
query_context = {"messages": [{"content": text, "role": "user"}]}
|
||||
content = self._get_body(
|
||||
self.spark_app_id.get_secret_value(), query_context
|
||||
)
|
||||
response = requests.post(
|
||||
url, json=content, headers={"content-type": "application/json"}
|
||||
).text
|
||||
res_arr = self._parser_message(response)
|
||||
if res_arr is not None:
|
||||
embed_result.append(res_arr.tolist())
|
||||
else:
|
||||
embed_result.append(None)
|
||||
return embed_result
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
|
||||
"""Public method to get embeddings for a list of documents.
|
||||
@@ -145,7 +152,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
return u
|
||||
|
||||
@staticmethod
|
||||
def _get_body(appid: str, text: List[str]) -> Dict[str, Any]:
|
||||
def _get_body(appid: str, text: dict) -> Dict[str, Any]:
|
||||
body = {
|
||||
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
|
||||
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},
|
||||
|
Reference in New Issue
Block a user