From 47b1b7092dac56794c8799d6af149fa00612086e Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Wed, 21 Feb 2024 03:23:47 +0800 Subject: [PATCH] community[minor]: Add SparkLLM to community (#17702) --- docs/docs/integrations/llms/sparkllm.ipynb | 141 +++++++ .../langchain_community/llms/__init__.py | 10 + .../langchain_community/llms/sparkllm.py | 383 ++++++++++++++++++ .../integration_tests/llms/test_sparkllm.py | 19 + .../tests/unit_tests/llms/test_imports.py | 1 + 5 files changed, 554 insertions(+) create mode 100644 docs/docs/integrations/llms/sparkllm.ipynb create mode 100644 libs/community/langchain_community/llms/sparkllm.py create mode 100644 libs/community/tests/integration_tests/llms/test_sparkllm.py diff --git a/docs/docs/integrations/llms/sparkllm.ipynb b/docs/docs/integrations/llms/sparkllm.ipynb new file mode 100644 index 00000000000..f17c33a36d3 --- /dev/null +++ b/docs/docs/integrations/llms/sparkllm.ipynb @@ -0,0 +1,141 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# SparkLLM\n", + "[SparkLLM](https://xinghuo.xfyun.cn/spark) is a large-scale cognitive model independently developed by iFLYTEK.\n", + "It has cross-domain knowledge and language understanding ability by learning a large amount of texts, codes and images.\n", + "It can understand and perform tasks based on natural dialogue." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisite\n", + "- Get SparkLLM's app_id, api_key and api_secret from [iFlyTek SparkLLM API Console](https://console.xfyun.cn/services/bm3) (for more info, see [iFlyTek SparkLLM Intro](https://xinghuo.xfyun.cn/sparkapi) ), then set environment variables `IFLYTEK_SPARK_APP_ID`, `IFLYTEK_SPARK_API_KEY` and `IFLYTEK_SPARK_API_SECRET` or pass parameters when creating `ChatSparkLLM` as the demo above." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use SparkLLM" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"IFLYTEK_SPARK_APP_ID\"] = \"app_id\"\n", + "os.environ[\"IFLYTEK_SPARK_API_KEY\"] = \"api_key\"\n", + "os.environ[\"IFLYTEK_SPARK_API_SECRET\"] = \"api_secret\"" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/Users/liugddx/code/langchain/libs/core/langchain_core/_api/deprecation.py:117: LangChainDeprecationWarning: The function `__call__` was deprecated in LangChain 0.1.7 and will be removed in 0.2.0. Use invoke instead.\n", + " warn_deprecated(\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "My name is iFLYTEK Spark. How can I assist you today?\n" + ] + } + ], + "source": [ + "from langchain_community.llms import SparkLLM\n", + "\n", + "# Load the model\n", + "llm = SparkLLM()\n", + "\n", + "res = llm(\"What's your name?\")\n", + "print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-18T13:04:29.305856Z", + "start_time": "2024-02-18T13:04:28.085715Z" + } + }, + "outputs": [ + { + "data": { + "text/plain": "LLMResult(generations=[[Generation(text='Hello! How can I assist you today?')]], llm_output=None, run=[RunInfo(run_id=UUID('d8cdcd41-a698-4cbf-a28d-e74f9cd2037b'))])" + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "res = llm.generate(prompts=[\"hello!\"])\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "ExecuteTime": { + "end_time": "2024-02-18T13:05:44.640035Z", + "start_time": "2024-02-18T13:05:43.244126Z" + } + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Hello! How can I assist you today?\n" + ] + } + ], + "source": [ + "for res in llm.stream(\"foo:\"):\n", + " print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": false + }, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index 9adaf0af1be..42cd3f0c273 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -582,6 +582,12 @@ def _import_volcengine_maas() -> Any: return VolcEngineMaasLLM +def _import_sparkllm() -> Any: + from langchain_community.llms.sparkllm import SparkLLM + + return SparkLLM + + def __getattr__(name: str) -> Any: if name == "AI21": return _import_ai21() @@ -769,6 +775,8 @@ def __getattr__(name: str) -> Any: k: v() for k, v in get_type_to_cls_dict().items() } return type_to_cls_dict + elif name == "SparkLLM": + return _import_sparkllm() else: raise AttributeError(f"Could not find: {name}") @@ -861,6 +869,7 @@ __all__ = [ "YandexGPT", "Yuan2", "VolcEngineMaasLLM", + "SparkLLM", ] @@ -950,4 +959,5 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "yandex_gpt": _import_yandex_gpt, "yuan2": _import_yuan2, "VolcEngineMaasLLM": _import_volcengine_maas, + "SparkLLM": _import_sparkllm(), } diff --git a/libs/community/langchain_community/llms/sparkllm.py b/libs/community/langchain_community/llms/sparkllm.py new file mode 100644 index 00000000000..0f49a356b9d --- /dev/null +++ b/libs/community/langchain_community/llms/sparkllm.py @@ -0,0 +1,383 @@ +from __future__ import annotations + +import base64 +import hashlib +import hmac +import json +import logging +import queue +import threading +from datetime import datetime +from queue import Queue +from time import mktime +from typing import Any, Dict, Generator, Iterator, List, Optional +from urllib.parse import urlencode, urlparse, urlunparse +from wsgiref.handlers import format_date_time + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.outputs import GenerationChunk +from langchain_core.pydantic_v1 import Field, root_validator +from langchain_core.utils import get_from_dict_or_env + +logger = logging.getLogger(__name__) + + +class SparkLLM(LLM): + """Wrapper around iFlyTek's Spark large language model. + + To use, you should pass `app_id`, `api_key`, `api_secret` + as a named parameter to the constructor OR set environment + variables ``IFLYTEK_SPARK_APP_ID``, ``IFLYTEK_SPARK_API_KEY`` and + ``IFLYTEK_SPARK_API_SECRET`` + + Example: + .. code-block:: python + + client = SparkLLM( + spark_app_id="", + spark_api_key="", + spark_api_secret="" + ) + """ + + client: Any = None #: :meta private: + spark_app_id: Optional[str] = None + spark_api_key: Optional[str] = None + spark_api_secret: Optional[str] = None + spark_api_url: Optional[str] = None + spark_llm_domain: Optional[str] = None + spark_user_id: str = "lc_user" + streaming: bool = False + request_timeout: int = 30 + temperature: float = 0.5 + top_k: int = 4 + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["spark_app_id"] = get_from_dict_or_env( + values, + "spark_app_id", + "IFLYTEK_SPARK_APP_ID", + ) + values["spark_api_key"] = get_from_dict_or_env( + values, + "spark_api_key", + "IFLYTEK_SPARK_API_KEY", + ) + values["spark_api_secret"] = get_from_dict_or_env( + values, + "spark_api_secret", + "IFLYTEK_SPARK_API_SECRET", + ) + values["spark_app_url"] = get_from_dict_or_env( + values, + "spark_app_url", + "IFLYTEK_SPARK_APP_URL", + "wss://spark-api.xf-yun.com/v3.1/chat", + ) + values["spark_llm_domain"] = get_from_dict_or_env( + values, + "spark_llm_domain", + "IFLYTEK_SPARK_LLM_DOMAIN", + "generalv3", + ) + # put extra params into model_kwargs + values["model_kwargs"]["temperature"] = values["temperature"] or cls.temperature + values["model_kwargs"]["top_k"] = values["top_k"] or cls.top_k + + values["client"] = _SparkLLMClient( + app_id=values["spark_app_id"], + api_key=values["spark_api_key"], + api_secret=values["spark_api_secret"], + api_url=values["spark_api_url"], + spark_domain=values["spark_llm_domain"], + model_kwargs=values["model_kwargs"], + ) + return values + + @property + def _llm_type(self) -> str: + """Return type of llm.""" + return "spark-llm-chat" + + @property + def _default_params(self) -> Dict[str, Any]: + """Get the default parameters for calling SparkLLM API.""" + normal_params = { + "spark_llm_domain": self.spark_llm_domain, + "stream": self.streaming, + "request_timeout": self.request_timeout, + "top_k": self.top_k, + "temperature": self.temperature, + } + + return {**normal_params, **self.model_kwargs} + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to an sparkllm for each generation with a prompt. + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + Returns: + The string generated by the llm. + + Example: + .. code-block:: python + response = client("Tell me a joke.") + """ + if self.streaming: + completion = "" + for chunk in self._stream(prompt, stop, run_manager, **kwargs): + completion += chunk.text + return completion + completion = "" + self.client.arun( + [{"role": "user", "content": prompt}], + self.spark_user_id, + self.model_kwargs, + self.streaming, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" not in content: + continue + completion = content["data"]["content"] + + return completion + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + self.client.run( + [{"role": "user", "content": prompt}], + self.spark_user_id, + self.model_kwargs, + self.streaming, + ) + for content in self.client.subscribe(timeout=self.request_timeout): + if "data" not in content: + continue + delta = content["data"] + yield GenerationChunk(text=delta["content"]) + if run_manager: + run_manager.on_llm_new_token(delta) + + +class _SparkLLMClient: + """ + Use websocket-client to call the SparkLLM interface provided by Xfyun, + which is the iFlyTek's open platform for AI capabilities + """ + + def __init__( + self, + app_id: str, + api_key: str, + api_secret: str, + api_url: Optional[str] = None, + spark_domain: Optional[str] = None, + model_kwargs: Optional[dict] = None, + ): + try: + import websocket + + self.websocket_client = websocket + except ImportError: + raise ImportError( + "Could not import websocket client python package. " + "Please install it with `pip install websocket-client`." + ) + + self.api_url = ( + "wss://spark-api.xf-yun.com/v3.1/chat" if not api_url else api_url + ) + self.app_id = app_id + self.ws_url = _SparkLLMClient._create_url( + self.api_url, + api_key, + api_secret, + ) + self.model_kwargs = model_kwargs + self.spark_domain = spark_domain or "generalv3" + self.queue: Queue[Dict] = Queue() + self.blocking_message = {"content": "", "role": "assistant"} + + @staticmethod + def _create_url(api_url: str, api_key: str, api_secret: str) -> str: + """ + Generate a request url with an api key and an api secret. + """ + # generate timestamp by RFC1123 + date = format_date_time(mktime(datetime.now().timetuple())) + + # urlparse + parsed_url = urlparse(api_url) + host = parsed_url.netloc + path = parsed_url.path + + signature_origin = f"host: {host}\ndate: {date}\nGET {path} HTTP/1.1" + + # encrypt using hmac-sha256 + signature_sha = hmac.new( + api_secret.encode("utf-8"), + signature_origin.encode("utf-8"), + digestmod=hashlib.sha256, + ).digest() + + signature_sha_base64 = base64.b64encode(signature_sha).decode(encoding="utf-8") + + authorization_origin = f'api_key="{api_key}", algorithm="hmac-sha256", \ + headers="host date request-line", signature="{signature_sha_base64}"' + authorization = base64.b64encode(authorization_origin.encode("utf-8")).decode( + encoding="utf-8" + ) + + # generate url + params_dict = {"authorization": authorization, "date": date, "host": host} + encoded_params = urlencode(params_dict) + url = urlunparse( + ( + parsed_url.scheme, + parsed_url.netloc, + parsed_url.path, + parsed_url.params, + encoded_params, + parsed_url.fragment, + ) + ) + return url + + def run( + self, + messages: List[Dict], + user_id: str, + model_kwargs: Optional[dict] = None, + streaming: bool = False, + ) -> None: + self.websocket_client.enableTrace(False) + ws = self.websocket_client.WebSocketApp( + self.ws_url, + on_message=self.on_message, + on_error=self.on_error, + on_close=self.on_close, + on_open=self.on_open, + ) + ws.messages = messages + ws.user_id = user_id + ws.model_kwargs = self.model_kwargs if model_kwargs is None else model_kwargs + ws.streaming = streaming + ws.run_forever() + + def arun( + self, + messages: List[Dict], + user_id: str, + model_kwargs: Optional[dict] = None, + streaming: bool = False, + ) -> threading.Thread: + ws_thread = threading.Thread( + target=self.run, + args=( + messages, + user_id, + model_kwargs, + streaming, + ), + ) + ws_thread.start() + return ws_thread + + def on_error(self, ws: Any, error: Optional[Any]) -> None: + self.queue.put({"error": error}) + ws.close() + + def on_close(self, ws: Any, close_status_code: int, close_reason: str) -> None: + logger.debug( + { + "log": { + "close_status_code": close_status_code, + "close_reason": close_reason, + } + } + ) + self.queue.put({"done": True}) + + def on_open(self, ws: Any) -> None: + self.blocking_message = {"content": "", "role": "assistant"} + data = json.dumps( + self.gen_params( + messages=ws.messages, user_id=ws.user_id, model_kwargs=ws.model_kwargs + ) + ) + ws.send(data) + + def on_message(self, ws: Any, message: str) -> None: + data = json.loads(message) + code = data["header"]["code"] + if code != 0: + self.queue.put( + {"error": f"Code: {code}, Error: {data['header']['message']}"} + ) + ws.close() + else: + choices = data["payload"]["choices"] + status = choices["status"] + content = choices["text"][0]["content"] + if ws.streaming: + self.queue.put({"data": choices["text"][0]}) + else: + self.blocking_message["content"] += content + if status == 2: + if not ws.streaming: + self.queue.put({"data": self.blocking_message}) + usage_data = ( + data.get("payload", {}).get("usage", {}).get("text", {}) + if data + else {} + ) + self.queue.put({"usage": usage_data}) + ws.close() + + def gen_params( + self, messages: list, user_id: str, model_kwargs: Optional[dict] = None + ) -> dict: + data: Dict = { + "header": {"app_id": self.app_id, "uid": user_id}, + "parameter": {"chat": {"domain": self.spark_domain}}, + "payload": {"message": {"text": messages}}, + } + + if model_kwargs: + data["parameter"]["chat"].update(model_kwargs) + logger.debug(f"Spark Request Parameters: {data}") + return data + + def subscribe(self, timeout: Optional[int] = 30) -> Generator[Dict, None, None]: + while True: + try: + content = self.queue.get(timeout=timeout) + except queue.Empty as _: + raise TimeoutError( + f"SparkLLMClient wait LLM api response timeout {timeout} seconds" + ) + if "error" in content: + raise ConnectionError(content["error"]) + if "usage" in content: + yield content + continue + if "done" in content: + break + if "data" not in content: + break + yield content diff --git a/libs/community/tests/integration_tests/llms/test_sparkllm.py b/libs/community/tests/integration_tests/llms/test_sparkllm.py new file mode 100644 index 00000000000..6df9bf7c36f --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_sparkllm.py @@ -0,0 +1,19 @@ +"""Test SparkLLM.""" +from langchain_core.outputs import LLMResult + +from langchain_community.llms.sparkllm import SparkLLM + + +def test_call() -> None: + """Test valid call to sparkllm.""" + llm = SparkLLM() + output = llm("Say foo:") + assert isinstance(output, str) + + +def test_generate() -> None: + """Test valid call to sparkllm.""" + llm = SparkLLM() + output = llm.generate(["Say foo:"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index 6c8c6504e36..1f5489b2b35 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -90,6 +90,7 @@ EXPECT_ALL = [ "Yuan2", "VolcEngineMaasLLM", "WatsonxLLM", + "SparkLLM", ]