mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-07 22:11:51 +00:00
communty[patch]: Native RAG Support in Prem AI langchain (#22238)
This PR adds native RAG support in langchain premai package. The same has been added in the docs too.
This commit is contained in:
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
Any,
|
||||
@@ -104,7 +105,18 @@ def _response_to_result(
|
||||
text=content, message=ChatMessage(role=role, content=content)
|
||||
)
|
||||
)
|
||||
return ChatResult(generations=generations)
|
||||
|
||||
if response.document_chunks is not None:
|
||||
return ChatResult(
|
||||
generations=generations,
|
||||
llm_output={
|
||||
"document_chunks": [
|
||||
chunk.to_dict() for chunk in response.document_chunks
|
||||
]
|
||||
},
|
||||
)
|
||||
else:
|
||||
return ChatResult(generations=generations, llm_output={"document_chunks": None})
|
||||
|
||||
|
||||
def _convert_delta_response_to_message_chunk(
|
||||
@@ -118,10 +130,6 @@ def _convert_delta_response_to_message_chunk(
|
||||
role = _delta.get("role", "") # type: ignore
|
||||
content = _delta.get("content", "") # type: ignore
|
||||
additional_kwargs: Dict = {}
|
||||
|
||||
if role is None or role == "":
|
||||
raise ChatPremAPIError("Role can not be None. Please check the response")
|
||||
|
||||
finish_reasons: Optional[str] = response.choices[0].finish_reason
|
||||
|
||||
if role == "user" or default_class == HumanMessageChunk:
|
||||
@@ -185,17 +193,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
If model name is other than default model then it will override the calls
|
||||
from the model deployed from launchpad."""
|
||||
|
||||
session_id: Optional[str] = None
|
||||
"""The ID of the session to use. It helps to track the chat history."""
|
||||
|
||||
temperature: Optional[float] = None
|
||||
"""Model temperature. Value should be >= 0 and <= 1.0"""
|
||||
|
||||
top_p: Optional[float] = None
|
||||
"""top_p adjusts the number of choices for each predicted tokens based on
|
||||
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
|
||||
"""
|
||||
|
||||
max_tokens: Optional[int] = None
|
||||
"""The maximum number of tokens to generate"""
|
||||
|
||||
@@ -209,30 +209,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
Changing the system prompt would override the default system prompt.
|
||||
"""
|
||||
|
||||
repositories: Optional[dict] = None
|
||||
"""Add valid repository ids. This will be overriding existing connected
|
||||
repositories (if any) and will use RAG with the connected repos.
|
||||
"""
|
||||
|
||||
streaming: Optional[bool] = False
|
||||
"""Whether to stream the responses or not."""
|
||||
|
||||
tools: Optional[Dict[str, Any]] = None
|
||||
"""A list of tools the model may call. Currently, only functions are
|
||||
supported as a tool"""
|
||||
|
||||
frequency_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based"""
|
||||
|
||||
presence_penalty: Optional[float] = None
|
||||
"""Number between -2.0 and 2.0. Positive values penalize new tokens based
|
||||
on whether they appear in the text so far."""
|
||||
|
||||
logit_bias: Optional[dict] = None
|
||||
"""JSON object that maps tokens to an associated bias value from -100 to 100."""
|
||||
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
"""Up to 4 sequences where the API will stop generating further tokens."""
|
||||
|
||||
seed: Optional[int] = None
|
||||
"""This feature is in Beta. If specified, our system will make a best effort
|
||||
to sample deterministically."""
|
||||
|
||||
client: Any
|
||||
|
||||
class Config:
|
||||
@@ -268,21 +252,34 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
|
||||
@property
|
||||
def _default_params(self) -> Dict[str, Any]:
|
||||
# FIXME: n and stop is not supported, so hardcoding to current default value
|
||||
return {
|
||||
"model": self.model,
|
||||
"system_prompt": self.system_prompt,
|
||||
"top_p": self.top_p,
|
||||
"temperature": self.temperature,
|
||||
"logit_bias": self.logit_bias,
|
||||
"max_tokens": self.max_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"seed": self.seed,
|
||||
"stop": None,
|
||||
"repositories": self.repositories,
|
||||
}
|
||||
|
||||
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
|
||||
kwargs_to_ignore = [
|
||||
"top_p",
|
||||
"tools",
|
||||
"frequency_penalty",
|
||||
"presence_penalty",
|
||||
"logit_bias",
|
||||
"stop",
|
||||
"seed",
|
||||
]
|
||||
keys_to_remove = []
|
||||
|
||||
for key in kwargs:
|
||||
if key in kwargs_to_ignore:
|
||||
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
|
||||
keys_to_remove.append(key)
|
||||
|
||||
for key in keys_to_remove:
|
||||
kwargs.pop(key)
|
||||
|
||||
all_kwargs = {**self._default_params, **kwargs}
|
||||
for key in list(self._default_params.keys()):
|
||||
if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
|
||||
@@ -298,7 +295,6 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
) -> ChatResult:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
|
||||
|
||||
kwargs["stop"] = stop
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
kwargs["system_prompt"] = system_prompt
|
||||
|
||||
@@ -322,7 +318,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
|
||||
**kwargs: Any,
|
||||
) -> Iterator[ChatGenerationChunk]:
|
||||
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
|
||||
kwargs["stop"] = stop
|
||||
|
||||
if stop is not None:
|
||||
logger.warning("stop is not supported in langchain streaming")
|
||||
|
||||
if "system_prompt" not in kwargs:
|
||||
if system_prompt is not None and system_prompt != "":
|
||||
|
Reference in New Issue
Block a user