diff --git a/docs/docs/integrations/llms/chatglm.ipynb b/docs/docs/integrations/llms/chatglm.ipynb index 53153a184ac..12de26dacfe 100644 --- a/docs/docs/integrations/llms/chatglm.ipynb +++ b/docs/docs/integrations/llms/chatglm.ipynb @@ -11,7 +11,102 @@ "\n", "[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n", "\n", - "This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n", + "[ChatGLM3](https://github.com/THUDM/ChatGLM3) is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Install required dependencies\n", + "\n", + "%pip install -qU langchain langchain-community" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ChatGLM3\n", + "\n", + "This examples goes over how to use LangChain to interact with ChatGLM3-6B Inference for text completion." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain.prompts import PromptTemplate\n", + "from langchain.schema.messages import AIMessage\n", + "from langchain_community.llms.chatglm3 import ChatGLM3" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "template = \"\"\"{question}\"\"\"\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [], + "source": [ + "endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n", + "\n", + "messages = [\n", + " AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"),\n", + " AIMessage(content=\"欢迎问我任何问题。\"),\n", + "]\n", + "\n", + "llm = ChatGLM3(\n", + " endpoint_url=endpoint_url,\n", + " max_tokens=80000,\n", + " prefix_messages=messages,\n", + " top_p=0.9,\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'北京和上海是中国两个不同的城市,它们在很多方面都有所不同。\\n\\n北京是中国的首都,也是历史悠久的城市之一。它有着丰富的历史文化遗产,如故宫、颐和园等,这些景点吸引着众多游客前来观光。北京也是一个政治、文化和教育中心,有很多政府机构和学术机构总部设在北京。\\n\\n上海则是一个现代化的城市,它是中国的经济中心之一。上海拥有许多高楼大厦和国际化的金融机构,是中国最国际化的城市之一。上海也是一个美食和购物天堂,有许多著名的餐厅和购物中心。\\n\\n北京和上海的气候也不同。北京属于温带大陆性气候,冬季寒冷干燥,夏季炎热多风;而上海属于亚热带季风气候,四季分明,春秋宜人。\\n\\n北京和上海有很多不同之处,但都是中国非常重要的城市,每个城市都有自己独特的魅力和特色。'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "question = \"北京和上海两座城市有什么不同?\"\n", + "\n", + "llm_chain.run(question)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## ChatGLM and ChatGLM2\n", + "\n", + "The following example shows how to use LangChain to interact with the ChatGLM2-6B Inference to complete text.\n", "ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both." ] }, @@ -106,7 +201,7 @@ ], "metadata": { "kernelspec": { - "display_name": "langchain-dev", + "display_name": "Python 3 (ipykernel)", "language": "python", "name": "python3" }, @@ -120,9 +215,9 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.12" + "version": "3.9.1" } }, "nbformat": 4, - "nbformat_minor": 2 + "nbformat_minor": 4 } diff --git a/libs/community/langchain_community/llms/chatglm3.py b/libs/community/langchain_community/llms/chatglm3.py new file mode 100644 index 00000000000..0582fc58f08 --- /dev/null +++ b/libs/community/langchain_community/llms/chatglm3.py @@ -0,0 +1,151 @@ +import json +import logging +from typing import Any, List, Optional, Union + +from langchain_core.callbacks import CallbackManagerForLLMRun +from langchain_core.language_models.llms import LLM +from langchain_core.messages import ( + AIMessage, + BaseMessage, + FunctionMessage, + HumanMessage, + SystemMessage, +) +from langchain_core.pydantic_v1 import Field + +from langchain_community.llms.utils import enforce_stop_tokens + +logger = logging.getLogger(__name__) +HEADERS = {"Content-Type": "application/json"} +DEFAULT_TIMEOUT = 30 + + +def _convert_message_to_dict(message: BaseMessage) -> dict: + if isinstance(message, HumanMessage): + message_dict = {"role": "user", "content": message.content} + elif isinstance(message, AIMessage): + message_dict = {"role": "assistant", "content": message.content} + elif isinstance(message, SystemMessage): + message_dict = {"role": "system", "content": message.content} + elif isinstance(message, FunctionMessage): + message_dict = {"role": "function", "content": message.content} + else: + raise ValueError(f"Got unknown type {message}") + return message_dict + + +class ChatGLM3(LLM): + """ChatGLM3 LLM service.""" + + model_name: str = Field(default="chatglm3-6b", alias="model") + endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions" + """Endpoint URL to use.""" + model_kwargs: Optional[dict] = None + """Keyword arguments to pass to the model.""" + max_tokens: int = 20000 + """Max token allowed to pass to the model.""" + temperature: float = 0.1 + """LLM model temperature from 0 to 10.""" + top_p: float = 0.7 + """Top P for nucleus sampling from 0 to 1""" + prefix_messages: List[BaseMessage] = Field(default_factory=list) + """Series of messages for Chat input.""" + streaming: bool = False + """Whether to stream the results or not.""" + http_client: Union[Any, None] = None + timeout: int = DEFAULT_TIMEOUT + + @property + def _llm_type(self) -> str: + return "chat_glm_3" + + @property + def _invocation_params(self) -> dict: + """Get the parameters used to invoke the model.""" + params = { + "model": self.model_name, + "temperature": self.temperature, + "max_tokens": self.max_tokens, + "top_p": self.top_p, + "stream": self.streaming, + } + return {**params, **(self.model_kwargs or {})} + + @property + def client(self) -> Any: + import httpx + + return self.http_client or httpx.Client(timeout=self.timeout) + + def _get_payload(self, prompt: str) -> dict: + params = self._invocation_params + messages = self.prefix_messages + [HumanMessage(content=prompt)] + params.update( + { + "messages": [_convert_message_to_dict(m) for m in messages], + } + ) + return params + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Call out to a ChatGLM3 LLM inference endpoint. + + Args: + prompt: The prompt to pass into the model. + stop: Optional list of stop words to use when generating. + + Returns: + The string generated by the model. + + Example: + .. code-block:: python + + response = chatglm_llm("Who are you?") + """ + import httpx + + payload = self._get_payload(prompt) + logger.debug(f"ChatGLM3 payload: {payload}") + + try: + response = self.client.post( + self.endpoint_url, headers=HEADERS, json=payload + ) + except httpx.NetworkError as e: + raise ValueError(f"Error raised by inference endpoint: {e}") + + logger.debug(f"ChatGLM3 response: {response}") + + if response.status_code != 200: + raise ValueError(f"Failed with response: {response}") + + try: + parsed_response = response.json() + + if isinstance(parsed_response, dict): + content_keys = "choices" + if content_keys in parsed_response: + choices = parsed_response[content_keys] + if len(choices): + text = choices[0]["message"]["content"] + else: + raise ValueError(f"No content in response : {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + except json.JSONDecodeError as e: + raise ValueError( + f"Error raised during decoding response from inference endpoint: {e}." + f"\nResponse: {response.text}" + ) + + if stop is not None: + text = enforce_stop_tokens(text, stop) + + return text diff --git a/libs/community/poetry.lock b/libs/community/poetry.lock index ea439ba7d65..f674b1d20a6 100644 --- a/libs/community/poetry.lock +++ b/libs/community/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. [[package]] name = "aenum" @@ -3433,6 +3433,7 @@ files = [ {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"}, {file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"}, {file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"}, + {file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"}, {file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"}, @@ -3943,7 +3944,7 @@ files = [ [[package]] name = "langchain-core" -version = "0.1.16" +version = "0.1.17" description = "Building applications with LLMs through composability" optional = false python-versions = ">=3.8.1,<4.0" @@ -6222,7 +6223,6 @@ files = [ {file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"}, {file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"}, {file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"}, - {file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"}, {file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"}, @@ -9247,9 +9247,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p [extras] cli = ["typer"] -extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] +extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"] [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "42d012441d7b42d273e11708b7e12308fc56b169d4d56c4c2511e7469743a983" +content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79" diff --git a/libs/community/pyproject.toml b/libs/community/pyproject.toml index 6691ca8b471..e048b7b304b 100644 --- a/libs/community/pyproject.toml +++ b/libs/community/pyproject.toml @@ -87,6 +87,7 @@ datasets = {version = "^2.15.0", optional = true} azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true} oracle-ads = {version = "^2.9.1", optional = true} zhipuai = {version = "^1.0.7", optional = true} +httpx = {version = "^0.24.1", optional = true} elasticsearch = {version = "^8.12.0", optional = true} hdbcli = {version = "^2.19.21", optional = true} oci = {version = "^2.119.1", optional = true} @@ -253,6 +254,7 @@ extended_testing = [ "azure-ai-documentintelligence", "oracle-ads", "zhipuai", + "httpx", "elasticsearch", "hdbcli", "oci",