From 46e28b961312a97747b28d512bc1114b0d9427b0 Mon Sep 17 00:00:00 2001 From: Mikelarg Date: Tue, 24 Oct 2023 22:53:51 +0300 Subject: [PATCH] Added GigaChat chat model support (#12201) - **Description:** Added integration with [GigaChat](https://developers.sber.ru/portal/products/gigachat) language model. - **Twitter handle:** @dvoshansky --- docs/docs/integrations/chat/gigachat.ipynb | 114 ++++++++ docs/docs/integrations/llms/gigachat.ipynb | 113 ++++++++ .../integrations/providers/salute_devices.mdx | 29 ++ .../langchain/chat_models/__init__.py | 2 + .../langchain/chat_models/gigachat.py | 179 ++++++++++++ libs/langchain/langchain/llms/__init__.py | 10 + libs/langchain/langchain/llms/gigachat.py | 259 ++++++++++++++++++ 7 files changed, 706 insertions(+) create mode 100644 docs/docs/integrations/chat/gigachat.ipynb create mode 100644 docs/docs/integrations/llms/gigachat.ipynb create mode 100644 docs/docs/integrations/providers/salute_devices.mdx create mode 100644 libs/langchain/langchain/chat_models/gigachat.py create mode 100644 libs/langchain/langchain/llms/gigachat.py diff --git a/docs/docs/integrations/chat/gigachat.ipynb b/docs/docs/integrations/chat/gigachat.ipynb new file mode 100644 index 00000000000..c353fe75d8f --- /dev/null +++ b/docs/docs/integrations/chat/gigachat.ipynb @@ -0,0 +1,114 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# GigaChat\n", + "This notebook shows how to use LangChain with [GigaChat](https://developers.sber.ru/portal/products/gigachat).\n", + "To use you need to install ```gigachat``` python package." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# !pip install gigachat" + ] + }, + { + "cell_type": "markdown", + "source": [ + "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/api/integration)\n", + "## Example" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 9, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ['GIGACHAT_CREDENTIALS'] = getpass()" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 10, + "outputs": [], + "source": [ + "from langchain.chat_models import GigaChat\n", + "\n", + "chat = GigaChat(verify_ssl_certs=False)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 31, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "What do you get when you cross a goat and a skunk? A smelly goat!\n" + ] + } + ], + "source": [ + "from langchain.schema import SystemMessage, HumanMessage\n", + "\n", + "messages = [\n", + " SystemMessage(\n", + " content=\"You are a helpful AI that shares everything you know. Talk in English.\"\n", + " ),\n", + " HumanMessage(\n", + " content=\"Tell me a joke\"\n", + " ),\n", + "]\n", + "\n", + "print(chat(messages).content)" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/docs/integrations/llms/gigachat.ipynb b/docs/docs/integrations/llms/gigachat.ipynb new file mode 100644 index 00000000000..a11ea9986ff --- /dev/null +++ b/docs/docs/integrations/llms/gigachat.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# GigaChat\n", + "This notebook shows how to use LangChain with [GigaChat](https://developers.sber.ru/portal/products/gigachat).\n", + "To use you need to install ```gigachat``` python package." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "# !pip install gigachat" + ] + }, + { + "cell_type": "markdown", + "source": [ + "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/api/integration)\n", + "## Example" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 1, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ['GIGACHAT_CREDENTIALS'] = getpass()" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "outputs": [], + "source": [ + "from langchain.llms import GigaChat\n", + "\n", + "llm = GigaChat(verify_ssl_certs=False)" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 3, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "The capital of Russia is Moscow.\n" + ] + } + ], + "source": [ + "from langchain.prompts import PromptTemplate\n", + "from langchain.chains import LLMChain\n", + "\n", + "template = \"What is capital of {country}?\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"country\"])\n", + "\n", + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "generated = llm_chain.run(country=\"Russia\")\n", + "print(generated)" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/docs/docs/integrations/providers/salute_devices.mdx b/docs/docs/integrations/providers/salute_devices.mdx new file mode 100644 index 00000000000..2ab03926722 --- /dev/null +++ b/docs/docs/integrations/providers/salute_devices.mdx @@ -0,0 +1,29 @@ +# Salute Devices + +Salute Devices provides GigaChat LLM's models. + +For more info how to get access to GigaChat [follow here](https://developers.sber.ru/docs/ru/gigachat/api/integration). + +## Installation and Setup + +GigaChat package can be installed via pip from PyPI: + +```bash +pip install gigachat +``` + +## LLMs + +See a [usage example](/docs/integrations/llms/gigachat). + +```python +from langchain.llms import GigaChat +``` + +## Chat models + +See a [usage example](/docs/integrations/chat/gigachat). + +```python +from langchain.chat_models import GigaChat +``` \ No newline at end of file diff --git a/libs/langchain/langchain/chat_models/__init__.py b/libs/langchain/langchain/chat_models/__init__.py index 9dc30fe9217..2471fe1c3f2 100644 --- a/libs/langchain/langchain/chat_models/__init__.py +++ b/libs/langchain/langchain/chat_models/__init__.py @@ -28,6 +28,7 @@ from langchain.chat_models.ernie import ErnieBotChat from langchain.chat_models.everlyai import ChatEverlyAI from langchain.chat_models.fake import FakeListChatModel from langchain.chat_models.fireworks import ChatFireworks +from langchain.chat_models.gigachat import GigaChat from langchain.chat_models.google_palm import ChatGooglePalm from langchain.chat_models.human import HumanInputChatModel from langchain.chat_models.hunyuan import ChatHunyuan @@ -71,4 +72,5 @@ __all__ = [ "ChatYandexGPT", "ChatBaichuan", "ChatHunyuan", + "GigaChat", ] diff --git a/libs/langchain/langchain/chat_models/gigachat.py b/libs/langchain/langchain/chat_models/gigachat.py new file mode 100644 index 00000000000..2b8ac7f2133 --- /dev/null +++ b/libs/langchain/langchain/chat_models/gigachat.py @@ -0,0 +1,179 @@ +import logging +from typing import Any, AsyncIterator, Iterator, List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.chat_models.base import ( + BaseChatModel, + _agenerate_from_stream, + _generate_from_stream, +) +from langchain.llms.gigachat import _BaseGigaChat +from langchain.schema import ChatResult +from langchain.schema.messages import ( + AIMessage, + AIMessageChunk, + BaseMessage, + ChatMessage, + HumanMessage, + SystemMessage, +) +from langchain.schema.output import ChatGeneration, ChatGenerationChunk + +logger = logging.getLogger(__name__) + + +def _convert_dict_to_message(message: Any) -> BaseMessage: + from gigachat.models import MessagesRole + + if message.role == MessagesRole.SYSTEM: + return SystemMessage(content=message.content) + elif message.role == MessagesRole.USER: + return HumanMessage(content=message.content) + elif message.role == MessagesRole.ASSISTANT: + return AIMessage(content=message.content) + else: + raise TypeError(f"Got unknown role {message.role} {message}") + + +def _convert_message_to_dict(message: BaseMessage) -> Any: + from gigachat.models import Messages, MessagesRole + + if isinstance(message, SystemMessage): + return Messages(role=MessagesRole.SYSTEM, content=message.content) + elif isinstance(message, HumanMessage): + return Messages(role=MessagesRole.USER, content=message.content) + elif isinstance(message, AIMessage): + return Messages(role=MessagesRole.ASSISTANT, content=message.content) + elif isinstance(message, ChatMessage): + return Messages(role=MessagesRole(message.role), content=message.content) + else: + raise TypeError(f"Got unknown type {message}") + + +class GigaChat(_BaseGigaChat, BaseChatModel): + """`GigaChat` large language models API. + + To use, you should pass login and password to access GigaChat API or use token. + + Example: + .. code-block:: python + + from langchain.chat_models import GigaChat + giga = GigaChat(credentials=..., verify_ssl_certs=False) + """ + + def _build_payload(self, messages: List[BaseMessage]) -> Any: + from gigachat.models import Chat + + payload = Chat( + messages=[_convert_message_to_dict(m) for m in messages], + profanity_check=self.profanity, + ) + if self.temperature is not None: + payload.temperature = self.temperature + if self.max_tokens is not None: + payload.max_tokens = self.max_tokens + + if self.verbose: + logger.info("Giga request: %s", payload.dict()) + + return payload + + def _create_chat_result(self, response: Any) -> ChatResult: + generations = [] + for res in response.choices: + message = _convert_dict_to_message(res.message) + finish_reason = res.finish_reason + gen = ChatGeneration( + message=message, + generation_info={"finish_reason": finish_reason}, + ) + generations.append(gen) + if finish_reason != "stop": + logger.warning( + "Giga generation stopped with reason: %s", + finish_reason, + ) + if self.verbose: + logger.info("Giga response: %s", message.content) + llm_output = {"token_usage": response.usage, "model_name": response.model} + return ChatResult(generations=generations, llm_output=llm_output) + + def _generate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._stream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return _generate_from_stream(stream_iter) + + payload = self._build_payload(messages) + response = self._client.chat(payload) + + return self._create_chat_result(response) + + async def _agenerate( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> ChatResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + stream_iter = self._astream( + messages, stop=stop, run_manager=run_manager, **kwargs + ) + return await _agenerate_from_stream(stream_iter) + + payload = self._build_payload(messages) + response = await self._client.achat(payload) + + return self._create_chat_result(response) + + def _stream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[ChatGenerationChunk]: + payload = self._build_payload(messages) + + for chunk in self._client.stream(payload): + if chunk.choices: + content = chunk.choices[0].delta.content + yield ChatGenerationChunk(message=AIMessageChunk(content=content)) + if run_manager: + run_manager.on_llm_new_token(content) + + async def _astream( + self, + messages: List[BaseMessage], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[ChatGenerationChunk]: + payload = self._build_payload(messages) + + async for chunk in self._client.astream(payload): + if chunk.choices: + content = chunk.choices[0].delta.content + yield ChatGenerationChunk(message=AIMessageChunk(content=content)) + if run_manager: + await run_manager.on_llm_new_token(content) + + def get_num_tokens(self, text: str) -> int: + """Count approximate number of tokens""" + return round(len(text) / 4.6) diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index e44255cb6c0..88b4e79d6df 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -184,6 +184,12 @@ def _import_forefrontai() -> Any: return ForefrontAI +def _import_gigachat() -> Any: + from langchain.llms.gigachat import GigaChat + + return GigaChat + + def _import_google_palm() -> Any: from langchain.llms.google_palm import GooglePalm @@ -547,6 +553,8 @@ def __getattr__(name: str) -> Any: return _import_fireworks() elif name == "ForefrontAI": return _import_forefrontai() + elif name == "GigaChat": + return _import_gigachat() elif name == "GooglePalm": return _import_google_palm() elif name == "GooseAI": @@ -686,6 +694,7 @@ __all__ = [ "FakeListLLM", "Fireworks", "ForefrontAI", + "GigaChat", "GPT4All", "GooglePalm", "GooseAI", @@ -767,6 +776,7 @@ def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: "edenai": _import_edenai, "fake-list": _import_fake, "forefrontai": _import_forefrontai, + "giga-chat-model": _import_gigachat, "google_palm": _import_google_palm, "gooseai": _import_gooseai, "gradient": _import_gradient_ai, diff --git a/libs/langchain/langchain/llms/gigachat.py b/libs/langchain/langchain/llms/gigachat.py new file mode 100644 index 00000000000..b484d8bb415 --- /dev/null +++ b/libs/langchain/langchain/llms/gigachat.py @@ -0,0 +1,259 @@ +from __future__ import annotations + +import logging +from functools import cached_property +from typing import Any, AsyncIterator, Dict, Iterator, List, Optional + +from langchain.callbacks.manager import ( + AsyncCallbackManagerForLLMRun, + CallbackManagerForLLMRun, +) +from langchain.llms.base import BaseLLM +from langchain.load.serializable import Serializable +from langchain.pydantic_v1 import root_validator +from langchain.schema.output import Generation, GenerationChunk, LLMResult + +logger = logging.getLogger(__name__) + + +class _BaseGigaChat(Serializable): + base_url: Optional[str] = None + """ Base API URL """ + auth_url: Optional[str] = None + """ Auth URL """ + credentials: Optional[str] = None + """ Auth Token """ + scope: Optional[str] = None + """ Permission scope for access token """ + + access_token: Optional[str] = None + """ Access token for GigaChat """ + + model: Optional[str] = None + """Model name to use.""" + user: Optional[str] = None + """ Username for authenticate """ + password: Optional[str] = None + """ Password for authenticate """ + + timeout: Optional[float] = None + """ Timeout for request """ + verify_ssl_certs: Optional[bool] = None + """ Check certificates for all requests """ + + ca_bundle_file: Optional[str] = None + cert_file: Optional[str] = None + key_file: Optional[str] = None + key_file_password: Optional[str] = None + # Support for connection to GigaChat through SSL certificates + + profanity: bool = True + """ Check for profanity """ + streaming: bool = False + """ Whether to stream the results or not. """ + temperature: Optional[float] = None + """What sampling temperature to use.""" + max_tokens: Optional[int] = None + """ Maximum number of tokens to generate """ + + @property + def _llm_type(self) -> str: + return "giga-chat-model" + + @property + def lc_secrets(self) -> Dict[str, str]: + return { + "credentials": "GIGACHAT_CREDENTIALS", + "access_token": "GIGACHAT_ACCESS_TOKEN", + "password": "GIGACHAT_PASSWORD", + "key_file_password": "GIGACHAT_KEY_FILE_PASSWORD", + } + + @property + def lc_serializable(self) -> bool: + return True + + @cached_property + def _client(self) -> Any: + """Returns GigaChat API client""" + import gigachat + + return gigachat.GigaChat( + base_url=self.base_url, + auth_url=self.auth_url, + credentials=self.credentials, + scope=self.scope, + access_token=self.access_token, + model=self.model, + user=self.user, + password=self.password, + timeout=self.timeout, + verify_ssl_certs=self.verify_ssl_certs, + ca_bundle_file=self.ca_bundle_file, + cert_file=self.cert_file, + key_file=self.key_file, + key_file_password=self.key_file_password, + ) + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate authenticate data in environment and python package is installed.""" + try: + import gigachat # noqa: F401 + except ImportError: + raise ImportError( + "Could not import gigachat python package. " + "Please install it with `pip install gigachat`." + ) + return values + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Get the identifying parameters.""" + return { + "temperature": self.temperature, + "model": self.model, + "profanity": self.profanity, + "streaming": self.streaming, + "max_tokens": self.max_tokens, + } + + +class GigaChat(_BaseGigaChat, BaseLLM): + """`GigaChat` large language models API. + + To use, you should pass login and password to access GigaChat API or use token. + + Example: + .. code-block:: python + + from langchain.llms import GigaChat + giga = GigaChat(credentials=..., verify_ssl_certs=False) + """ + + def _build_payload(self, messages: List[str]) -> Dict[str, Any]: + payload: Dict[str, Any] = { + "messages": [{"role": "user", "content": m} for m in messages], + "profanity_check": self.profanity, + } + if self.temperature is not None: + payload["temperature"] = self.temperature + if self.max_tokens is not None: + payload["max_tokens"] = self.max_tokens + if self.model: + payload["model"] = self.model + + if self.verbose: + logger.info("Giga request: %s", payload) + + return payload + + def _create_llm_result(self, response: Any) -> LLMResult: + generations = [] + for res in response.choices: + finish_reason = res.finish_reason + gen = Generation( + text=res.message.content, + generation_info={"finish_reason": finish_reason}, + ) + generations.append([gen]) + if finish_reason != "stop": + logger.warning( + "Giga generation stopped with reason: %s", + finish_reason, + ) + if self.verbose: + logger.info("Giga response: %s", res.message.content) + token_usage = response.usage + llm_output = {"token_usage": token_usage, "model_name": response.model} + return LLMResult(generations=generations, llm_output=llm_output) + + def _generate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> LLMResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + generation: Optional[GenerationChunk] = None + stream_iter = self._stream( + prompts[0], stop=stop, run_manager=run_manager, **kwargs + ) + for chunk in stream_iter: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + + payload = self._build_payload(prompts) + response = self._client.chat(payload) + + return self._create_llm_result(response) + + async def _agenerate( + self, + prompts: List[str], + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + stream: Optional[bool] = None, + **kwargs: Any, + ) -> LLMResult: + should_stream = stream if stream is not None else self.streaming + if should_stream: + generation: Optional[GenerationChunk] = None + stream_iter = self._astream( + prompts[0], stop=stop, run_manager=run_manager, **kwargs + ) + async for chunk in stream_iter: + if generation is None: + generation = chunk + else: + generation += chunk + assert generation is not None + return LLMResult(generations=[[generation]]) + + payload = self._build_payload(prompts) + response = await self._client.achat(payload) + + return self._create_llm_result(response) + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + payload = self._build_payload([prompt]) + + for chunk in self._client.stream(payload): + if chunk.choices: + content = chunk.choices[0].delta.content + yield GenerationChunk(text=content) + if run_manager: + run_manager.on_llm_new_token(content) + + async def _astream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> AsyncIterator[GenerationChunk]: + payload = self._build_payload([prompt]) + + async for chunk in self._client.astream(payload): + if chunk.choices: + content = chunk.choices[0].delta.content + yield GenerationChunk(text=content) + if run_manager: + await run_manager.on_llm_new_token(content) + + def get_num_tokens(self, text: str) -> int: + """Count approximate number of tokens""" + return round(len(text) / 4.6)