diff --git a/libs/community/langchain_community/embeddings/sparkllm.py b/libs/community/langchain_community/embeddings/sparkllm.py index 44a6b9a7fda..fe82f9b3126 100644 --- a/libs/community/langchain_community/embeddings/sparkllm.py +++ b/libs/community/langchain_community/embeddings/sparkllm.py @@ -5,25 +5,20 @@ import json import logging from datetime import datetime from time import mktime -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Literal, Optional from urllib.parse import urlencode from wsgiref.handlers import format_date_time import numpy as np import requests from langchain_core.embeddings import Embeddings -from langchain_core.pydantic_v1 import BaseModel, SecretStr, root_validator +from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr, root_validator from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env from numpy import ndarray -# Used for document and knowledge embedding -EMBEDDING_P_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/sa8a05c27" -# Used for user questions embedding -EMBEDDING_Q_API_URL: str = "https://cn-huabei-1.xf-yun.com/v1/private/s50d55a16" - # SparkLLMTextEmbeddings is an embedding model provided by iFLYTEK Co., Ltd.. (https://iflytek.com/en/). -# Official Website: https://www.xfyun.cn/doc/spark/Embedding_new_api.html +# Official Website: https://www.xfyun.cn/doc/spark/Embedding_api.html # Developers need to create an application in the console first, use the appid, APIKey, # and APISecret provided in the application for authentication, # and generate an authentication URL for handshake. @@ -43,39 +38,89 @@ class Url: class SparkLLMTextEmbeddings(BaseModel, Embeddings): - """SparkLLM Text Embedding models.""" + """SparkLLM Text Embedding models. - spark_app_id: SecretStr - spark_api_key: SecretStr - spark_api_secret: SecretStr + To use, you should have the environment variable "SPARK_APP_ID","SPARK_API_KEY" + and "SPARK_API_SECRET" set your APP_ID, API_KEY and API_SECRET or pass it + as a name parameter to the constructor. + + Example: + .. code-block:: python + + from langchain_community.embeddings import SparkLLMTextEmbeddings + + embeddings = SparkLLMTextEmbeddings( + spark_app_id="your-app-id", + spark_api_key="your-api-key", + spark_api_secret="your-api-secret" + ) + text = "This is a test query." + query_result = embeddings.embed_query(text) + + """ + + spark_app_id: Optional[SecretStr] = Field(default=None, alias="app_id") + """Automatically inferred from env var `SPARK_APP_ID` if not provided.""" + spark_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") + """Automatically inferred from env var `SPARK_API_KEY` if not provided.""" + spark_api_secret: Optional[SecretStr] = Field(default=None, alias="api_secret") + """Automatically inferred from env var `SPARK_API_SECRET` if not provided.""" + base_url: str = Field(default="https://emb-cn-huabei-1.xf-yun.com/") + """Base URL path for API requests""" + domain: Literal["para", "query"] = Field(default="para") + """This parameter is used for which Embedding this time belongs to. + If "para"(default), it belongs to document Embedding. + If "query", it belongs to query Embedding.""" + + class Config: + """Configuration for this pydantic object""" + + allow_population_by_field_name = True @root_validator(allow_reuse=True) def validate_environment(cls, values: Dict) -> Dict: """Validate that auth token exists in environment.""" - cls.spark_app_id = convert_to_secret_str( + values["spark_app_id"] = convert_to_secret_str( get_from_dict_or_env(values, "spark_app_id", "SPARK_APP_ID") ) - cls.spark_api_key = convert_to_secret_str( + values["spark_api_key"] = convert_to_secret_str( get_from_dict_or_env(values, "spark_api_key", "SPARK_API_KEY") ) - cls.spark_api_secret = convert_to_secret_str( + values["spark_api_secret"] = convert_to_secret_str( get_from_dict_or_env(values, "spark_api_secret", "SPARK_API_SECRET") ) return values def _embed(self, texts: List[str], host: str) -> Optional[List[List[float]]]: + """Internal method to call Spark Embedding API and return embeddings. + + Args: + texts: A list of texts to embed. + host: Base URL path for API requests + + Returns: + A list of list of floats representing the embeddings, + or list with value None if an error occurs. + """ + app_id = "" + api_key = "" + api_secret = "" + if self.spark_app_id: + app_id = self.spark_app_id.get_secret_value() + if self.spark_api_key: + api_key = self.spark_api_key.get_secret_value() + if self.spark_api_secret: + api_secret = self.spark_api_secret.get_secret_value() url = self._assemble_ws_auth_url( request_url=host, method="POST", - api_key=self.spark_api_key.get_secret_value(), - api_secret=self.spark_api_secret.get_secret_value(), + api_key=api_key, + api_secret=api_secret, ) embed_result: list = [] for text in texts: query_context = {"messages": [{"content": text, "role": "user"}]} - content = self._get_body( - self.spark_app_id.get_secret_value(), query_context - ) + content = self._get_body(app_id, query_context) response = requests.post( url, json=content, headers={"content-type": "application/json"} ).text @@ -95,7 +140,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings): Returns: A list of embeddings, one for each text, or None if an error occurs. """ - return self._embed(texts, EMBEDDING_P_API_URL) + return self._embed(texts, self.base_url) def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override] """Public method to get embedding for a single query text. @@ -106,7 +151,7 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings): Returns: Embeddings for the text, or None if an error occurs. """ - result = self._embed([text], EMBEDDING_Q_API_URL) + result = self._embed([text], self.base_url) return result[0] if result is not None else None @staticmethod @@ -151,11 +196,12 @@ class SparkLLMTextEmbeddings(BaseModel, Embeddings): u = Url(host, path, schema) return u - @staticmethod - def _get_body(appid: str, text: dict) -> Dict[str, Any]: + def _get_body(self, appid: str, text: dict) -> Dict[str, Any]: body = { "header": {"app_id": appid, "uid": "39769795890", "status": 3}, - "parameter": {"emb": {"feature": {"encoding": "utf8"}}}, + "parameter": { + "emb": {"domain": self.domain, "feature": {"encoding": "utf8"}} + }, "payload": { "messages": { "text": base64.b64encode(json.dumps(text).encode("utf-8")).decode() diff --git a/libs/community/tests/unit_tests/embeddings/test_sparkllm.py b/libs/community/tests/unit_tests/embeddings/test_sparkllm.py new file mode 100644 index 00000000000..d318035106e --- /dev/null +++ b/libs/community/tests/unit_tests/embeddings/test_sparkllm.py @@ -0,0 +1,47 @@ +import os +from typing import cast + +import pytest +from langchain_core.pydantic_v1 import SecretStr, ValidationError + +from langchain_community.embeddings import SparkLLMTextEmbeddings + + +def test_sparkllm_initialization_by_alias() -> None: + # Effective initialization + embeddings = SparkLLMTextEmbeddings( + app_id="your-app-id", # type: ignore[arg-type] + api_key="your-api-key", # type: ignore[arg-type] + api_secret="your-api-secret", # type: ignore[arg-type] + ) + assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id" + assert ( + cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key" + ) + assert ( + cast(SecretStr, embeddings.spark_api_secret).get_secret_value() + == "your-api-secret" + ) + + +def test_initialization_parameters_from_env() -> None: + # Setting environment variable + os.environ["SPARK_APP_ID"] = "your-app-id" + os.environ["SPARK_API_KEY"] = "your-api-key" + os.environ["SPARK_API_SECRET"] = "your-api-secret" + + # Effective initialization + embeddings = SparkLLMTextEmbeddings() + assert cast(SecretStr, embeddings.spark_app_id).get_secret_value() == "your-app-id" + assert ( + cast(SecretStr, embeddings.spark_api_key).get_secret_value() == "your-api-key" + ) + assert ( + cast(SecretStr, embeddings.spark_api_secret).get_secret_value() + == "your-api-secret" + ) + + # Environment variable missing + del os.environ["SPARK_APP_ID"] + with pytest.raises(ValidationError): + SparkLLMTextEmbeddings()