From fa0a9e502a526d3a24ae6b76bf210da4d68ecf3c Mon Sep 17 00:00:00 2001 From: Liu Ming <83799887@qq.com> Date: Mon, 17 Jul 2023 22:27:17 +0800 Subject: [PATCH] Add LLM for ChatGLM(2)-6B API (#7774) Description: Add LLM for ChatGLM-6B & ChatGLM2-6B API Related Issue: Will the langchain support ChatGLM? #4766 Add support for selfhost models like ChatGLM or transformer models #1780 Dependencies: No extra library install required. It wraps api call to a ChatGLM(2)-6B server(start with api.py), so api endpoint is required to run. Tag maintainer: @mlot Any comments on this PR would be appreciated. --------- Co-authored-by: mlot Co-authored-by: Bagatur --- .../models/llms/integrations/chatglm.ipynb | 121 +++++++++++++++++ langchain/llms/__init__.py | 3 + langchain/llms/chatglm.py | 123 ++++++++++++++++++ tests/integration_tests/llms/test_chatglm.py | 18 +++ 4 files changed, 265 insertions(+) create mode 100644 docs/extras/modules/model_io/models/llms/integrations/chatglm.ipynb create mode 100644 langchain/llms/chatglm.py create mode 100644 tests/integration_tests/llms/test_chatglm.py diff --git a/docs/extras/modules/model_io/models/llms/integrations/chatglm.ipynb b/docs/extras/modules/model_io/models/llms/integrations/chatglm.ipynb new file mode 100644 index 00000000000..3341563551e --- /dev/null +++ b/docs/extras/modules/model_io/models/llms/integrations/chatglm.ipynb @@ -0,0 +1,121 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# ChatGLM\n", + "\n", + "[ChatGLM-6B](https://github.com/THUDM/ChatGLM-6B) is an open bilingual language model based on General Language Model (GLM) framework, with 6.2 billion parameters. With the quantization technique, users can deploy locally on consumer-grade graphics cards (only 6GB of GPU memory is required at the INT4 quantization level). \n", + "\n", + "[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n", + "\n", + "This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n", + "ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both." + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import ChatGLM\n", + "from langchain import PromptTemplate, LLMChain\n", + "\n", + "# import os" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"{question}\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [], + "source": [ + "# default endpoint_url for a local deployed ChatGLM api server\n", + "endpoint_url = \"http://127.0.0.1:8000\"\n", + "\n", + "# direct access endpoint in a proxied environment\n", + "# os.environ['NO_PROXY'] = '127.0.0.1'\n", + "\n", + "llm = ChatGLM(\n", + " endpoint_url=endpoint_url,\n", + " max_token=80000,\n", + " history=[[\"我将从美国到中国来旅游,出行前希望了解中国的城市\", \"欢迎问我任何问题。\"]],\n", + " top_p=0.9,\n", + " model_kwargs={\"sample_model_args\": False},\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "ChatGLM payload: {'prompt': '北京和上海两座城市有什么不同?', 'temperature': 0.1, 'history': [['我将从美国到中国来旅游,出行前希望了解中国的城市', '欢迎问我任何问题。']], 'max_length': 80000, 'top_p': 0.9, 'sample_model_args': False}\n" + ] + }, + { + "data": { + "text/plain": [ + "'北京和上海是中国的两个首都,它们在许多方面都有所不同。\\n\\n北京是中国的政治和文化中心,拥有悠久的历史和灿烂的文化。它是中国最重要的古都之一,也是中国历史上最后一个封建王朝的都城。北京有许多著名的古迹和景点,例如紫禁城、天安门广场和长城等。\\n\\n上海是中国最现代化的城市之一,也是中国商业和金融中心。上海拥有许多国际知名的企业和金融机构,同时也有许多著名的景点和美食。上海的外滩是一个历史悠久的商业区,拥有许多欧式建筑和餐馆。\\n\\n除此之外,北京和上海在交通和人口方面也有很大差异。北京是中国的首都,人口众多,交通拥堵问题较为严重。而上海是中国的商业和金融中心,人口密度较低,交通相对较为便利。\\n\\n总的来说,北京和上海是两个拥有独特魅力和特点的城市,可以根据自己的兴趣和时间来选择前往其中一座城市旅游。'" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "question = \"北京和上海两座城市有什么不同?\"\n", + "\n", + "llm_chain.run(question)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "langchain-dev", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.12" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/langchain/llms/__init__.py b/langchain/llms/__init__.py index 57902853ad4..1e32346f26b 100644 --- a/langchain/llms/__init__.py +++ b/langchain/llms/__init__.py @@ -14,6 +14,7 @@ from langchain.llms.baseten import Baseten from langchain.llms.beam import Beam from langchain.llms.bedrock import Bedrock from langchain.llms.cerebriumai import CerebriumAI +from langchain.llms.chatglm import ChatGLM from langchain.llms.clarifai import Clarifai from langchain.llms.cohere import Cohere from langchain.llms.ctransformers import CTransformers @@ -69,6 +70,7 @@ __all__ = [ "Bedrock", "CTransformers", "CerebriumAI", + "ChatGLM", "Clarifai", "Cohere", "Databricks", @@ -125,6 +127,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = { "baseten": Baseten, "beam": Beam, "cerebriumai": CerebriumAI, + "chat_glm": ChatGLM, "clarifai": Clarifai, "cohere": Cohere, "ctransformers": CTransformers, diff --git a/langchain/llms/chatglm.py b/langchain/llms/chatglm.py new file mode 100644 index 00000000000..2ee3c01f0c0 --- /dev/null +++ b/langchain/llms/chatglm.py @@ -0,0 +1,123 @@ +from typing import Any, List, Mapping, Optional + +import requests + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM +from langchain.llms.utils import enforce_stop_tokens + + +class ChatGLM(LLM): + """Wrapper around ChatGLM's LLM inference service. + + Example: + .. code-block:: python + + from langchain.llms import ChatGLM + endpoint_url = ( + "http://127.0.0.1:8000" + ) + ChatGLM_llm = ChatGLM( + endpoint_url=endpoint_url + ) + """ + + endpoint_url: str = "http://127.0.0.1:8000/" + """Endpoint URL to use.""" + model_kwargs: Optional[dict] = None + """Key word arguments to pass to the model.""" + max_token: int = 20000 + """Max token allowed to pass to the model.""" + temperature: float = 0.1 + """LLM model temperature from 0 to 10.""" + history: List[List] = [] + """History of the conversation""" + top_p: float = 0.7 + """Top P for nucleus sampling from 0 to 1""" + + @property + def _llm_type(self) -> str: + return "chat_glm" + + @property + def _identifying_params(self) -> Mapping[str, Any]: + """Get the identifying parameters.""" + _model_kwargs = self.model_kwargs or {} + return { + **{"endpoint_url": self.endpoint_url}, + **{"model_kwargs": _model_kwargs}, + } + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to a ChatGLM LLM inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = chatglm_llm("Who are you?") + """ + + _model_kwargs = self.model_kwargs or {} + + # HTTP headers for authorization + headers = {"Content-Type": "application/json"} + + payload = { + "prompt": prompt, + "temperature": self.temperature, + "history": self.history, + "max_length": self.max_token, + "top_p": self.top_p, + } + payload.update(_model_kwargs) + payload.update(kwargs) + + # print("ChatGLM payload:", payload) + + # call api + try: + response = requests.post(self.endpoint_url, headers=headers, json=payload) + except requests.exceptions.RequestException as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + # print("ChatGLM resp:", response) + + if response.status_code != 200: + raise ValueError(f"Failed with response: {response}") + + try: + parsed_response = response.json() + + # Check if response content does exists + if isinstance(parsed_response, dict): + content_keys = "response" + if content_keys in parsed_response: + text = parsed_response[content_keys] + else: + raise ValueError(f"No content in response : {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + except requests.exceptions.JSONDecodeError as e: + raise ValueError( + f"Error raised during decoding response from inference endpoint: {e}." + f"\nResponse: {response.text}" + ) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + self.history = self.history + [[None, parsed_response["response"]]] + return text diff --git a/tests/integration_tests/llms/test_chatglm.py b/tests/integration_tests/llms/test_chatglm.py new file mode 100644 index 00000000000..a62a76896db --- /dev/null +++ b/tests/integration_tests/llms/test_chatglm.py @@ -0,0 +1,18 @@ +"""Test ChatGLM API wrapper.""" +from langchain.llms.chatglm import ChatGLM +from langchain.schema import LLMResult + + +def test_chatglm_call() -> None: + """Test valid call to chatglm.""" + llm = ChatGLM() + output = llm("北京和上海这两座城市有什么不同?") + assert isinstance(output, str) + + +def test_chatglm_generate() -> None: + """Test valid call to chatglm.""" + llm = ChatGLM() + output = llm.generate(["who are you"]) + assert isinstance(output, LLMResult) + assert isinstance(output.generations, list)