mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-12 12:59:07 +00:00
community[minor]: Added GigaChat Embeddings support + updated previous GigaChat integration (#19516)
- **Description:** Added integration with [GigaChat](https://developers.sber.ru/portal/products/gigachat) embeddings. Also added support for extra fields in GigaChat LLM and fixed docs.
This commit is contained in:
@@ -1,5 +1,17 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import Any, AsyncIterator, Iterator, List, Optional
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
AsyncIterator,
|
||||
Dict,
|
||||
Iterator,
|
||||
List,
|
||||
Mapping,
|
||||
Optional,
|
||||
Type,
|
||||
)
|
||||
|
||||
from langchain_core.callbacks import (
|
||||
AsyncCallbackManagerForLLMRun,
|
||||
@@ -14,31 +26,47 @@ from langchain_core.messages import (
|
||||
AIMessage,
|
||||
AIMessageChunk,
|
||||
BaseMessage,
|
||||
BaseMessageChunk,
|
||||
ChatMessage,
|
||||
ChatMessageChunk,
|
||||
FunctionMessage,
|
||||
FunctionMessageChunk,
|
||||
HumanMessage,
|
||||
HumanMessageChunk,
|
||||
SystemMessage,
|
||||
SystemMessageChunk,
|
||||
)
|
||||
from langchain_core.outputs import ChatGeneration, ChatGenerationChunk, ChatResult
|
||||
|
||||
from langchain_community.llms.gigachat import _BaseGigaChat
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import gigachat.models as gm
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _convert_dict_to_message(message: Any) -> BaseMessage:
|
||||
from gigachat.models import MessagesRole
|
||||
def _convert_dict_to_message(message: gm.Messages) -> BaseMessage:
|
||||
from gigachat.models import FunctionCall, MessagesRole
|
||||
|
||||
additional_kwargs: Dict = {}
|
||||
if function_call := message.function_call:
|
||||
if isinstance(function_call, FunctionCall):
|
||||
additional_kwargs["function_call"] = dict(function_call)
|
||||
elif isinstance(function_call, dict):
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
if message.role == MessagesRole.SYSTEM:
|
||||
return SystemMessage(content=message.content)
|
||||
elif message.role == MessagesRole.USER:
|
||||
return HumanMessage(content=message.content)
|
||||
elif message.role == MessagesRole.ASSISTANT:
|
||||
return AIMessage(content=message.content)
|
||||
return AIMessage(content=message.content, additional_kwargs=additional_kwargs)
|
||||
else:
|
||||
raise TypeError(f"Got unknown role {message.role} {message}")
|
||||
|
||||
|
||||
def _convert_message_to_dict(message: BaseMessage) -> Any:
|
||||
def _convert_message_to_dict(message: gm.BaseMessage) -> gm.Messages:
|
||||
from gigachat.models import Messages, MessagesRole
|
||||
|
||||
if isinstance(message, SystemMessage):
|
||||
@@ -46,13 +74,45 @@ def _convert_message_to_dict(message: BaseMessage) -> Any:
|
||||
elif isinstance(message, HumanMessage):
|
||||
return Messages(role=MessagesRole.USER, content=message.content)
|
||||
elif isinstance(message, AIMessage):
|
||||
return Messages(role=MessagesRole.ASSISTANT, content=message.content)
|
||||
return Messages(
|
||||
role=MessagesRole.ASSISTANT,
|
||||
content=message.content,
|
||||
function_call=message.additional_kwargs.get("function_call", None),
|
||||
)
|
||||
elif isinstance(message, ChatMessage):
|
||||
return Messages(role=MessagesRole(message.role), content=message.content)
|
||||
elif isinstance(message, FunctionMessage):
|
||||
return Messages(role=MessagesRole.FUNCTION, content=message.content)
|
||||
else:
|
||||
raise TypeError(f"Got unknown type {message}")
|
||||
|
||||
|
||||
def _convert_delta_to_message_chunk(
|
||||
_dict: Mapping[str, Any], default_class: Type[BaseMessageChunk]
|
||||
) -> BaseMessageChunk:
|
||||
role = _dict.get("role")
|
||||
content = _dict.get("content") or ""
|
||||
additional_kwargs: Dict = {}
|
||||
if _dict.get("function_call"):
|
||||
function_call = dict(_dict["function_call"])
|
||||
if "name" in function_call and function_call["name"] is None:
|
||||
function_call["name"] = ""
|
||||
additional_kwargs["function_call"] = function_call
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
return HumanMessageChunk(content=content)
|
||||
elif role == "assistant" or default_class == AIMessageChunk:
|
||||
return AIMessageChunk(content=content, additional_kwargs=additional_kwargs)
|
||||
elif role == "system" or default_class == SystemMessageChunk:
|
||||
return SystemMessageChunk(content=content)
|
||||
elif role == "function" or default_class == FunctionMessageChunk:
|
||||
return FunctionMessageChunk(content=content, name=_dict["name"])
|
||||
elif role or default_class == ChatMessageChunk:
|
||||
return ChatMessageChunk(content=content, role=role)
|
||||
else:
|
||||
return default_class(content=content)
|
||||
|
||||
|
||||
class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
"""`GigaChat` large language models API.
|
||||
|
||||
@@ -62,23 +122,33 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_community.chat_models import GigaChat
|
||||
giga = GigaChat(credentials=..., verify_ssl_certs=False)
|
||||
giga = GigaChat(credentials=..., scope=..., verify_ssl_certs=False)
|
||||
"""
|
||||
|
||||
def _build_payload(self, messages: List[BaseMessage]) -> Any:
|
||||
def _build_payload(self, messages: List[BaseMessage], **kwargs: Any) -> gm.Chat:
|
||||
from gigachat.models import Chat
|
||||
|
||||
payload = Chat(
|
||||
messages=[_convert_message_to_dict(m) for m in messages],
|
||||
profanity_check=self.profanity,
|
||||
)
|
||||
|
||||
payload.functions = kwargs.get("functions", None)
|
||||
|
||||
if self.profanity_check is not None:
|
||||
payload.profanity_check = self.profanity_check
|
||||
if self.temperature is not None:
|
||||
payload.temperature = self.temperature
|
||||
if self.top_p is not None:
|
||||
payload.top_p = self.top_p
|
||||
if self.max_tokens is not None:
|
||||
payload.max_tokens = self.max_tokens
|
||||
if self.repetition_penalty is not None:
|
||||
payload.repetition_penalty = self.repetition_penalty
|
||||
if self.update_interval is not None:
|
||||
payload.update_interval = self.update_interval
|
||||
|
||||
if self.verbose:
|
||||
logger.info("Giga request: %s", payload.dict())
|
||||
logger.warning("Giga request: %s", payload.dict())
|
||||
|
||||
return payload
|
||||
|
||||
@@ -98,7 +168,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
finish_reason,
|
||||
)
|
||||
if self.verbose:
|
||||
logger.info("Giga response: %s", message.content)
|
||||
logger.warning("Giga response: %s", message.content)
|
||||
llm_output = {"token_usage": response.usage, "model_name": response.model}
|
||||
return ChatResult(generations=generations, llm_output=llm_output)
|
||||
|
||||
@@ -117,7 +187,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
)
|
||||
return generate_from_stream(stream_iter)
|
||||
|
||||
payload = self._build_payload(messages)
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
response = self._client.chat(payload)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
@@ -137,7 +207,7 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
)
|
||||
return await agenerate_from_stream(stream_iter)
|
||||
|
||||
payload = self._build_payload(messages)
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
response = await self._client.achat(payload)
|
||||
|
||||
return self._create_chat_result(response)
|
||||
@@ -149,15 +219,28 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
payload = self._build_payload(messages)
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
|
||||
for chunk in self._client.stream(payload):
|
||||
if chunk.choices:
|
||||
content = chunk.choices[0].delta.content
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
|
||||
choice = chunk["choices"][0]
|
||||
content = choice.get("delta", {}).get("content", {})
|
||||
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
|
||||
if run_manager:
|
||||
run_manager.on_llm_new_token(content)
|
||||
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
|
||||
async def _astream(
|
||||
self,
|
||||
@@ -166,16 +249,24 @@ class GigaChat(_BaseGigaChat, BaseChatModel):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterator[ChatGenerationChunk]:
|
||||
payload = self._build_payload(messages)
|
||||
payload = self._build_payload(messages, **kwargs)
|
||||
|
||||
async for chunk in self._client.astream(payload):
|
||||
if chunk.choices:
|
||||
content = chunk.choices[0].delta.content
|
||||
cg_chunk = ChatGenerationChunk(message=AIMessageChunk(content=content))
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content, chunk=cg_chunk)
|
||||
yield cg_chunk
|
||||
if not isinstance(chunk, dict):
|
||||
chunk = chunk.dict()
|
||||
if len(chunk["choices"]) == 0:
|
||||
continue
|
||||
|
||||
def get_num_tokens(self, text: str) -> int:
|
||||
"""Count approximate number of tokens"""
|
||||
return round(len(text) / 4.6)
|
||||
choice = chunk["choices"][0]
|
||||
content = choice.get("delta", {}).get("content", {})
|
||||
chunk = _convert_delta_to_message_chunk(choice["delta"], AIMessageChunk)
|
||||
|
||||
finish_reason = choice.get("finish_reason")
|
||||
|
||||
generation_info = (
|
||||
dict(finish_reason=finish_reason) if finish_reason is not None else None
|
||||
)
|
||||
|
||||
yield ChatGenerationChunk(message=chunk, generation_info=generation_info)
|
||||
if run_manager:
|
||||
await run_manager.on_llm_new_token(content)
|
||||
|
Reference in New Issue
Block a user