mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-05 04:38:26 +00:00
community[patch]: Update the default api_url and reqeust_body of sparkllm embedding (#22136)
- **Description:** When I was running the SparkLLMTextEmbeddings, app_id, api_key and api_secret are all correct, but it cannot run normally using the current URL. ```python # example from langchain_community.embeddings import SparkLLMTextEmbeddings embedding= SparkLLMTextEmbeddings( spark_app_id="my-app-id", spark_api_key="my-api-key", spark_api_secret="my-api-secret" ) embedding= "hello" print(spark.embed_query(text1)) ```  So I updated the url and request body parameters according to [Embedding_api](https://www.xfyun.cn/doc/spark/Embedding_api.html), now it is runnable.
This commit is contained in:
parent
ba0dca46d7
commit
13140dc4ff
@ -5,25 +5,20 @@ import json
|
|||||||
import logging
|
import logging
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from time import mktime
|
from time import mktime
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Literal, Optional
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
from wsgiref.handlers import format_date_time
|
from wsgiref.handlers import format_date_time
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import requests
|
import requests
|
||||||
from langchain_core.embeddings import Embeddings
|
from langchain_core.embeddings import Embeddings
|
||||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||||
from numpy import ndarray
|
from numpy import ndarray
|
||||||
|
|
||||||
# Used for document and knowledge embedding
|
|
||||||
EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27"
|
|
||||||
# Used for user questions embedding
|
|
||||||
EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16"
|
|
||||||
|
|
||||||
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
|
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
|
||||||
|
|
||||||
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html
|
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_api.html
|
||||||
# Developers need to create an application in the console first, use the appid, APIKey,
|
# Developers need to create an application in the console first, use the appid, APIKey,
|
||||||
# and APISecret provided in the application for authentication,
|
# and APISecret provided in the application for authentication,
|
||||||
# and generate an authentication URL for handshake.
|
# and generate an authentication URL for handshake.
|
||||||
@ -43,39 +38,89 @@ class Url:
|
|||||||
|
|
||||||
|
|
||||||
class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||||
"""SparkLLM Text Embedding models."""
|
"""SparkLLM Text Embedding models.
|
||||||
|
|
||||||
spark_app_id: SecretStr
|
To use, you should have the environment variable "SPARK_APP_ID","SPARK_API_KEY"
|
||||||
spark_api_key: SecretStr
|
and "SPARK_API_SECRET" set your APP_ID, API_KEY and API_SECRET or pass it
|
||||||
spark_api_secret: SecretStr
|
as a name parameter to the constructor.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
.. code-block:: python
|
||||||
|
|
||||||
|
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||||
|
|
||||||
|
embeddings = SparkLLMTextEmbeddings(
|
||||||
|
spark_app_id="your-app-id",
|
||||||
|
spark_api_key="your-api-key",
|
||||||
|
spark_api_secret="your-api-secret"
|
||||||
|
)
|
||||||
|
text = "This is a test query."
|
||||||
|
query_result = embeddings.embed_query(text)
|
||||||
|
|
||||||
|
"""
|
||||||
|
|
||||||
|
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
|
||||||
|
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
|
||||||
|
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||||
|
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
|
||||||
|
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
|
||||||
|
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
|
||||||
|
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
|
||||||
|
"""Base URL path for API requests"""
|
||||||
|
domain: Literal["para", "query"] = Field(default="para")
|
||||||
|
"""This parameter is used for which Embedding this time belongs to.
|
||||||
|
If "para"(default), it belongs to document Embedding.
|
||||||
|
If "query", it belongs to query Embedding."""
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
"""Configuration for this pydantic object"""
|
||||||
|
|
||||||
|
allow_population_by_field_name = True
|
||||||
|
|
||||||
@root_validator(allow_reuse=True)
|
@root_validator(allow_reuse=True)
|
||||||
def validate_environment(cls, values: Dict) -> Dict:
|
def validate_environment(cls, values: Dict) -> Dict:
|
||||||
"""Validate that auth token exists in environment."""
|
"""Validate that auth token exists in environment."""
|
||||||
cls.spark_app_id = convert_to_secret_str(
|
values["spark_app_id"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
|
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
|
||||||
)
|
)
|
||||||
cls.spark_api_key = convert_to_secret_str(
|
values["spark_api_key"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
|
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
|
||||||
)
|
)
|
||||||
cls.spark_api_secret = convert_to_secret_str(
|
values["spark_api_secret"] = convert_to_secret_str(
|
||||||
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
|
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
|
||||||
)
|
)
|
||||||
return values
|
return values
|
||||||
|
|
||||||
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
|
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
|
||||||
|
"""Internal method to call Spark Embedding API and return embeddings.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
texts: A list of texts to embed.
|
||||||
|
host: Base URL path for API requests
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A list of list of floats representing the embeddings,
|
||||||
|
or list with value None if an error occurs.
|
||||||
|
"""
|
||||||
|
app_id = ""
|
||||||
|
api_key = ""
|
||||||
|
api_secret = ""
|
||||||
|
if self.spark_app_id:
|
||||||
|
app_id = self.spark_app_id.get_secret_value()
|
||||||
|
if self.spark_api_key:
|
||||||
|
api_key = self.spark_api_key.get_secret_value()
|
||||||
|
if self.spark_api_secret:
|
||||||
|
api_secret = self.spark_api_secret.get_secret_value()
|
||||||
url = self._assemble_ws_auth_url(
|
url = self._assemble_ws_auth_url(
|
||||||
request_url=host,
|
request_url=host,
|
||||||
method="POST",
|
method="POST",
|
||||||
api_key=self.spark_api_key.get_secret_value(),
|
api_key=api_key,
|
||||||
api_secret=self.spark_api_secret.get_secret_value(),
|
api_secret=api_secret,
|
||||||
)
|
)
|
||||||
embed_result: list = []
|
embed_result: list = []
|
||||||
for text in texts:
|
for text in texts:
|
||||||
query_context = {"messages": [{"content": text, "role": "user"}]}
|
query_context = {"messages": [{"content": text, "role": "user"}]}
|
||||||
content = self._get_body(
|
content = self._get_body(app_id, query_context)
|
||||||
self.spark_app_id.get_secret_value(), query_context
|
|
||||||
)
|
|
||||||
response = requests.post(
|
response = requests.post(
|
||||||
url, json=content, headers={"content-type": "application/json"}
|
url, json=content, headers={"content-type": "application/json"}
|
||||||
).text
|
).text
|
||||||
@ -95,7 +140,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
A list of embeddings, one for each text, or None if an error occurs.
|
A list of embeddings, one for each text, or None if an error occurs.
|
||||||
"""
|
"""
|
||||||
return self._embed(texts, EMBEDDING_P_API_URL)
|
return self._embed(texts, self.base_url)
|
||||||
|
|
||||||
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
|
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
|
||||||
"""Public method to get embedding for a single query text.
|
"""Public method to get embedding for a single query text.
|
||||||
@ -106,7 +151,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
|||||||
Returns:
|
Returns:
|
||||||
Embeddings for the text, or None if an error occurs.
|
Embeddings for the text, or None if an error occurs.
|
||||||
"""
|
"""
|
||||||
result = self._embed([text], EMBEDDING_Q_API_URL)
|
result = self._embed([text], self.base_url)
|
||||||
return result[0] if result is not None else None
|
return result[0] if result is not None else None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@ -151,11 +196,12 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
|||||||
u = Url(host, path, schema)
|
u = Url(host, path, schema)
|
||||||
return u
|
return u
|
||||||
|
|
||||||
@staticmethod
|
def _get_body(self, appid: str, text: dict) -> Dict[str, Any]:
|
||||||
def _get_body(appid: str, text: dict) -> Dict[str, Any]:
|
|
||||||
body = {
|
body = {
|
||||||
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
|
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
|
||||||
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},
|
"parameter": {
|
||||||
|
"emb": {"domain": self.domain, "feature": {"encoding": "utf8"}}
|
||||||
|
},
|
||||||
"payload": {
|
"payload": {
|
||||||
"messages": {
|
"messages": {
|
||||||
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode()
|
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode()
|
||||||
|
47
libs/community/tests/unit_tests/embeddings/test_sparkllm.py
Normal file
47
libs/community/tests/unit_tests/embeddings/test_sparkllm.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
import os
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
from langchain_core.pydantic_v1 import SecretStr, ValidationError
|
||||||
|
|
||||||
|
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||||
|
|
||||||
|
|
||||||
|
def test_sparkllm_initialization_by_alias() -> None:
|
||||||
|
# Effective initialization
|
||||||
|
embeddings = SparkLLMTextEmbeddings(
|
||||||
|
app_id="your-app-id", # type: ignore[arg-type]
|
||||||
|
api_key="your-api-key", # type: ignore[arg-type]
|
||||||
|
api_secret="your-api-secret", # type: ignore[arg-type]
|
||||||
|
)
|
||||||
|
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
|
||||||
|
assert (
|
||||||
|
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
|
||||||
|
== "your-api-secret"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_initialization_parameters_from_env() -> None:
|
||||||
|
# Setting environment variable
|
||||||
|
os.environ["SPARK_APP_ID"] = "your-app-id"
|
||||||
|
os.environ["SPARK_API_KEY"] = "your-api-key"
|
||||||
|
os.environ["SPARK_API_SECRET"] = "your-api-secret"
|
||||||
|
|
||||||
|
# Effective initialization
|
||||||
|
embeddings = SparkLLMTextEmbeddings()
|
||||||
|
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
|
||||||
|
assert (
|
||||||
|
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
|
||||||
|
)
|
||||||
|
assert (
|
||||||
|
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
|
||||||
|
== "your-api-secret"
|
||||||
|
)
|
||||||
|
|
||||||
|
# Environment variable missing
|
||||||
|
del os.environ["SPARK_APP_ID"]
|
||||||
|
with pytest.raises(ValidationError):
|
||||||
|
SparkLLMTextEmbeddings()
|
Loading…
Reference in New Issue
Block a user