mirror of
https://github.com/hwchase17/langchain.git
synced 2025-06-21 14:18:52 +00:00
community: Add ChatGLM3 (#15265)
Add [ChatGLM3](https://github.com/THUDM/ChatGLM3) and updated [chatglm.ipynb](https://python.langchain.com/docs/integrations/llms/chatglm) --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
a1ce7ab672
commit
546b757303
@ -11,7 +11,102 @@
|
||||
"\n",
|
||||
"[ChatGLM2-6B](https://github.com/THUDM/ChatGLM2-6B) is the second-generation version of the open-source bilingual (Chinese-English) chat model ChatGLM-6B. It retains the smooth conversation flow and low deployment threshold of the first-generation model, while introducing the new features like better performance, longer context and more efficient inference.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with ChatGLM2-6B Inference for text completion.\n",
|
||||
"[ChatGLM3](https://github.com/THUDM/ChatGLM3) is a new generation of pre-trained dialogue models jointly released by Zhipu AI and Tsinghua KEG. ChatGLM3-6B is the open-source model in the ChatGLM3 series"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Install required dependencies\n",
|
||||
"\n",
|
||||
"%pip install -qU langchain langchain-community"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ChatGLM3\n",
|
||||
"\n",
|
||||
"This examples goes over how to use LangChain to interact with ChatGLM3-6B Inference for text completion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.chains import LLMChain\n",
|
||||
"from langchain.prompts import PromptTemplate\n",
|
||||
"from langchain.schema.messages import AIMessage\n",
|
||||
"from langchain_community.llms.chatglm3 import ChatGLM3"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"{question}\"\"\"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"endpoint_url = \"http://127.0.0.1:8000/v1/chat/completions\"\n",
|
||||
"\n",
|
||||
"messages = [\n",
|
||||
" AIMessage(content=\"我将从美国到中国来旅游,出行前希望了解中国的城市\"),\n",
|
||||
" AIMessage(content=\"欢迎问我任何问题。\"),\n",
|
||||
"]\n",
|
||||
"\n",
|
||||
"llm = ChatGLM3(\n",
|
||||
" endpoint_url=endpoint_url,\n",
|
||||
" max_tokens=80000,\n",
|
||||
" prefix_messages=messages,\n",
|
||||
" top_p=0.9,\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"'北京和上海是中国两个不同的城市,它们在很多方面都有所不同。\\n\\n北京是中国的首都,也是历史悠久的城市之一。它有着丰富的历史文化遗产,如故宫、颐和园等,这些景点吸引着众多游客前来观光。北京也是一个政治、文化和教育中心,有很多政府机构和学术机构总部设在北京。\\n\\n上海则是一个现代化的城市,它是中国的经济中心之一。上海拥有许多高楼大厦和国际化的金融机构,是中国最国际化的城市之一。上海也是一个美食和购物天堂,有许多著名的餐厅和购物中心。\\n\\n北京和上海的气候也不同。北京属于温带大陆性气候,冬季寒冷干燥,夏季炎热多风;而上海属于亚热带季风气候,四季分明,春秋宜人。\\n\\n北京和上海有很多不同之处,但都是中国非常重要的城市,每个城市都有自己独特的魅力和特色。'"
|
||||
]
|
||||
},
|
||||
"execution_count": 4,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
|
||||
"question = \"北京和上海两座城市有什么不同?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"## ChatGLM and ChatGLM2\n",
|
||||
"\n",
|
||||
"The following example shows how to use LangChain to interact with the ChatGLM2-6B Inference to complete text.\n",
|
||||
"ChatGLM-6B and ChatGLM2-6B has the same api specs, so this example should work with both."
|
||||
]
|
||||
},
|
||||
@ -106,7 +201,7 @@
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "langchain-dev",
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
@ -120,9 +215,9 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.10.12"
|
||||
"version": "3.9.1"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
"nbformat_minor": 4
|
||||
}
|
||||
|
151
libs/community/langchain_community/llms/chatglm3.py
Normal file
151
libs/community/langchain_community/llms/chatglm3.py
Normal file
@ -0,0 +1,151 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
class ChatGLM3(LLM):
|
||||
"""ChatGLM3 LLM service."""
|
||||
|
||||
model_name: str = Field(default="chatglm3-6b", alias="model")
|
||||
endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
|
||||
"""Endpoint URL to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
max_tokens: int = 20000
|
||||
"""Max token allowed to pass to the model."""
|
||||
temperature: float = 0.1
|
||||
"""LLM model temperature from 0 to 10."""
|
||||
top_p: float = 0.7
|
||||
"""Top P for nucleus sampling from 0 to 1"""
|
||||
prefix_messages: List[BaseMessage] = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
http_client: Union[Any, None] = None
|
||||
timeout: int = DEFAULT_TIMEOUT
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "chat_glm_3"
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> dict:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
}
|
||||
return {**params, **(self.model_kwargs or {})}
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
import httpx
|
||||
|
||||
return self.http_client or httpx.Client(timeout=self.timeout)
|
||||
|
||||
def _get_payload(self, prompt: str) -> dict:
|
||||
params = self._invocation_params
|
||||
messages = self.prefix_messages + [HumanMessage(content=prompt)]
|
||||
params.update(
|
||||
{
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a ChatGLM3 LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = chatglm_llm("Who are you?")
|
||||
"""
|
||||
import httpx
|
||||
|
||||
payload = self._get_payload(prompt)
|
||||
logger.debug(f"ChatGLM3 payload: {payload}")
|
||||
|
||||
try:
|
||||
response = self.client.post(
|
||||
self.endpoint_url, headers=HEADERS, json=payload
|
||||
)
|
||||
except httpx.NetworkError as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
logger.debug(f"ChatGLM3 response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if isinstance(parsed_response, dict):
|
||||
content_keys = "choices"
|
||||
if content_keys in parsed_response:
|
||||
choices = parsed_response[content_keys]
|
||||
if len(choices):
|
||||
text = choices[0]["message"]["content"]
|
||||
else:
|
||||
raise ValueError(f"No content in response : {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised during decoding response from inference endpoint: {e}."
|
||||
f"\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
10
libs/community/poetry.lock
generated
10
libs/community/poetry.lock
generated
@ -1,4 +1,4 @@
|
||||
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
|
||||
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
|
||||
|
||||
[[package]]
|
||||
name = "aenum"
|
||||
@ -3433,6 +3433,7 @@ files = [
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:227b178b22a7f91ae88525810441791b1ca1fc71c86f03190911793be15cec3d"},
|
||||
{file = "jq-1.6.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:780eb6383fbae12afa819ef676fc93e1548ae4b076c004a393af26a04b460742"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:08ded6467f4ef89fec35b2bf310f210f8cd13fbd9d80e521500889edf8d22441"},
|
||||
{file = "jq-1.6.0-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:49e44ed677713f4115bd5bf2dbae23baa4cd503be350e12a1c1f506b0687848f"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:984f33862af285ad3e41e23179ac4795f1701822473e1a26bf87ff023e5a89ea"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f42264fafc6166efb5611b5d4cb01058887d050a6c19334f6a3f8a13bb369df5"},
|
||||
{file = "jq-1.6.0-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:a67154f150aaf76cc1294032ed588436eb002097dd4fd1e283824bf753a05080"},
|
||||
@ -3943,7 +3944,7 @@ files = [
|
||||
|
||||
[[package]]
|
||||
name = "langchain-core"
|
||||
version = "0.1.16"
|
||||
version = "0.1.17"
|
||||
description = "Building applications with LLMs through composability"
|
||||
optional = false
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
@ -6222,7 +6223,6 @@ files = [
|
||||
{file = "pymongo-4.6.1-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:b8729dbf25eb32ad0dc0b9bd5e6a0d0b7e5c2dc8ec06ad171088e1896b522a74"},
|
||||
{file = "pymongo-4.6.1-cp312-cp312-win32.whl", hash = "sha256:3177f783ae7e08aaf7b2802e0df4e4b13903520e8380915e6337cdc7a6ff01d8"},
|
||||
{file = "pymongo-4.6.1-cp312-cp312-win_amd64.whl", hash = "sha256:00c199e1c593e2c8b033136d7a08f0c376452bac8a896c923fcd6f419e07bdd2"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:6dcc95f4bb9ed793714b43f4f23a7b0c57e4ef47414162297d6f650213512c19"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_i686.whl", hash = "sha256:13552ca505366df74e3e2f0a4f27c363928f3dff0eef9f281eb81af7f29bc3c5"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux1_x86_64.whl", hash = "sha256:77e0df59b1a4994ad30c6d746992ae887f9756a43fc25dec2db515d94cf0222d"},
|
||||
{file = "pymongo-4.6.1-cp37-cp37m-manylinux2014_aarch64.whl", hash = "sha256:3a7f02a58a0c2912734105e05dedbee4f7507e6f1bd132ebad520be0b11d46fd"},
|
||||
@ -9247,9 +9247,9 @@ testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "p
|
||||
|
||||
[extras]
|
||||
cli = ["typer"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "assemblyai", "atlassian-python-api", "azure-ai-documentintelligence", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "cohere", "dashvector", "databricks-vectorsearch", "datasets", "dgml-utils", "elasticsearch", "esprima", "faiss-cpu", "feedparser", "fireworks-ai", "geopandas", "gitpython", "google-cloud-documentai", "gql", "gradientai", "hdbcli", "hologres-vector", "html2text", "httpx", "javelin-sdk", "jinja2", "jq", "jsonschema", "lxml", "markdownify", "motor", "msal", "mwparserfromhell", "mwxml", "newspaper3k", "numexpr", "oci", "openai", "openapi-pydantic", "oracle-ads", "pandas", "pdfminer-six", "pgvector", "praw", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "rapidocr-onnxruntime", "rdflib", "requests-toolbelt", "rspace_client", "scikit-learn", "sqlite-vss", "streamlit", "sympy", "telethon", "timescale-vector", "tqdm", "upstash-redis", "xata", "xmltodict", "zhipuai"]
|
||||
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = ">=3.8.1,<4.0"
|
||||
content-hash = "42d012441d7b42d273e11708b7e12308fc56b169d4d56c4c2511e7469743a983"
|
||||
content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79"
|
||||
|
@ -87,6 +87,7 @@ datasets = {version = "^2.15.0", optional = true}
|
||||
azure-ai-documentintelligence = {version = "^1.0.0b1", optional = true}
|
||||
oracle-ads = {version = "^2.9.1", optional = true}
|
||||
zhipuai = {version = "^1.0.7", optional = true}
|
||||
httpx = {version = "^0.24.1", optional = true}
|
||||
elasticsearch = {version = "^8.12.0", optional = true}
|
||||
hdbcli = {version = "^2.19.21", optional = true}
|
||||
oci = {version = "^2.119.1", optional = true}
|
||||
@ -253,6 +254,7 @@ extended_testing = [
|
||||
"azure-ai-documentintelligence",
|
||||
"oracle-ads",
|
||||
"zhipuai",
|
||||
"httpx",
|
||||
"elasticsearch",
|
||||
"hdbcli",
|
||||
"oci",
|
||||
|
Loading…
Reference in New Issue
Block a user