From dac2e0165a895c19416a55249b1f9f20d51df7de Mon Sep 17 00:00:00 2001 From: Mikelarg Date: Tue, 26 Mar 2024 02:08:37 +0300 Subject: [PATCH] community[minor]: Added GigaChat Embeddings support + updated previous GigaChat integration (#19516) - **Description:** Added integration with [GigaChat](https://developers.sber.ru/portal/products/gigachat) embeddings. Also added support for extra fields in GigaChat LLM and fixed docs. --- docs/docs/integrations/chat/gigachat.ipynb | 24 ++- docs/docs/integrations/llms/gigachat.ipynb | 20 +- .../integrations/providers/salute_devices.mdx | 8 + .../text_embedding/gigachat.ipynb | 116 +++++++++++ .../chat_models/gigachat.py | 151 +++++++++++--- .../embeddings/__init__.py | 1 + .../embeddings/gigachat.py | 187 ++++++++++++++++++ .../langchain_community/llms/gigachat.py | 102 ++++++++-- .../unit_tests/embeddings/test_imports.py | 1 + 9 files changed, 548 insertions(+), 62 deletions(-) create mode 100644 docs/docs/integrations/text_embedding/gigachat.ipynb create mode 100644 libs/community/langchain_community/embeddings/gigachat.py diff --git a/docs/docs/integrations/chat/gigachat.ipynb b/docs/docs/integrations/chat/gigachat.ipynb index e33b0fa2cda..8676c26875d 100644 --- a/docs/docs/integrations/chat/gigachat.ipynb +++ b/docs/docs/integrations/chat/gigachat.ipynb @@ -13,9 +13,12 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": null, "metadata": { - "collapsed": true + "collapsed": true, + "pycharm": { + "is_executing": true + } }, "outputs": [], "source": [ @@ -28,13 +31,14 @@ "collapsed": false }, "source": [ - "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/api/integration)\n", + "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/individuals-quickstart)\n", + "\n", "## Example" ] }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -48,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -56,12 +60,12 @@ "source": [ "from langchain_community.chat_models import GigaChat\n", "\n", - "chat = GigaChat(verify_ssl_certs=False)" + "chat = GigaChat(verify_ssl_certs=False, scope=\"GIGACHAT_API_PERS\")" ] }, { "cell_type": "code", - "execution_count": 31, + "execution_count": 8, "metadata": { "collapsed": false }, @@ -70,7 +74,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "What do you get when you cross a goat and a skunk? A smelly goat!\n" + "The capital of Russia is Moscow.\n" ] } ], @@ -81,10 +85,10 @@ " SystemMessage(\n", " content=\"You are a helpful AI that shares everything you know. Talk in English.\"\n", " ),\n", - " HumanMessage(content=\"Tell me a joke\"),\n", + " HumanMessage(content=\"What is capital of Russia?\"),\n", "]\n", "\n", - "print(chat(messages).content)" + "print(chat.invoke(messages).content)" ] } ], diff --git a/docs/docs/integrations/llms/gigachat.ipynb b/docs/docs/integrations/llms/gigachat.ipynb index 8e1e4a43d07..7e92a38aa73 100644 --- a/docs/docs/integrations/llms/gigachat.ipynb +++ b/docs/docs/integrations/llms/gigachat.ipynb @@ -15,7 +15,10 @@ "cell_type": "code", "execution_count": null, "metadata": { - "collapsed": true + "collapsed": true, + "pycharm": { + "is_executing": true + } }, "outputs": [], "source": [ @@ -28,13 +31,14 @@ "collapsed": false }, "source": [ - "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/api/integration)\n", + "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/individuals-quickstart)\n", + "\n", "## Example" ] }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 2, "metadata": { "collapsed": false }, @@ -48,7 +52,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 3, "metadata": { "collapsed": false }, @@ -56,12 +60,12 @@ "source": [ "from langchain_community.llms import GigaChat\n", "\n", - "llm = GigaChat(verify_ssl_certs=False)" + "llm = GigaChat(verify_ssl_certs=False, scope=\"GIGACHAT_API_PERS\")" ] }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 9, "metadata": { "collapsed": false }, @@ -84,8 +88,8 @@ "\n", "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", "\n", - "generated = llm_chain.run(country=\"Russia\")\n", - "print(generated)" + "generated = llm_chain.invoke(input={\"country\": \"Russia\"})\n", + "print(generated[\"text\"])" ] } ], diff --git a/docs/docs/integrations/providers/salute_devices.mdx b/docs/docs/integrations/providers/salute_devices.mdx index b35adf02449..2651090acc0 100644 --- a/docs/docs/integrations/providers/salute_devices.mdx +++ b/docs/docs/integrations/providers/salute_devices.mdx @@ -26,4 +26,12 @@ See a [usage example](/docs/integrations/chat/gigachat). ```python from langchain_community.chat_models import GigaChat +``` + +## Embeddings + +See a [usage example](/docs/integrations/text_embedding/gigachat). + +```python +from langchain_community.embeddings import GigaChatEmbeddings ``` \ No newline at end of file diff --git a/docs/docs/integrations/text_embedding/gigachat.ipynb b/docs/docs/integrations/text_embedding/gigachat.ipynb new file mode 100644 index 00000000000..d50c2a5fda6 --- /dev/null +++ b/docs/docs/integrations/text_embedding/gigachat.ipynb @@ -0,0 +1,116 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# GigaChat\n", + "This notebook shows how to use LangChain with [GigaChat embeddings](https://developers.sber.ru/portal/products/gigachat).\n", + "To use you need to install ```gigachat``` python package." + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": null, + "outputs": [], + "source": [ + "%pip install --upgrade --quiet gigachat" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "markdown", + "source": [ + "To get GigaChat credentials you need to [create account](https://developers.sber.ru/studio/login) and [get access to API](https://developers.sber.ru/docs/ru/gigachat/individuals-quickstart)\n", + "\n", + "## Example" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "collapsed": true + }, + "outputs": [], + "source": [ + "import os\n", + "from getpass import getpass\n", + "\n", + "os.environ[\"GIGACHAT_CREDENTIALS\"] = getpass()" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "outputs": [], + "source": [ + "from langchain_community.embeddings import GigaChatEmbeddings\n", + "\n", + "embeddings = GigaChatEmbeddings(verify_ssl_certs=False, scope=\"GIGACHAT_API_PERS\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 7, + "outputs": [], + "source": [ + "query_result = embeddings.embed_query(\"The quick brown fox jumps over the lazy dog\")" + ], + "metadata": { + "collapsed": false + } + }, + { + "cell_type": "code", + "execution_count": 8, + "outputs": [ + { + "data": { + "text/plain": "[0.8398333191871643,\n -0.14180311560630798,\n -0.6161925792694092,\n -0.17103666067123413,\n 1.2884578704833984]" + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "query_result[:5]" + ], + "metadata": { + "collapsed": false + } + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 2 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython2", + "version": "2.7.6" + } + }, + "nbformat": 4, + "nbformat_minor": 0 +} diff --git a/libs/community/langchain_community/chat_models/gigachat.py b/libs/community/langchain_community/chat_models/gigachat.py index ca02400ea0d..7345aa4ad9c 100644 --- a/libs/community/langchain_community/chat_models/gigachat.py +++ b/libs/community/langchain_community/chat_models/gigachat.py @@ -1,5 +1,17 @@ +from __future__ import annotations + import logging -from typing import Any, AsyncIterator, Iterator, List, Optional +from typing import ( + TYPE_CHECKING, + Any, + AsyncIterator, + Dict, + Iterator, + List, + Mapping, + Optional, + Type, +) from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -14,31 +26,47 @@ from langchain_core.messages import ( AIMessage, AIMessageChunk, BaseMessage, + BaseMessageChunk, ChatMessage, + ChatMessageChunk, + FunctionMessage, + FunctionMessageChunk, HumanMessage, + HumanMessageChunk, SystemMessage, + SystemMessageChunk, ) from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult from langchain_community.llms.gigachat import _BaseGigaChat +if TYPE_CHECKING: + import gigachat.models as gm + logger = logging.getLogger(__name__) -def _convert_dict_to_message(message: Any) -> BaseMessage: - from gigachat.models import MessagesRole +def _convert_dict_to_message(message: gm.Messages) -> BaseMessage: + from gigachat.models import FunctionCall, MessagesRole + + additional_kwargs: Dict = {} + if function_call := message.function_call: + if isinstance(function_call, FunctionCall): + additional_kwargs["function_call"] = dict(function_call) + elif isinstance(function_call, dict): + additional_kwargs["function_call"] = function_call if message.role == MessagesRole.SYSTEM: return SystemMessage(content=message.content) elif message.role == MessagesRole.USER: return HumanMessage(content=message.content) elif message.role == MessagesRole.ASSISTANT: - return AIMessage(content=message.content) + return AIMessage(content=message.content, additional_kwargs=additional_kwargs) else: raise TypeError(f"Got unknown role {message.role} {message}") -def _convert_message_to_dict(message: BaseMessage) -> Any: +def _convert_message_to_dict(message: gm.BaseMessage) -> gm.Messages: from gigachat.models import Messages, MessagesRole if isinstance(message, SystemMessage): @@ -46,13 +74,45 @@ def _convert_message_to_dict(message: BaseMessage) -> Any: elif isinstance(message, HumanMessage): return Messages(role=MessagesRole.USER, content=message.content) elif isinstance(message, AIMessage): - return Messages(role=MessagesRole.ASSISTANT, content=message.content) + return Messages( + role=MessagesRole.ASSISTANT, + content=message.content, + function_call=message.additional_kwargs.get("function_call", None), + ) elif isinstance(message, ChatMessage): return Messages(role=MessagesRole(message.role), content=message.content) + elif isinstance(message, FunctionMessage): + return Messages(role=MessagesRole.FUNCTION, content=message.content) else: raise TypeError(f"Got unknown type {message}") +def _convert_delta_to_message_chunk( + _dict: Mapping[str, Any], default_class: Type[BaseMessageChunk] +) -> BaseMessageChunk: + role = _dict.get("role") + content = _dict.get("content") or "" + additional_kwargs: Dict = {} + if _dict.get("function_call"): + function_call = dict(_dict["function_call"]) + if "name" in function_call and function_call["name"] is None: + function_call["name"] = "" + additional_kwargs["function_call"] = function_call + + if role == "user" or default_class == HumanMessageChunk: + return HumanMessageChunk(content=content) + elif role == "assistant" or default_class == AIMessageChunk: + return AIMessageChunk(content=content, additional_kwargs=additional_kwargs) + elif role == "system" or default_class == SystemMessageChunk: + return SystemMessageChunk(content=content) + elif role == "function" or default_class == FunctionMessageChunk: + return FunctionMessageChunk(content=content, name=_dict["name"]) + elif role or default_class == ChatMessageChunk: + return ChatMessageChunk(content=content, role=role) + else: + return default_class(content=content) + + class GigaChat(_BaseGigaChat, BaseChatModel): """`GigaChat` large language models API. @@ -62,23 +122,33 @@ class GigaChat(_BaseGigaChat, BaseChatModel): .. code-block:: python from langchain_community.chat_models import GigaChat - giga = GigaChat(credentials=..., verify_ssl_certs=False) + giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=False) """ - def _build_payload(self, messages: List[BaseMessage]) -> Any: + def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat: from gigachat.models import Chat payload = Chat( messages=[_convert_message_to_dict(m) for m in messages], - profanity_check=self.profanity, ) + + payload.functions = kwargs.get("functions", None) + + if self.profanity_check is not None: + payload.profanity_check = self.profanity_check if self.temperature is not None: payload.temperature = self.temperature + if self.top_p is not None: + payload.top_p = self.top_p if self.max_tokens is not None: payload.max_tokens = self.max_tokens + if self.repetition_penalty is not None: + payload.repetition_penalty = self.repetition_penalty + if self.update_interval is not None: + payload.update_interval = self.update_interval if self.verbose: - logger.info("Giga request: %s", payload.dict()) + logger.warning("Giga request: %s", payload.dict()) return payload @@ -98,7 +168,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel): finish_reason, ) if self.verbose: - logger.info("Giga response: %s", message.content) + logger.warning("Giga response: %s", message.content) llm_output = {"token_usage": response.usage, "model_name": response.model} return ChatResult(generations=generations, llm_output=llm_output) @@ -117,7 +187,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel): ) return generate_from_stream(stream_iter) - payload = self._build_payload(messages) + payload = self._build_payload(messages, **kwargs) response = self._client.chat(payload) return self._create_chat_result(response) @@ -137,7 +207,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel): ) return await agenerate_from_stream(stream_iter) - payload = self._build_payload(messages) + payload = self._build_payload(messages, **kwargs) response = await self._client.achat(payload) return self._create_chat_result(response) @@ -149,15 +219,28 @@ class GigaChat(_BaseGigaChat, BaseChatModel): run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[ChatGenerationChunk]: - payload = self._build_payload(messages) + payload = self._build_payload(messages, **kwargs) for chunk in self._client.stream(payload): - if chunk.choices: - content = chunk.choices[0].delta.content - cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - if run_manager: - run_manager.on_llm_new_token(content, chunk=cg_chunk) - yield cg_chunk + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue + + choice = chunk["choices"][0] + content = choice.get("delta", {}).get("content", {}) + chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) + + finish_reason = choice.get("finish_reason") + + generation_info = ( + dict(finish_reason=finish_reason) if finish_reason is not None else None + ) + + if run_manager: + run_manager.on_llm_new_token(content) + + yield ChatGenerationChunk(message=chunk, generation_info=generation_info) async def _astream( self, @@ -166,16 +249,24 @@ class GigaChat(_BaseGigaChat, BaseChatModel): run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[ChatGenerationChunk]: - payload = self._build_payload(messages) + payload = self._build_payload(messages, **kwargs) async for chunk in self._client.astream(payload): - if chunk.choices: - content = chunk.choices[0].delta.content - cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content)) - if run_manager: - await run_manager.on_llm_new_token(content, chunk=cg_chunk) - yield cg_chunk + if not isinstance(chunk, dict): + chunk = chunk.dict() + if len(chunk["choices"]) == 0: + continue - def get_num_tokens(self, text: str) -> int: - """Count approximate number of tokens""" - return round(len(text) / 4.6) + choice = chunk["choices"][0] + content = choice.get("delta", {}).get("content", {}) + chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk) + + finish_reason = choice.get("finish_reason") + + generation_info = ( + dict(finish_reason=finish_reason) if finish_reason is not None else None + ) + + yield ChatGenerationChunk(message=chunk, generation_info=generation_info) + if run_manager: + await run_manager.on_llm_new_token(content) diff --git a/libs/community/langchain_community/embeddings/__init__.py b/libs/community/langchain_community/embeddings/__init__.py index 9cc358011c3..33cf7fce17b 100644 --- a/libs/community/langchain_community/embeddings/__init__.py +++ b/libs/community/langchain_community/embeddings/__init__.py @@ -38,6 +38,7 @@ _module_lookup = { "GPT4AllEmbeddings": "langchain_community.embeddings.gpt4all", "GooglePalmEmbeddings": "langchain_community.embeddings.google_palm", "GradientEmbeddings": "langchain_community.embeddings.gradient_ai", + "GigaChatEmbeddings": "langchain_community.embeddings.gigachat", "HuggingFaceBgeEmbeddings": "langchain_community.embeddings.huggingface", "HuggingFaceEmbeddings": "langchain_community.embeddings.huggingface", "HuggingFaceHubEmbeddings": "langchain_community.embeddings.huggingface_hub", diff --git a/libs/community/langchain_community/embeddings/gigachat.py b/libs/community/langchain_community/embeddings/gigachat.py new file mode 100644 index 00000000000..8fb6f233fc6 --- /dev/null +++ b/libs/community/langchain_community/embeddings/gigachat.py @@ -0,0 +1,187 @@ +from __future__ import annotations + +import logging +from functools import cached_property +from typing import Any, Dict, List, Optional + +from langchain_core.embeddings import Embeddings +from langchain_core.pydantic_v1 import BaseModel, root_validator + +logger = logging.getLogger(__name__) + +MAX_BATCH_SIZE_CHARS = 1000000 +MAX_BATCH_SIZE_PARTS = 90 + + +class GigaChatEmbeddings(BaseModel, Embeddings): + """GigaChat Embeddings models. + + Example: + .. code-block:: python + from langchain_community.embeddings.gigachat import GigaChatEmbeddings + + embeddings = GigaChatEmbeddings( + credentials=..., scope=..., verify_ssl_certs=False + ) + """ + + base_url: Optional[str] = None + """ Base API URL """ + auth_url: Optional[str] = None + """ Auth URL """ + credentials: Optional[str] = None + """ Auth Token """ + scope: Optional[str] = None + """ Permission scope for access token """ + + access_token: Optional[str] = None + """ Access token for GigaChat """ + + model: Optional[str] = None + """Model name to use.""" + user: Optional[str] = None + """ Username for authenticate """ + password: Optional[str] = None + """ Password for authenticate """ + + timeout: Optional[float] = 600 + """ Timeout for request. By default it works for long requests. """ + verify_ssl_certs: Optional[bool] = None + """ Check certificates for all requests """ + + ca_bundle_file: Optional[str] = None + cert_file: Optional[str] = None + key_file: Optional[str] = None + key_file_password: Optional[str] = None + # Support for connection to GigaChat through SSL certificates + + @cached_property + def _client(self) -> Any: + """Returns GigaChat API client""" + import gigachat + + return gigachat.GigaChat( + base_url=self.base_url, + auth_url=self.auth_url, + credentials=self.credentials, + scope=self.scope, + access_token=self.access_token, + model=self.model, + user=self.user, + password=self.password, + timeout=self.timeout, + verify_ssl_certs=self.verify_ssl_certs, + ca_bundle_file=self.ca_bundle_file, + cert_file=self.cert_file, + key_file=self.key_file, + key_file_password=self.key_file_password, + ) + + @root_validator() + def validate_environment(cls, values: Dict) -> Dict: + """Validate authenticate data in environment and python package is installed.""" + try: + import gigachat # noqa: F401 + except ImportError: + raise ImportError( + "Could not import gigachat python package. " + "Please install it with `pip install gigachat`." + ) + fields = set(cls.__fields__.keys()) + diff = set(values.keys()) - fields + if diff: + logger.warning(f"Extra fields {diff} in GigaChat class") + return values + + def embed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a GigaChat embeddings models. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + result: List[List[float]] = [] + size = 0 + local_texts = [] + embed_kwargs = {} + if self.model is not None: + embed_kwargs["model"] = self.model + for text in texts: + local_texts.append(text) + size += len(text) + if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: + for embedding in self._client.embeddings( + texts=local_texts, **embed_kwargs + ).data: + result.append(embedding.embedding) + size = 0 + local_texts = [] + # Call for last iteration + if local_texts: + for embedding in self._client.embeddings( + texts=local_texts, **embed_kwargs + ).data: + result.append(embedding.embedding) + + return result + + async def aembed_documents(self, texts: List[str]) -> List[List[float]]: + """Embed documents using a GigaChat embeddings models. + + Args: + texts: The list of texts to embed. + + Returns: + List of embeddings, one for each text. + """ + result: List[List[float]] = [] + size = 0 + local_texts = [] + embed_kwargs = {} + if self.model is not None: + embed_kwargs["model"] = self.model + for text in texts: + local_texts.append(text) + size += len(text) + if size > MAX_BATCH_SIZE_CHARS or len(local_texts) > MAX_BATCH_SIZE_PARTS: + embeddings = await self._client.aembeddings( + texts=local_texts, **embed_kwargs + ) + for embedding in embeddings.data: + result.append(embedding.embedding) + size = 0 + local_texts = [] + # Call for last iteration + if local_texts: + embeddings = await self._client.aembeddings( + texts=local_texts, **embed_kwargs + ) + for embedding in embeddings.data: + result.append(embedding.embedding) + + return result + + def embed_query(self, text: str) -> List[float]: + """Embed a query using a GigaChat embeddings models. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + return self.embed_documents(texts=[text])[0] + + async def aembed_query(self, text: str) -> List[float]: + """Embed a query using a GigaChat embeddings models. + + Args: + text: The text to embed. + + Returns: + Embeddings for the text. + """ + docs = await self.aembed_documents(texts=[text]) + return docs[0] diff --git a/libs/community/langchain_community/llms/gigachat.py b/libs/community/langchain_community/llms/gigachat.py index 61f0893980a..67b604b2d20 100644 --- a/libs/community/langchain_community/llms/gigachat.py +++ b/libs/community/langchain_community/llms/gigachat.py @@ -2,7 +2,7 @@ from __future__ import annotations import logging from functools import cached_property -from typing import Any, AsyncIterator, Dict, Iterator, List, Optional +from typing import TYPE_CHECKING, Any, AsyncIterator, Dict, Iterator, List, Optional from langchain_core.callbacks import ( AsyncCallbackManagerForLLMRun, @@ -13,6 +13,10 @@ from langchain_core.load.serializable import Serializable from langchain_core.outputs import Generation, GenerationChunk, LLMResult from langchain_core.pydantic_v1 import root_validator +if TYPE_CHECKING: + import gigachat + import gigachat.models as gm + logger = logging.getLogger(__name__) @@ -48,13 +52,25 @@ class _BaseGigaChat(Serializable): # Support for connection to GigaChat through SSL certificates profanity: bool = True + """ DEPRECATED: Check for profanity """ + profanity_check: Optional[bool] = None """ Check for profanity """ streaming: bool = False """ Whether to stream the results or not. """ temperature: Optional[float] = None - """What sampling temperature to use.""" + """ What sampling temperature to use. """ max_tokens: Optional[int] = None """ Maximum number of tokens to generate """ + use_api_for_tokens: bool = False + """ Use GigaChat API for tokens count """ + verbose: bool = False + """ Verbose logging """ + top_p: Optional[float] = None + """ top_p value to use for nucleus sampling. Must be between 0.0 and 1.0 """ + repetition_penalty: Optional[float] = None + """ The penalty applied to repeated tokens """ + update_interval: Optional[float] = None + """ Minimum interval in seconds that elapses between sending tokens """ @property def _llm_type(self) -> str: @@ -74,7 +90,7 @@ class _BaseGigaChat(Serializable): return True @cached_property - def _client(self) -> Any: + def _client(self) -> gigachat.GigaChat: """Returns GigaChat API client""" import gigachat @@ -85,6 +101,7 @@ class _BaseGigaChat(Serializable): scope=self.scope, access_token=self.access_token, model=self.model, + profanity_check=self.profanity_check, user=self.user, password=self.password, timeout=self.timeout, @@ -93,6 +110,7 @@ class _BaseGigaChat(Serializable): cert_file=self.cert_file, key_file=self.key_file, key_file_password=self.key_file_password, + verbose=self.verbose, ) @root_validator() @@ -105,6 +123,16 @@ class _BaseGigaChat(Serializable): "Could not import gigachat python package. " "Please install it with `pip install gigachat`." ) + fields = set(cls.__fields__.keys()) + diff = set(values.keys()) - fields + if diff: + logger.warning(f"Extra fields {diff} in GigaChat class") + if "profanity" in fields and values.get("profanity") is False: + logger.warning( + "'profanity' field is deprecated. Use 'profanity_check' instead." + ) + if values.get("profanity_check") is None: + values["profanity_check"] = values.get("profanity") return values @property @@ -113,11 +141,48 @@ class _BaseGigaChat(Serializable): return { "temperature": self.temperature, "model": self.model, - "profanity": self.profanity, + "profanity": self.profanity_check, "streaming": self.streaming, "max_tokens": self.max_tokens, + "top_p": self.top_p, + "repetition_penalty": self.repetition_penalty, } + def tokens_count( + self, input_: List[str], model: Optional[str] = None + ) -> List[gm.TokensCount]: + """Get tokens of string list""" + return self._client.tokens_count(input_, model) + + async def atokens_count( + self, input_: List[str], model: Optional[str] = None + ) -> List[gm.TokensCount]: + """Get tokens of strings list (async)""" + return await self._client.atokens_count(input_, model) + + def get_models(self) -> gm.Models: + """Get available models of Gigachat""" + return self._client.get_models() + + async def aget_models(self) -> gm.Models: + """Get available models of Gigachat (async)""" + return await self._client.aget_models() + + def get_model(self, model: str) -> gm.Model: + """Get info about model""" + return self._client.get_model(model) + + async def aget_model(self, model: str) -> gm.Model: + """Get info about model (async)""" + return await self._client.aget_model(model) + + def get_num_tokens(self, text: str) -> int: + """Count approximate number of tokens""" + if self.use_api_for_tokens: + return self.tokens_count([text])[0].tokens # type: ignore + else: + return round(len(text) / 4.6) + class GigaChat(_BaseGigaChat, BaseLLM): """`GigaChat` large language models API. @@ -128,20 +193,29 @@ class GigaChat(_BaseGigaChat, BaseLLM): .. code-block:: python from langchain_community.llms import GigaChat - giga = GigaChat(credentials=..., verify_ssl_certs=False) + giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=False) """ + payload_role: str = "user" + def _build_payload(self, messages: List[str]) -> Dict[str, Any]: payload: Dict[str, Any] = { - "messages": [{"role": "user", "content": m} for m in messages], - "profanity_check": self.profanity, + "messages": [{"role": self.payload_role, "content": m} for m in messages], } - if self.temperature is not None: - payload["temperature"] = self.temperature - if self.max_tokens is not None: - payload["max_tokens"] = self.max_tokens if self.model: payload["model"] = self.model + if self.profanity_check is not None: + payload["profanity_check"] = self.profanity_check + if self.temperature is not None: + payload["temperature"] = self.temperature + if self.top_p is not None: + payload["top_p"] = self.top_p + if self.max_tokens is not None: + payload["max_tokens"] = self.max_tokens + if self.repetition_penalty is not None: + payload["repetition_penalty"] = self.repetition_penalty + if self.update_interval is not None: + payload["update_interval"] = self.update_interval if self.verbose: logger.info("Giga request: %s", payload) @@ -164,6 +238,7 @@ class GigaChat(_BaseGigaChat, BaseLLM): ) if self.verbose: logger.info("Giga response: %s", res.message.content) + token_usage = response.usage llm_output = {"token_usage": token_usage, "model_name": response.model} return LLMResult(generations=generations, llm_output=llm_output) @@ -254,6 +329,5 @@ class GigaChat(_BaseGigaChat, BaseLLM): if run_manager: await run_manager.on_llm_new_token(content) - def get_num_tokens(self, text: str) -> int: - """Count approximate number of tokens""" - return round(len(text) / 4.6) + class Config: + extra = "allow" diff --git a/libs/community/tests/unit_tests/embeddings/test_imports.py b/libs/community/tests/unit_tests/embeddings/test_imports.py index 5ca203d9344..3e3ed33d9f0 100644 --- a/libs/community/tests/unit_tests/embeddings/test_imports.py +++ b/libs/community/tests/unit_tests/embeddings/test_imports.py @@ -48,6 +48,7 @@ EXPECTED_ALL = [ "SpacyEmbeddings", "NLPCloudEmbeddings", "GPT4AllEmbeddings", + "GigaChatEmbeddings", "XinferenceEmbeddings", "LocalAIEmbeddings", "AwaEmbeddings",