community: Add ChatGLM3 (#15265)

Add [ChatGLM3](https://github.com/THUDM/ChatGLM3) and updated
[chatglm.ipynb](https://python.langchain.com/docs/integrations/llms/chatglm)

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
Bob Lin 2024-01-30 12:30:52 +08:00 committed by GitHub
parent a1ce7ab672
commit 546b757303
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 257 additions and 9 deletions

View File

@ -11,7 +11,102 @@
"\n",
"[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n",
"\n",
"This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n",
"[ChatGLM3](https://github.com/THUDM/ChatGLM3) is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install required dependencies\n",
"\n",
"%pip install -qU langchain langchain-community"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChatGLM3\n",
"\n",
"This examples goes over how to use LangChain to interact with ChatGLM3-6B Inference for text completion."
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"from langchain.chains import LLMChain\n",
"from langchain.prompts import PromptTemplate\n",
"from langchain.schema.messages import AIMessage\n",
"from langchain_community.llms.chatglm3 import ChatGLM3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"{question}\"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
"endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n",
"\n",
"messages = [\n",
" AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"),\n",
" AIMessage(content=\"欢迎问我任何问题。\"),\n",
"]\n",
"\n",
"llm = ChatGLM3(\n",
" endpoint_url=endpoint_url,\n",
" max_tokens=80000,\n",
" prefix_messages=messages,\n",
" top_p=0.9,\n",
")"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'北京和上海是中国两个不同的城市,它们在很多方面都有所不同。\\n\\n北京是中国的首都,也是历史悠久的城市之一。它有着丰富的历史文化遗产,如故宫、颐和园等,这些景点吸引着众多游客前来观光。北京也是一个政治、文化和教育中心,有很多政府机构和学术机构总部设在北京。\\n\\n上海则是一个现代化的城市,它是中国的经济中心之一。上海拥有许多高楼大厦和国际化的金融机构,是中国最国际化的城市之一。上海也是一个美食和购物天堂,有许多著名的餐厅和购物中心。\\n\\n北京和上海的气候也不同。北京属于温带大陆性气候,冬季寒冷干燥,夏季炎热多风;而上海属于亚热带季风气候,四季分明,春秋宜人。\\n\\n北京和上海有很多不同之处,但都是中国非常重要的城市,每个城市都有自己独特的魅力和特色。'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"question = \"北京和上海两座城市有什么不同?\"\n",
"\n",
"llm_chain.run(question)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## ChatGLM and ChatGLM2\n",
"\n",
"The following example shows how to use LangChain to interact with the ChatGLM2-6B Inference to complete text.\n",
"ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both."
]
},
@ -106,7 +201,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "langchain-dev",
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
@ -120,9 +215,9 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.12"
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 2
"nbformat_minor": 4
}

View File

@ -0,0 +1,151 @@
import json
import logging
from typing import Any, List, Optional, Union
from langchain_core.callbacks import CallbackManagerForLLMRun
from langchain_core.language_models.llms import LLM
from langchain_core.messages import (
AIMessage,
BaseMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
)
from langchain_core.pydantic_v1 import Field
from langchain_community.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
HEADERS = {"Content-Type": "application/json"}
DEFAULT_TIMEOUT = 30
def _convert_message_to_dict(message: BaseMessage) -> dict:
if isinstance(message, HumanMessage):
message_dict = {"role": "user", "content": message.content}
elif isinstance(message, AIMessage):
message_dict = {"role": "assistant", "content": message.content}
elif isinstance(message, SystemMessage):
message_dict = {"role": "system", "content": message.content}
elif isinstance(message, FunctionMessage):
message_dict = {"role": "function", "content": message.content}
else:
raise ValueError(f"Got unknown type {message}")
return message_dict
class ChatGLM3(LLM):
"""ChatGLM3 LLM service."""
model_name: str = Field(default="chatglm3-6b", alias="model")
endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Keyword arguments to pass to the model."""
max_tokens: int = 20000
"""Max token allowed to pass to the model."""
temperature: float = 0.1
"""LLM model temperature from 0 to 10."""
top_p: float = 0.7
"""Top P for nucleus sampling from 0 to 1"""
prefix_messages: List[BaseMessage] = Field(default_factory=list)
"""Series of messages for Chat input."""
streaming: bool = False
"""Whether to stream the results or not."""
http_client: Union[Any, None] = None
timeout: int = DEFAULT_TIMEOUT
@property
def _llm_type(self) -> str:
return "chat_glm_3"
@property
def _invocation_params(self) -> dict:
"""Get the parameters used to invoke the model."""
params = {
"model": self.model_name,
"temperature": self.temperature,
"max_tokens": self.max_tokens,
"top_p": self.top_p,
"stream": self.streaming,
}
return {**params, **(self.model_kwargs or {})}
@property
def client(self) -> Any:
import httpx
return self.http_client or httpx.Client(timeout=self.timeout)
def _get_payload(self, prompt: str) -> dict:
params = self._invocation_params
messages = self.prefix_messages + [HumanMessage(content=prompt)]
params.update(
{
"messages": [_convert_message_to_dict(m) for m in messages],
}
)
return params
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to a ChatGLM3 LLM inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = chatglm_llm("Who are you?")
"""
import httpx
payload = self._get_payload(prompt)
logger.debug(f"ChatGLM3 payload: {payload}")
try:
response = self.client.post(
self.endpoint_url, headers=HEADERS, json=payload
)
except httpx.NetworkError as e:
raise ValueError(f"Error raised by inference endpoint: {e}")
logger.debug(f"ChatGLM3 response: {response}")
if response.status_code != 200:
raise ValueError(f"Failed with response: {response}")
try:
parsed_response = response.json()
if isinstance(parsed_response, dict):
content_keys = "choices"
if content_keys in parsed_response:
choices = parsed_response[content_keys]
if len(choices):
text = choices[0]["message"]["content"]
else:
raise ValueError(f"No content in response : {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
except json.JSONDecodeError as e:
raise ValueError(
f"Error raised during decoding response from inference endpoint: {e}."
f"\nResponse: {response.text}"
)
if stop is not None:
text = enforce_stop_tokens(text, stop)
return text

View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
[[package]]
name = "aenum"
@ -3433,6 +3433,7 @@ files = [
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
@ -3943,7 +3944,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.16"
version = "0.1.17"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -6222,7 +6223,6 @@ files = [
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
@ -9247,9 +9247,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
[extras]
cli = ["typer"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "42d012441d7b42d273e11708b7e12308fc56b169d4d56c4c2511e7469743a983"
content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79"

View File

@ -87,6 +87,7 @@ datasets = {version = "^2.15.0", optional = true}
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
oracle-ads = {version = "^2.9.1", optional = true}
zhipuai = {version = "^1.0.7", optional = true}
httpx = {version = "^0.24.1", optional = true}
elasticsearch = {version = "^8.12.0", optional = true}
hdbcli = {version = "^2.19.21", optional = true}
oci = {version = "^2.119.1", optional = true}
@ -253,6 +254,7 @@ extended_testing = [
"azure-ai-documentintelligence",
"oracle-ads",
"zhipuai",
"httpx",
"elasticsearch",
"hdbcli",
"oci",