communty[patch]: Native RAG Support in Prem AI langchain (#22238)

This PR adds native RAG support in langchain premai package. The same
has been added in the docs too.
This commit is contained in:
Anindyadeep
2024-06-04 22:49:54 +05:30
committed by GitHub
parent 77ad857934
commit 7a197539aa
3 changed files with 174 additions and 46 deletions

View File

@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import warnings
from typing import (
TYPE_CHECKING,
Any,
@@ -104,7 +105,18 @@ def _response_to_result(
text=content, message=ChatMessage(role=role, content=content)
)
)
return ChatResult(generations=generations)
if response.document_chunks is not None:
return ChatResult(
generations=generations,
llm_output={
"document_chunks": [
chunk.to_dict() for chunk in response.document_chunks
]
},
)
else:
return ChatResult(generations=generations, llm_output={"document_chunks": None})
def _convert_delta_response_to_message_chunk(
@@ -118,10 +130,6 @@ def _convert_delta_response_to_message_chunk(
role = _delta.get("role", "") # type: ignore
content = _delta.get("content", "") # type: ignore
additional_kwargs: Dict = {}
if role is None or role == "":
raise ChatPremAPIError("Role can not be None. Please check the response")
finish_reasons: Optional[str] = response.choices[0].finish_reason
if role == "user" or default_class == HumanMessageChunk:
@@ -185,17 +193,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
If model name is other than default model then it will override the calls
from the model deployed from launchpad."""
session_id: Optional[str] = None
"""The ID of the session to use. It helps to track the chat history."""
temperature: Optional[float] = None
"""Model temperature. Value should be >= 0 and <= 1.0"""
top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""
max_tokens: Optional[int] = None
"""The maximum number of tokens to generate"""
@@ -209,30 +209,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
Changing the system prompt would override the default system prompt.
"""
repositories: Optional[dict] = None
"""Add valid repository ids. This will be overriding existing connected
repositories (if any) and will use RAG with the connected repos.
"""
streaming: Optional[bool] = False
"""Whether to stream the responses or not."""
tools: Optional[Dict[str, Any]] = None
"""A list of tools the model may call. Currently, only functions are
supported as a tool"""
frequency_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based"""
presence_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far."""
logit_bias: Optional[dict] = None
"""JSON object that maps tokens to an associated bias value from -100 to 100."""
stop: Optional[Union[str, List[str]]] = None
"""Up to 4 sequences where the API will stop generating further tokens."""
seed: Optional[int] = None
"""This feature is in Beta. If specified, our system will make a best effort
to sample deterministically."""
client: Any
class Config:
@@ -268,21 +252,34 @@ class ChatPremAI(BaseChatModel, BaseModel):
@property
def _default_params(self) -> Dict[str, Any]:
# FIXME: n and stop is not supported, so hardcoding to current default value
return {
"model": self.model,
"system_prompt": self.system_prompt,
"top_p": self.top_p,
"temperature": self.temperature,
"logit_bias": self.logit_bias,
"max_tokens": self.max_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"seed": self.seed,
"stop": None,
"repositories": self.repositories,
}
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
kwargs_to_ignore = [
"top_p",
"tools",
"frequency_penalty",
"presence_penalty",
"logit_bias",
"stop",
"seed",
]
keys_to_remove = []
for key in kwargs:
if key in kwargs_to_ignore:
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
keys_to_remove.append(key)
for key in keys_to_remove:
kwargs.pop(key)
all_kwargs = {**self._default_params, **kwargs}
for key in list(self._default_params.keys()):
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
@@ -298,7 +295,6 @@ class ChatPremAI(BaseChatModel, BaseModel):
) -> ChatResult:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
kwargs["stop"] = stop
if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt
@@ -322,7 +318,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
**kwargs: Any,
) -> Iterator[ChatGenerationChunk]:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
kwargs["stop"] = stop
if stop is not None:
logger.warning("stop is not supported in langchain streaming")
if "system_prompt" not in kwargs:
if system_prompt is not None and system_prompt != "":