mirror of
https://github.com/hwchase17/langchain.git
synced 2025-07-03 19:57:51 +00:00
community[patch]: Update the default api_url and reqeust_body of sparkllm embedding (#22136)
- **Description:** When I was running the SparkLLMTextEmbeddings, app_id, api_key and api_secret are all correct, but it cannot run normally using the current URL. ```python # example from langchain_community.embeddings import SparkLLMTextEmbeddings embedding= SparkLLMTextEmbeddings( spark_app_id="my-app-id", spark_api_key="my-api-key", spark_api_secret="my-api-secret" ) embedding= "hello" print(spark.embed_query(text1)) ```  So I updated the url and request body parameters according to [Embedding_api](https://www.xfyun.cn/doc/spark/Embedding_api.html), now it is runnable.
This commit is contained in:
parent
ba0dca46d7
commit
13140dc4ff
@ -5,25 +5,20 @@ import json
|
||||
import logging
|
||||
from datetime import datetime
|
||||
from time import mktime
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
from urllib.parse import urlencode
|
||||
from wsgiref.handlers import format_date_time
|
||||
|
||||
import numpy as np
|
||||
import requests
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator
|
||||
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator
|
||||
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
||||
from numpy import ndarray
|
||||
|
||||
# Used for document and knowledge embedding
|
||||
EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27"
|
||||
# Used for user questions embedding
|
||||
EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16"
|
||||
|
||||
# SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/).
|
||||
|
||||
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html
|
||||
# Official Website: https://www.xfyun.cn/doc/spark/Embedding_api.html
|
||||
# Developers need to create an application in the console first, use the appid, APIKey,
|
||||
# and APISecret provided in the application for authentication,
|
||||
# and generate an authentication URL for handshake.
|
||||
@ -43,39 +38,89 @@ class Url:
|
||||
|
||||
|
||||
class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
"""SparkLLM Text Embedding models."""
|
||||
"""SparkLLM Text Embedding models.
|
||||
|
||||
spark_app_id: SecretStr
|
||||
spark_api_key: SecretStr
|
||||
spark_api_secret: SecretStr
|
||||
To use, you should have the environment variable "SPARK_APP_ID","SPARK_API_KEY"
|
||||
and "SPARK_API_SECRET" set your APP_ID, API_KEY and API_SECRET or pass it
|
||||
as a name parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||
|
||||
embeddings = SparkLLMTextEmbeddings(
|
||||
spark_app_id="your-app-id",
|
||||
spark_api_key="your-api-key",
|
||||
spark_api_secret="your-api-secret"
|
||||
)
|
||||
text = "This is a test query."
|
||||
query_result = embeddings.embed_query(text)
|
||||
|
||||
"""
|
||||
|
||||
spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id")
|
||||
"""Automatically inferred from env var `SPARK_APP_ID` if not provided."""
|
||||
spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env var `SPARK_API_KEY` if not provided."""
|
||||
spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret")
|
||||
"""Automatically inferred from env var `SPARK_API_SECRET` if not provided."""
|
||||
base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/")
|
||||
"""Base URL path for API requests"""
|
||||
domain: Literal["para", "query"] = Field(default="para")
|
||||
"""This parameter is used for which Embedding this time belongs to.
|
||||
If "para"(default), it belongs to document Embedding.
|
||||
If "query", it belongs to query Embedding."""
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object"""
|
||||
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(allow_reuse=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that auth token exists in environment."""
|
||||
cls.spark_app_id = convert_to_secret_str(
|
||||
values["spark_app_id"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID")
|
||||
)
|
||||
cls.spark_api_key = convert_to_secret_str(
|
||||
values["spark_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY")
|
||||
)
|
||||
cls.spark_api_secret = convert_to_secret_str(
|
||||
values["spark_api_secret"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET")
|
||||
)
|
||||
return values
|
||||
|
||||
def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]:
|
||||
"""Internal method to call Spark Embedding API and return embeddings.
|
||||
|
||||
Args:
|
||||
texts: A list of texts to embed.
|
||||
host: Base URL path for API requests
|
||||
|
||||
Returns:
|
||||
A list of list of floats representing the embeddings,
|
||||
or list with value None if an error occurs.
|
||||
"""
|
||||
app_id = ""
|
||||
api_key = ""
|
||||
api_secret = ""
|
||||
if self.spark_app_id:
|
||||
app_id = self.spark_app_id.get_secret_value()
|
||||
if self.spark_api_key:
|
||||
api_key = self.spark_api_key.get_secret_value()
|
||||
if self.spark_api_secret:
|
||||
api_secret = self.spark_api_secret.get_secret_value()
|
||||
url = self._assemble_ws_auth_url(
|
||||
request_url=host,
|
||||
method="POST",
|
||||
api_key=self.spark_api_key.get_secret_value(),
|
||||
api_secret=self.spark_api_secret.get_secret_value(),
|
||||
api_key=api_key,
|
||||
api_secret=api_secret,
|
||||
)
|
||||
embed_result: list = []
|
||||
for text in texts:
|
||||
query_context = {"messages": [{"content": text, "role": "user"}]}
|
||||
content = self._get_body(
|
||||
self.spark_app_id.get_secret_value(), query_context
|
||||
)
|
||||
content = self._get_body(app_id, query_context)
|
||||
response = requests.post(
|
||||
url, json=content, headers={"content-type": "application/json"}
|
||||
).text
|
||||
@ -95,7 +140,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
A list of embeddings, one for each text, or None if an error occurs.
|
||||
"""
|
||||
return self._embed(texts, EMBEDDING_P_API_URL)
|
||||
return self._embed(texts, self.base_url)
|
||||
|
||||
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
|
||||
"""Public method to get embedding for a single query text.
|
||||
@ -106,7 +151,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
Returns:
|
||||
Embeddings for the text, or None if an error occurs.
|
||||
"""
|
||||
result = self._embed([text], EMBEDDING_Q_API_URL)
|
||||
result = self._embed([text], self.base_url)
|
||||
return result[0] if result is not None else None
|
||||
|
||||
@staticmethod
|
||||
@ -151,11 +196,12 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings):
|
||||
u = Url(host, path, schema)
|
||||
return u
|
||||
|
||||
@staticmethod
|
||||
def _get_body(appid: str, text: dict) -> Dict[str, Any]:
|
||||
def _get_body(self, appid: str, text: dict) -> Dict[str, Any]:
|
||||
body = {
|
||||
"header": {"app_id": appid, "uid": "39769795890", "status": 3},
|
||||
"parameter": {"emb": {"feature": {"encoding": "utf8"}}},
|
||||
"parameter": {
|
||||
"emb": {"domain": self.domain, "feature": {"encoding": "utf8"}}
|
||||
},
|
||||
"payload": {
|
||||
"messages": {
|
||||
"text": base64.b64encode(json.dumps(text).encode("utf-8")).decode()
|
||||
|
47
libs/community/tests/unit_tests/embeddings/test_sparkllm.py
Normal file
47
libs/community/tests/unit_tests/embeddings/test_sparkllm.py
Normal file
@ -0,0 +1,47 @@
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
import pytest
|
||||
from langchain_core.pydantic_v1 import SecretStr, ValidationError
|
||||
|
||||
from langchain_community.embeddings import SparkLLMTextEmbeddings
|
||||
|
||||
|
||||
def test_sparkllm_initialization_by_alias() -> None:
|
||||
# Effective initialization
|
||||
embeddings = SparkLLMTextEmbeddings(
|
||||
app_id="your-app-id", # type: ignore[arg-type]
|
||||
api_key="your-api-key", # type: ignore[arg-type]
|
||||
api_secret="your-api-secret", # type: ignore[arg-type]
|
||||
)
|
||||
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
|
||||
assert (
|
||||
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
|
||||
)
|
||||
assert (
|
||||
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
|
||||
== "your-api-secret"
|
||||
)
|
||||
|
||||
|
||||
def test_initialization_parameters_from_env() -> None:
|
||||
# Setting environment variable
|
||||
os.environ["SPARK_APP_ID"] = "your-app-id"
|
||||
os.environ["SPARK_API_KEY"] = "your-api-key"
|
||||
os.environ["SPARK_API_SECRET"] = "your-api-secret"
|
||||
|
||||
# Effective initialization
|
||||
embeddings = SparkLLMTextEmbeddings()
|
||||
assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id"
|
||||
assert (
|
||||
cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key"
|
||||
)
|
||||
assert (
|
||||
cast(SecretStr, embeddings.spark_api_secret).get_secret_value()
|
||||
== "your-api-secret"
|
||||
)
|
||||
|
||||
# Environment variable missing
|
||||
del os.environ["SPARK_APP_ID"]
|
||||
with pytest.raises(ValidationError):
|
||||
SparkLLMTextEmbeddings()
|
Loading…
Reference in New Issue
Block a user