mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 14:03:26 +00:00
community: Add ChatGLM3 (#15265)
Add [ChatGLM3](https://github.com/THUDM/ChatGLM3) and updated [chatglm.ipynb](https://python.langchain.com/docs/integrations/llms/chatglm) --------- Co-authored-by: Bagatur <baskaryan@gmail.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
151
libs/community/langchain_community/llms/chatglm3.py
Normal file
151
libs/community/langchain_community/llms/chatglm3.py
Normal file
@@ -0,0 +1,151 @@
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, List, Optional, Union
|
||||
|
||||
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||||
from langchain_core.language_models.llms import LLM
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
BaseMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import Field
|
||||
|
||||
from langchain_community.llms.utils import enforce_stop_tokens
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
HEADERS = {"Content-Type": "application/json"}
|
||||
DEFAULT_TIMEOUT = 30
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> dict:
|
||||
if isinstance(message, HumanMessage):
|
||||
message_dict = {"role": "user", "content": message.content}
|
||||
elif isinstance(message, AIMessage):
|
||||
message_dict = {"role": "assistant", "content": message.content}
|
||||
elif isinstance(message, SystemMessage):
|
||||
message_dict = {"role": "system", "content": message.content}
|
||||
elif isinstance(message, FunctionMessage):
|
||||
message_dict = {"role": "function", "content": message.content}
|
||||
else:
|
||||
raise ValueError(f"Got unknown type {message}")
|
||||
return message_dict
|
||||
|
||||
|
||||
class ChatGLM3(LLM):
|
||||
"""ChatGLM3 LLM service."""
|
||||
|
||||
model_name: str = Field(default="chatglm3-6b", alias="model")
|
||||
endpoint_url: str = "http://127.0.0.1:8000/v1/chat/completions"
|
||||
"""Endpoint URL to use."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Keyword arguments to pass to the model."""
|
||||
max_tokens: int = 20000
|
||||
"""Max token allowed to pass to the model."""
|
||||
temperature: float = 0.1
|
||||
"""LLM model temperature from 0 to 10."""
|
||||
top_p: float = 0.7
|
||||
"""Top P for nucleus sampling from 0 to 1"""
|
||||
prefix_messages: List[BaseMessage] = Field(default_factory=list)
|
||||
"""Series of messages for Chat input."""
|
||||
streaming: bool = False
|
||||
"""Whether to stream the results or not."""
|
||||
http_client: Union[Any, None] = None
|
||||
timeout: int = DEFAULT_TIMEOUT
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
return "chat_glm_3"
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> dict:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = {
|
||||
"model": self.model_name,
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_tokens,
|
||||
"top_p": self.top_p,
|
||||
"stream": self.streaming,
|
||||
}
|
||||
return {**params, **(self.model_kwargs or {})}
|
||||
|
||||
@property
|
||||
def client(self) -> Any:
|
||||
import httpx
|
||||
|
||||
return self.http_client or httpx.Client(timeout=self.timeout)
|
||||
|
||||
def _get_payload(self, prompt: str) -> dict:
|
||||
params = self._invocation_params
|
||||
messages = self.prefix_messages + [HumanMessage(content=prompt)]
|
||||
params.update(
|
||||
{
|
||||
"messages": [_convert_message_to_dict(m) for m in messages],
|
||||
}
|
||||
)
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Call out to a ChatGLM3 LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = chatglm_llm("Who are you?")
|
||||
"""
|
||||
import httpx
|
||||
|
||||
payload = self._get_payload(prompt)
|
||||
logger.debug(f"ChatGLM3 payload: {payload}")
|
||||
|
||||
try:
|
||||
response = self.client.post(
|
||||
self.endpoint_url, headers=HEADERS, json=payload
|
||||
)
|
||||
except httpx.NetworkError as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
logger.debug(f"ChatGLM3 response: {response}")
|
||||
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed with response: {response}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if isinstance(parsed_response, dict):
|
||||
content_keys = "choices"
|
||||
if content_keys in parsed_response:
|
||||
choices = parsed_response[content_keys]
|
||||
if len(choices):
|
||||
text = choices[0]["message"]["content"]
|
||||
else:
|
||||
raise ValueError(f"No content in response : {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised during decoding response from inference endpoint: {e}."
|
||||
f"\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
||||
return text
|
Reference in New Issue
Block a user