From c776cfc59902e508ba662399a63b405f6ba22877 Mon Sep 17 00:00:00 2001 From: wulixuan Date: Thu, 15 Feb 2024 03:46:20 +0800 Subject: [PATCH] community[minor]: integrate with model Yuan2.0 (#15411) 1. integrate with [`Yuan2.0`](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/README-EN.md) 2. update `langchain.llms` 3. add a new doc for [Yuan2.0 integration](docs/docs/integrations/llms/yuan2.ipynb) --------- Co-authored-by: Harrison Chase Co-authored-by: Bagatur --- .github/workflows/codespell.yml | 1 + docs/docs/integrations/llms/yuan2.ipynb | 117 +++++++++++ .../langchain_community/llms/__init__.py | 10 + .../langchain_community/llms/yuan2.py | 192 ++++++++++++++++++ .../integration_tests/llms/test_yuan2.py | 33 +++ .../tests/unit_tests/llms/test_imports.py | 1 + 6 files changed, 354 insertions(+) create mode 100644 docs/docs/integrations/llms/yuan2.ipynb create mode 100644 libs/community/langchain_community/llms/yuan2.py create mode 100644 libs/community/tests/integration_tests/llms/test_yuan2.py diff --git a/.github/workflows/codespell.yml b/.github/workflows/codespell.yml index a5a63b996de..738e0643346 100644 --- a/.github/workflows/codespell.yml +++ b/.github/workflows/codespell.yml @@ -34,3 +34,4 @@ jobs: with: skip: guide_imports.json ignore_words_list: ${{ steps.extract_ignore_words.outputs.ignore_words_list }} + exclude_file: libs/community/langchain_community/llms/yuan2.py diff --git a/docs/docs/integrations/llms/yuan2.ipynb b/docs/docs/integrations/llms/yuan2.ipynb new file mode 100644 index 00000000000..06e45df81df --- /dev/null +++ b/docs/docs/integrations/llms/yuan2.ipynb @@ -0,0 +1,117 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": { + "collapsed": false, + "pycharm": { + "name": "#%% md\n" + } + }, + "source": [ + "# Yuan2.0\n", + "\n", + "[Yuan2.0](https://github.com/IEIT-Yuan/Yuan-2.0) is a new generation Fundamental Large Language Model developed by IEIT System. We have published all three models, Yuan 2.0-102B, Yuan 2.0-51B, and Yuan 2.0-2B. And we provide relevant scripts for pretraining, fine-tuning, and inference services for other developers. Yuan2.0 is based on Yuan1.0, utilizing a wider range of high-quality pre training data and instruction fine-tuning datasets to enhance the model's understanding of semantics, mathematics, reasoning, code, knowledge, and other aspects.\n", + "\n", + "This example goes over how to use LangChain to interact with `Yuan2.0`(2B/51B/102B) Inference for text generation.\n", + "\n", + "Yuan2.0 set up an inference service so user just need request the inference api to get result, which is introduced in [Yuan2.0 Inference-Server](https://github.com/IEIT-Yuan/Yuan-2.0/blob/main/docs/inference_server.md)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain_community.llms.yuan2 import Yuan2" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "# default infer_api for a local deployed Yuan2.0 inference server\n", + "infer_api = \"http://127.0.0.1:8000\"\n", + "\n", + "# direct access endpoint in a proxied environment\n", + "# import os\n", + "# os.environ[\"no_proxy\"]=\"localhost,127.0.0.1,::1\"\n", + "\n", + "yuan_llm = Yuan2(\n", + " infer_api=infer_api,\n", + " max_tokens=2048,\n", + " temp=1.0,\n", + " top_p=0.9,\n", + " top_k=40,\n", + " use_history=False,\n", + ")\n", + "\n", + "# turn on use_history only when you want the Yuan2.0 to keep track of the conversation history\n", + "# and send the accumulated context to the backend model api, which make it stateful. By default it is stateless.\n", + "# llm.use_history = True" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": { + "pycharm": { + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "question = \"请介绍一下中国。\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "pycharm": { + "is_executing": true, + "name": "#%%\n" + } + }, + "outputs": [], + "source": [ + "print(yuan_llm(question))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} \ No newline at end of file diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index 2ccf4199c41..9adaf0af1be 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -570,6 +570,12 @@ def _import_yandex_gpt() -> Any: return YandexGPT +def _import_yuan2() -> Any: + from langchain_community.llms.yuan2 import Yuan2 + + return Yuan2 + + def _import_volcengine_maas() -> Any: from langchain_community.llms.volcengine_maas import VolcEngineMaasLLM @@ -753,6 +759,8 @@ def __getattr__(name: str) -> Any: return _import_xinference() elif name == "YandexGPT": return _import_yandex_gpt() + elif name == "Yuan2": + return _import_yuan2() elif name == "VolcEngineMaasLLM": return _import_volcengine_maas() elif name == "type_to_cls_dict": @@ -851,6 +859,7 @@ __all__ = [ "JavelinAIGateway", "QianfanLLMEndpoint", "YandexGPT", + "Yuan2", "VolcEngineMaasLLM", ] @@ -939,5 +948,6 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "javelin-ai-gateway": _import_javelin_ai_gateway, "qianfan_endpoint": _import_baidu_qianfan_endpoint, "yandex_gpt": _import_yandex_gpt, + "yuan2": _import_yuan2, "VolcEngineMaasLLM": _import_volcengine_maas, } diff --git a/libs/community/langchain_community/llms/yuan2.py b/libs/community/langchain_community/llms/yuan2.py new file mode 100644 index 00000000000..360418a0a0f --- /dev/null +++ b/libs/community/langchain_community/llms/yuan2.py @@ -0,0 +1,192 @@ +import json +import logging +from typing import Any, Dict, List, Mapping, Optional, Set + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.pydantic_v1 import Field + +from langchain_community.llms.utils import enforce_stop_tokens + +logger = logging.getLogger(__name__) + + +class Yuan2(LLM): + """Yuan2.0 language models. + + Example: + .. code-block:: python + + yuan_llm = Yuan2( + infer_api="http://127.0.0.1:8000/yuan", + max_tokens=1024, + temp=1.0, + top_p=0.9, + top_k=40, + ) + print(yuan_llm) + print(yuan_llm("你是谁?")) + """ + + infer_api: str = "http://127.0.0.1:8000/yuan" + """Yuan2.0 inference api""" + + max_tokens: int = Field(1024, alias="max_token") + """Token context window.""" + + temp: Optional[float] = 0.7 + """The temperature to use for sampling.""" + + top_p: Optional[float] = 0.9 + """The top-p value to use for sampling.""" + + top_k: Optional[int] = 40 + """The top-k value to use for sampling.""" + + do_sample: bool = False + """The do_sample is a Boolean value that determines whether + to use the sampling method during text generation. + """ + + echo: Optional[bool] = False + """Whether to echo the prompt.""" + + stop: Optional[List[str]] = [] + """A list of strings to stop generation when encountered.""" + + repeat_last_n: Optional[int] = 64 + "Last n tokens to penalize" + + repeat_penalty: Optional[float] = 1.18 + """The penalty to apply to repeated tokens.""" + + streaming: bool = False + """Whether to stream the results or not.""" + + history: List[str] = [] + """History of the conversation""" + + use_history: bool = False + """Whether to use history or not""" + + @property + def _llm_type(self) -> str: + return "Yuan2.0" + + @staticmethod + def _model_param_names() -> Set[str]: + return { + "max_tokens", + "temp", + "top_k", + "top_p", + "do_sample", + } + + def _default_params(self) -> Dict[str, Any]: + return { + "infer_api": self.infer_api, + "max_tokens": self.max_tokens, + "temp": self.temp, + "top_k": self.top_k, + "top_p": self.top_p, + "do_sample": self.do_sample, + "use_history": self.use_history, + } + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + return { + "model": self._llm_type, + **self._default_params(), + **{ + k: v for k, v in self.__dict__.items() if k in self._model_param_names() + }, + } + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to a Yuan2.0 LLM inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = yuan_llm("你能做什么?") + """ + + if self.use_history: + self.history.append(prompt) + input = "".join(self.history) + else: + input = prompt + + headers = {"Content-Type": "application/json"} + data = json.dumps( + { + "ques_list": [{"id": "000", "ques": input}], + "tokens_to_generate": self.max_tokens, + "temperature": self.temp, + "top_p": self.top_p, + "top_k": self.top_k, + "do_sample": self.do_sample, + } + ) + + logger.debug("Yuan2.0 prompt:", input) + + # call api + try: + response = requests.put(self.infer_api, headers=headers, data=data) + except requests.exceptions.RequestException as e: + raise ValueError(f"Error raised by inference api: {e}") + + logger.debug(f"Yuan2.0 response: {response}") + + if response.status_code != 200: + raise ValueError(f"Failed with response: {response}") + try: + resp = response.json() + + if resp["errCode"] != "0": + raise ValueError( + f"Failed with error code [{resp['errCode']}], " + f"error message: [{resp['errMessage']}]" + ) + + if "resData" in resp: + if len(resp["resData"]["output"]) >= 0: + generate_text = resp["resData"]["output"][0]["ans"] + else: + raise ValueError("No output found in response.") + else: + raise ValueError("No resData found in response.") + + except requests.exceptions.JSONDecodeError as e: + raise ValueError( + f"Error raised during decoding response from inference api: {e}." + f"\nResponse: {response.text}" + ) + + if stop is not None: + generate_text = enforce_stop_tokens(generate_text, stop) + + # support multi-turn chat + if self.use_history: + self.history.append(generate_text) + + logger.debug(f"history: {self.history}") + return generate_text diff --git a/libs/community/tests/integration_tests/llms/test_yuan2.py b/libs/community/tests/integration_tests/llms/test_yuan2.py new file mode 100644 index 00000000000..94667819e66 --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_yuan2.py @@ -0,0 +1,33 @@ +"""Test Yuan2.0 API wrapper.""" +from langchain_core.outputs import LLMResult + +from langchain_community.llms import Yuan2 + + +def test_yuan2_call_method() -> None: + """Test valid call to Yuan2.0.""" + llm = Yuan2( + infer_api="http://127.0.0.1:8000/yuan", + max_tokens=1024, + temp=1.0, + top_p=0.9, + top_k=40, + use_history=False, + ) + output = llm("写一段快速排序算法。") + assert isinstance(output, str) + + +def test_yuan2_generate_method() -> None: + """Test valid call to Yuan2.0 inference api.""" + llm = Yuan2( + infer_api="http://127.0.0.1:8000/yuan", + max_tokens=1024, + temp=1.0, + top_p=0.9, + top_k=40, + use_history=False, + ) + output = llm.generate(["who are you?"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list) diff --git a/libs/community/tests/unit_tests/llms/test_imports.py b/libs/community/tests/unit_tests/llms/test_imports.py index 9c7abb11f83..6c8c6504e36 100644 --- a/libs/community/tests/unit_tests/llms/test_imports.py +++ b/libs/community/tests/unit_tests/llms/test_imports.py @@ -87,6 +87,7 @@ EXPECT_ALL = [ "JavelinAIGateway", "QianfanLLMEndpoint", "YandexGPT", + "Yuan2", "VolcEngineMaasLLM", "WatsonxLLM", ]