community[patch]: Update the default api_url and reqeust_body of sparkllm embedding (#22136)

- **Description:** When I was running the SparkLLMTextEmbeddings,
app_id, api_key and api_secret are all correct, but it cannot run
normally using the current URL.

    ```python
    # example
    from langchain_community.embeddings import SparkLLMTextEmbeddings

    embedding= SparkLLMTextEmbeddings(
        spark_app_id="my-app-id",
        spark_api_key="my-api-key",
        spark_api_secret="my-api-secret"
    )
    embedding= "hello"
    print(spark.embed_query(text1))
    ```

![sparkembedding](https://github.com/langchain-ai/langchain/assets/55082429/11daa853-4f67-45b2-aae2-c95caa14e38c)
   
So I updated the url and request body parameters according to
[Embedding_api](https://www.xfyun.cn/doc/spark/Embedding_api.html), now
it is runnable.
This commit is contained in:
maang-h 2024-06-04 03:38:11 +08:00 committed by GitHub
parent ba0dca46d7
commit 13140dc4ff
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 118 additions and 25 deletions

View File

@ -5,25 +5,20 @@ import json
import logging
from datetime import datetime
from time import mktime
from typing import Any, Dict, List, Optional
from typing import Any, Dict, List, Literal, Optional
from urllib.parse import urlencode
from wsgiref.handlers import format_date_time
import numpy as np
import requests
from langchain_core.embeddings import Embeddings
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
from numpy import ndarray
# Used for document and knowledge embedding
EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27"
# Used for user questions embedding
EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16"
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_api.html
# Developers need to create an application in the console first, use the appid, APIKey,
# and APISecret provided in the application for authentication,
# and generate an authentication URL for handshake.
@ -43,39 +38,89 @@ class Url:
class SparkLLMTextEmbeddings(BaseModel, Embeddings):
"""SparkLLM Text Embedding models."""
"""SparkLLM Text Embedding models.
spark_app_id: SecretStr
spark_api_key: SecretStr
spark_api_secret: SecretStr
To use, you should have the environment variable "SPARK_APP_ID","SPARK_API_KEY"
and "SPARK_API_SECRET" set your APP_ID, API_KEY and API_SECRET or pass it
as a name parameter to the constructor.
Example:
.. code-block:: python
from langchain_community.embeddings import SparkLLMTextEmbeddings
embeddings = SparkLLMTextEmbeddings(
spark_app_id="your-app-id",
spark_api_key="your-api-key",
spark_api_secret="your-api-secret"
)
text = "This is a test query."
query_result = embeddings.embed_query(text)
"""
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
"""Base URL path for API requests"""
domain: Literal["para", "query"] = Field(default="para")
"""This parameter is used for which Embedding this time belongs to.
If "para"(default), it belongs to document Embedding.
If "query", it belongs to query Embedding."""
class Config:
"""Configuration for this pydantic object"""
allow_population_by_field_name = True
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that auth token exists in environment."""
cls.spark_app_id = convert_to_secret_str(
values["spark_app_id"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
)
cls.spark_api_key = convert_to_secret_str(
values["spark_api_key"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
)
cls.spark_api_secret = convert_to_secret_str(
values["spark_api_secret"] = convert_to_secret_str(
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
)
return values
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
"""Internal method to call Spark Embedding API and return embeddings.
Args:
texts: A list of texts to embed.
host: Base URL path for API requests
Returns:
A list of list of floats representing the embeddings,
or list with value None if an error occurs.
"""
app_id = ""
api_key = ""
api_secret = ""
if self.spark_app_id:
app_id = self.spark_app_id.get_secret_value()
if self.spark_api_key:
api_key = self.spark_api_key.get_secret_value()
if self.spark_api_secret:
api_secret = self.spark_api_secret.get_secret_value()
url = self._assemble_ws_auth_url(
request_url=host,
method="POST",
api_key=self.spark_api_key.get_secret_value(),
api_secret=self.spark_api_secret.get_secret_value(),
api_key=api_key,
api_secret=api_secret,
)
embed_result: list = []
for text in texts:
query_context = {"messages": [{"content": text, "role": "user"}]}
content = self._get_body(
self.spark_app_id.get_secret_value(), query_context
)
content = self._get_body(app_id, query_context)
response = requests.post(
url, json=content, headers={"content-type": "application/json"}
).text
@ -95,7 +140,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
Returns:
A list of embeddings, one for each text, or None if an error occurs.
"""
return self._embed(texts, EMBEDDING_P_API_URL)
return self._embed(texts, self.base_url)
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
"""Public method to get embedding for a single query text.
@ -106,7 +151,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
Returns:
Embeddings for the text, or None if an error occurs.
"""
result = self._embed([text], EMBEDDING_Q_API_URL)
result = self._embed([text], self.base_url)
return result[0] if result is not None else None
@staticmethod
@ -151,11 +196,12 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
u = Url(host, path, schema)
return u
@staticmethod
def _get_body(appid: str, text: dict) -> Dict[str, Any]:
def _get_body(self, appid: str, text: dict) -> Dict[str, Any]:
body = {
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},
"parameter": {
"emb": {"domain": self.domain, "feature": {"encoding": "utf8"}}
},
"payload": {
"messages": {
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode()

View File

@ -0,0 +1,47 @@
import os
from typing import cast
import pytest
from langchain_core.pydantic_v1 import SecretStr, ValidationError
from langchain_community.embeddings import SparkLLMTextEmbeddings
def test_sparkllm_initialization_by_alias() -> None:
# Effective initialization
embeddings = SparkLLMTextEmbeddings(
app_id="your-app-id", # type: ignore[arg-type]
api_key="your-api-key", # type: ignore[arg-type]
api_secret="your-api-secret", # type: ignore[arg-type]
)
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
assert (
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
)
assert (
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
== "your-api-secret"
)
def test_initialization_parameters_from_env() -> None:
# Setting environment variable
os.environ["SPARK_APP_ID"] = "your-app-id"
os.environ["SPARK_API_KEY"] = "your-api-key"
os.environ["SPARK_API_SECRET"] = "your-api-secret"
# Effective initialization
embeddings = SparkLLMTextEmbeddings()
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
assert (
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
)
assert (
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
== "your-api-secret"
)
# Environment variable missing
del os.environ["SPARK_APP_ID"]
with pytest.raises(ValidationError):
SparkLLMTextEmbeddings()