From f8f2649f121f75de6d274c141e4f392a309b26af Mon Sep 17 00:00:00 2001 From: baichuan-assistant <139942740+baichuan-assistant@users.noreply.github.com> Date: Tue, 30 Jan 2024 12:08:24 +0800 Subject: [PATCH] community: Add Baichuan LLM to community (#16724) Replace this entire comment with: - **Description:** Add Baichuan LLM to integration/llm, also updated related docs. Co-authored-by: BaiChuanHelper --- docs/docs/integrations/chat/baichuan.ipynb | 16 ++- docs/docs/integrations/llms/baichuan.ipynb | 97 +++++++++++++++++++ docs/docs/integrations/providers/baichuan.mdx | 5 +- .../text_embedding/baichuan.ipynb | 81 ++++++++++------ .../langchain_community/llms/__init__.py | 8 ++ .../langchain_community/llms/baichuan.py | 95 ++++++++++++++++++ .../integration_tests/llms/test_baichuan.py | 19 ++++ 7 files changed, 289 insertions(+), 32 deletions(-) create mode 100644 docs/docs/integrations/llms/baichuan.ipynb create mode 100644 libs/community/langchain_community/llms/baichuan.py create mode 100644 libs/community/tests/integration_tests/llms/test_baichuan.py diff --git a/docs/docs/integrations/chat/baichuan.ipynb b/docs/docs/integrations/chat/baichuan.ipynb index 14f2d0d4c98..3b184b953fa 100644 --- a/docs/docs/integrations/chat/baichuan.ipynb +++ b/docs/docs/integrations/chat/baichuan.ipynb @@ -51,10 +51,18 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "or you can set `api_key` in your environment variables\n", - "```bash\n", - "export BAICHUAN_API_KEY=YOUR_API_KEY\n", - "```" + "Alternatively, you can set your API key with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"BAICHUAN_API_KEY\"] = \"YOUR_API_KEY\"" ] }, { diff --git a/docs/docs/integrations/llms/baichuan.ipynb b/docs/docs/integrations/llms/baichuan.ipynb new file mode 100644 index 00000000000..7c92d17717e --- /dev/null +++ b/docs/docs/integrations/llms/baichuan.ipynb @@ -0,0 +1,97 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Baichuan LLM\n", + "Baichuan Inc. (https://www.baichuan-ai.com/) is a Chinese startup in the era of AGI, dedicated to addressing fundamental human needs: Efficiency, Health, and Happiness." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisite\n", + "An API key is required to access Baichuan LLM API. Visit https://platform.baichuan-ai.com/ to get your API key." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Use Baichuan LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "\n", + "os.environ[\"BAICHUAN_API_KEY\"] = \"YOUR_API_KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain_community.llms import BaichuanLLM\n", + "\n", + "# Load the model\n", + "llm = BaichuanLLM()\n", + "\n", + "res = llm(\"What's your name?\")\n", + "print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "res = llm.generate(prompts=[\"你好!\"])\n", + "res" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "for res in llm.stream(\"Who won the second world war?\"):\n", + " print(res)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import asyncio\n", + "\n", + "\n", + "async def run_aio_stream():\n", + " async for res in llm.astream(\"Write a poem about the sun.\"):\n", + " print(res)\n", + "\n", + "\n", + "asyncio.run(run_aio_stream())" + ] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/docs/docs/integrations/providers/baichuan.mdx b/docs/docs/integrations/providers/baichuan.mdx index b73e74b4579..ddac4cf65ea 100644 --- a/docs/docs/integrations/providers/baichuan.mdx +++ b/docs/docs/integrations/providers/baichuan.mdx @@ -6,8 +6,11 @@ Visit us at https://www.baichuan-ai.com/. Register and get an API key if you are trying out our APIs. +## Baichuan LLM Endpoint +An example is available at [example](/docs/integrations/llms/baichuan) + ## Baichuan Chat Model An example is available at [example](/docs/integrations/chat/baichuan). ## Baichuan Text Embedding Model -An example is available at [example] (/docs/integrations/text_embedding/baichuan) +An example is available at [example](/docs/integrations/text_embedding/baichuan) diff --git a/docs/docs/integrations/text_embedding/baichuan.ipynb b/docs/docs/integrations/text_embedding/baichuan.ipynb index 8b5d57a2ddb..3aa2320448b 100644 --- a/docs/docs/integrations/text_embedding/baichuan.ipynb +++ b/docs/docs/integrations/text_embedding/baichuan.ipynb @@ -6,46 +6,77 @@ "source": [ "# Baichuan Text Embeddings\n", "\n", - "As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB (Chinese Multi-Task Embedding Benchmark) leaderboard.\n", - "\n", - "Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard\n", - "\n", + "As of today (Jan 25th, 2024) BaichuanTextEmbeddings ranks #1 in C-MTEB (Chinese Multi-Task Embedding Benchmark) leaderboard.\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Leaderboard (Under Overall -> Chinese section): https://huggingface.co/spaces/mteb/leaderboard" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ "Official Website: https://platform.baichuan-ai.com/docs/text-Embedding\n", - "An API-key is required to use this embedding model. You can get one by registering at https://platform.baichuan-ai.com/docs/text-Embedding.\n", - "BaichuanTextEmbeddings support 512 token window and preduces vectors with 1024 dimensions. \n", "\n", - "Please NOTE that BaichuanTextEmbeddings only supports Chinese text embedding. Multi-language support is coming soon.\n" + "An API key is required to use this embedding model. You can get one by registering at https://platform.baichuan-ai.com/docs/text-Embedding." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "BaichuanTextEmbeddings support 512 token window and preduces vectors with 1024 dimensions. " + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Please NOTE that BaichuanTextEmbeddings only supports Chinese text embedding. Multi-language support is coming soon." ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, + "metadata": {}, "outputs": [], "source": [ "from langchain_community.embeddings import BaichuanTextEmbeddings\n", "\n", - "# Place your Baichuan API-key here.\n", - "embeddings = BaichuanTextEmbeddings(baichuan_api_key=\"sk-*\")\n", - "\n", - "text_1 = \"今天天气不错\"\n", - "text_2 = \"今天阳光很好\"" + "embeddings = BaichuanTextEmbeddings(baichuan_api_key=\"sk-*\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Alternatively, you can set API key this way:" ] }, { "cell_type": "code", "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, + "metadata": {}, "outputs": [], "source": [ + "import os\n", + "\n", + "os.environ[\"BAICHUAN_API_KEY\"] = \"YOUR_API_KEY\"" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "text_1 = \"今天天气不错\"\n", + "text_2 = \"今天阳光很好\"\n", + "\n", "query_result = embeddings.embed_query(text_1)\n", "query_result" ] @@ -53,11 +84,7 @@ { "cell_type": "code", "execution_count": null, - "metadata": { - "vscode": { - "languageId": "plaintext" - } - }, + "metadata": {}, "outputs": [], "source": [ "doc_result = embeddings.embed_documents([text_1, text_2])\n", diff --git a/libs/community/langchain_community/llms/__init__.py b/libs/community/langchain_community/llms/__init__.py index 5a08670333e..2ccf4199c41 100644 --- a/libs/community/langchain_community/llms/__init__.py +++ b/libs/community/langchain_community/llms/__init__.py @@ -76,6 +76,12 @@ def _import_azureml_endpoint() -> Any: return AzureMLOnlineEndpoint +def _import_baichuan() -> Any: + from langchain_community.llms.baichuan import BaichuanLLM + + return BaichuanLLM + + def _import_baidu_qianfan_endpoint() -> Any: from langchain_community.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint @@ -589,6 +595,8 @@ def __getattr__(name: str) -> Any: return _import_aviary() elif name == "AzureMLOnlineEndpoint": return _import_azureml_endpoint() + elif name == "Baichuan": + return _import_baichuan() elif name == "QianfanLLMEndpoint": return _import_baidu_qianfan_endpoint() elif name == "Banana": diff --git a/libs/community/langchain_community/llms/baichuan.py b/libs/community/langchain_community/llms/baichuan.py new file mode 100644 index 00000000000..2627b81bd24 --- /dev/null +++ b/libs/community/langchain_community/llms/baichuan.py @@ -0,0 +1,95 @@ +from __future__ import annotations + +import json +import logging +from typing import Any, Dict, List, Optional + +import requests +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.pydantic_v1 import Field, SecretStr, root_validator +from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env + +from langchain_community.llms.utils import enforce_stop_tokens + +logger = logging.getLogger(__name__) + + +class BaichuanLLM(LLM): + # TODO: Adding streaming support. + """Wrapper around Baichuan large language models.""" + + model: str = "Baichuan2-Turbo-192k" + """ + Other models are available at https://platform.baichuan-ai.com/docs/api. + """ + temperature: float = 0.3 + top_p: float = 0.95 + timeout: int = 60 + model_kwargs: Dict[str, Any] = Field(default_factory=dict) + + baichuan_api_host: Optional[str] = None + baichuan_api_key: Optional[SecretStr] = None + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + values["baichuan_api_key"] = convert_to_secret_str( + get_from_dict_or_env(values, "baichuan_api_key", "BAICHUAN_API_KEY") + ) + values["baichuan_api_host"] = get_from_dict_or_env( + values, + "baichuan_api_host", + "BAICHUAN_API_HOST", + default="https://api.baichuan-ai.com/v1/chat/completions", + ) + return values + + @property + def _default_params(self) -> Dict[str, Any]: + return { + "model": self.model, + "temperature": self.temperature, + "top_p": self.top_p, + **self.model_kwargs, + } + + def _post(self, request: Any) -> Any: + headers = { + "Content-Type": "application/json", + "Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", + } + try: + response = requests.post( + self.baichuan_api_host, + headers=headers, + json=request, + timeout=self.timeout, + ) + + if response.status_code == 200: + parsed_json = json.loads(response.text) + return parsed_json["choices"][0]["message"]["content"] + else: + response.raise_for_status() + except Exception as e: + raise ValueError(f"An error has occurred: {e}") + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + request = self._default_params + request["messages"] = [{"role": "user", "content": prompt}] + request.update(kwargs) + text = self._post(request) + if stop is not None: + text = enforce_stop_tokens(text, stop) + return text + + @property + def _llm_type(self) -> str: + """Return type of chat_model.""" + return "baichuan-llm" diff --git a/libs/community/tests/integration_tests/llms/test_baichuan.py b/libs/community/tests/integration_tests/llms/test_baichuan.py new file mode 100644 index 00000000000..330e9fe8293 --- /dev/null +++ b/libs/community/tests/integration_tests/llms/test_baichuan.py @@ -0,0 +1,19 @@ +"""Test Baichuan LLM Endpoint.""" +from langchain_core.outputs import LLMResult + +from langchain_community.llms.baichuan import BaichuanLLM + + +def test_call() -> None: + """Test valid call to baichuan.""" + llm = BaichuanLLM() + output = llm("Who won the second world war?") + assert isinstance(output, str) + + +def test_generate() -> None: + """Test valid call to baichuan.""" + llm = BaichuanLLM() + output = llm.generate(["Who won the second world war?"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list)