communty[patch]: Native RAG Support in Prem AI langchain (#22238)

This PR adds native RAG support in langchain premai package. The same
has been added in the docs too.
This commit is contained in:
Anindyadeep 2024-06-04 22:49:54 +05:30 committed by GitHub
parent 77ad857934
commit 7a197539aa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 174 additions and 46 deletions

View File

@ -179,10 +179,69 @@
"cell_type": "markdown", "cell_type": "markdown",
"metadata": {}, "metadata": {},
"source": [ "source": [
"> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform. \n", "> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Native RAG Support with Prem Repositories\n",
"\n", "\n",
"> Please note that the current version of ChatPremAI does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). \n", "Prem Repositories which allows users to upload documents (.txt, .pdf etc) and connect those repositories to the LLMs. You can think Prem repositories as native RAG, where each repository can be considered as a vector database. You can connect multiple repositories. You can learn more about repositories [here](https://docs.premai.io/get-started/repositories).\n",
"\n", "\n",
"Repositories are also supported in langchain premai. Here is how you can do it. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"query = \"what is the diameter of individual Galaxy\"\n",
"repository_ids = [\n",
" 1991,\n",
"]\n",
"repositories = dict(ids=repository_ids, similarity_threshold=0.3, limit=3)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First we start by defining our repository with some repository ids. Make sure that the ids are valid repository ids. You can learn more about how to get the repository id [here](https://docs.premai.io/get-started/repositories). \n",
"\n",
"> Please note: Similar like `model_name` when you invoke the argument `repositories`, then you are potentially overriding the repositories connected in the launchpad. \n",
"\n",
"Now, we connect the repository with our chat object to invoke RAG based generations. "
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"response = chat.invoke(query, max_tokens=100, repositories=repositories)\n",
"\n",
"print(response.content)\n",
"print(json.dumps(response.response_metadata, indent=4))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"> Ideally, you do not need to connect Repository IDs here to get Retrieval Augmented Generations. You can still get the same result if you have connected the repositories in prem platform. "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Streaming\n", "### Streaming\n",
"\n", "\n",
"In this section, let's see how we can stream tokens using langchain and PremAI. Here's how you do it. " "In this section, let's see how we can stream tokens using langchain and PremAI. Here's how you do it. "

View File

@ -73,7 +73,76 @@ chat.invoke(
> If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform. > If you are going to place system prompt here, then it will override your system prompt that was fixed while deploying the application from the platform.
> Please note that the current version of ChatPremAI does not support parameters: [n](https://platform.openai.com/docs/api-reference/chat/create#chat-create-n) and [stop](https://platform.openai.com/docs/api-reference/chat/create#chat-create-stop). > You can find all the optional parameters [here](https://docs.premai.io/get-started/sdk#optional-parameters). Any parameters other than [these supported parameters](https://docs.premai.io/get-started/sdk#optional-parameters) will be automatically removed before calling the model.
### Native RAG Support with Prem Repositories
Prem Repositories which allows users to upload documents (.txt, .pdf etc) and connect those repositories to the LLMs. You can think Prem repositories as native RAG, where each repository can be considered as a vector database. You can connect multiple repositories. You can learn more about repositories [here](https://docs.premai.io/get-started/repositories).
Repositories are also supported in langchain premai. Here is how you can do it.
```python
query = "what is the diameter of individual Galaxy"
repository_ids = [1991, ]
repositories = dict(
ids=repository_ids,
similarity_threshold=0.3,
limit=3
)
```
First we start by defining our repository with some repository ids. Make sure that the ids are valid repository ids. You can learn more about how to get the repository id [here](https://docs.premai.io/get-started/repositories).
> Please note: Similar like `model_name` when you invoke the argument `repositories`, then you are potentially overriding the repositories connected in the launchpad.
Now, we connect the repository with our chat object to invoke RAG based generations.
```python
response = chat.invoke(query, max_tokens=100, repositories=repositories)
print(response.content)
print(json.dumps(response.response_metadata, indent=4))
```
This is how an output looks like.
```bash
The diameters of individual galaxies range from 80,000-150,000 light-years.
{
"document_chunks": [
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173926,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.586126983165741,
"content": "n thousands\n of light-years. The diameters of individual\n galaxies range from 80,000-150,000 light\n "
},
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173925,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.4815782308578491,
"content": " for development of galaxies. A galaxy contains\n a large number of stars. Galaxies spread over\n vast distances that are measured in thousands\n "
},
{
"repository_id": 1991,
"document_id": 1307,
"chunk_id": 173916,
"document_name": "Kegy 202 Chapter 2",
"similarity_score": 0.38112708926200867,
"content": " was separated from the from each other as the balloon expands.\n solar surface. As the passing star moved away, Similarly, the distance between the galaxies is\n the material separated from the solar surface\n continued to revolve around the sun and it\n slowly condensed into planets. Sir James Jeans\n and later Sir Harold Jeffrey supported thisnot to be republishedalso found to be increasing and thereby, the\n universe is"
}
]
}
```
So, this also means that you do not need to make your own RAG pipeline when using the Prem Platform. Prem uses it's own RAG technology to deliver best in class performance for Retrieval Augmented Generations.
> Ideally, you do not need to connect Repository IDs here to get Retrieval Augmented Generations. You can still get the same result if you have connected the repositories in prem platform.
### Streaming ### Streaming
@ -102,6 +171,8 @@ for chunk in chat.stream(
This will stream tokens one after the other. This will stream tokens one after the other.
> Please note: As of now, RAG with streaming is not supported. However we still support it with our API. You can learn more about that [here](https://docs.premai.io/get-started/chat-completion-sse).
## PremEmbeddings ## PremEmbeddings
In this section we are going to dicuss how we can get access to different embedding model using `PremEmbeddings` with LangChain. Lets start by importing our modules and setting our API Key. In this section we are going to dicuss how we can get access to different embedding model using `PremEmbeddings` with LangChain. Lets start by importing our modules and setting our API Key.

View File

@ -3,6 +3,7 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
import warnings
from typing import ( from typing import (
TYPE_CHECKING, TYPE_CHECKING,
Any, Any,
@ -104,7 +105,18 @@ def _response_to_result(
text=content, message=ChatMessage(role=role, content=content) text=content, message=ChatMessage(role=role, content=content)
) )
) )
return ChatResult(generations=generations)
if response.document_chunks is not None:
return ChatResult(
generations=generations,
llm_output={
"document_chunks": [
chunk.to_dict() for chunk in response.document_chunks
]
},
)
else:
return ChatResult(generations=generations, llm_output={"document_chunks": None})
def _convert_delta_response_to_message_chunk( def _convert_delta_response_to_message_chunk(
@ -118,10 +130,6 @@ def _convert_delta_response_to_message_chunk(
role = _delta.get("role", "") # type: ignore role = _delta.get("role", "") # type: ignore
content = _delta.get("content", "") # type: ignore content = _delta.get("content", "") # type: ignore
additional_kwargs: Dict = {} additional_kwargs: Dict = {}
if role is None or role == "":
raise ChatPremAPIError("Role can not be None. Please check the response")
finish_reasons: Optional[str] = response.choices[0].finish_reason finish_reasons: Optional[str] = response.choices[0].finish_reason
if role == "user" or default_class == HumanMessageChunk: if role == "user" or default_class == HumanMessageChunk:
@ -185,17 +193,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
If model name is other than default model then it will override the calls If model name is other than default model then it will override the calls
from the model deployed from launchpad.""" from the model deployed from launchpad."""
session_id: Optional[str] = None
"""The ID of the session to use. It helps to track the chat history."""
temperature: Optional[float] = None temperature: Optional[float] = None
"""Model temperature. Value should be >= 0 and <= 1.0""" """Model temperature. Value should be >= 0 and <= 1.0"""
top_p: Optional[float] = None
"""top_p adjusts the number of choices for each predicted tokens based on
cumulative probabilities. Value should be ranging between 0.0 and 1.0.
"""
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
"""The maximum number of tokens to generate""" """The maximum number of tokens to generate"""
@ -209,30 +209,14 @@ class ChatPremAI(BaseChatModel, BaseModel):
Changing the system prompt would override the default system prompt. Changing the system prompt would override the default system prompt.
""" """
repositories: Optional[dict] = None
"""Add valid repository ids. This will be overriding existing connected
repositories (if any) and will use RAG with the connected repos.
"""
streaming: Optional[bool] = False streaming: Optional[bool] = False
"""Whether to stream the responses or not.""" """Whether to stream the responses or not."""
tools: Optional[Dict[str, Any]] = None
"""A list of tools the model may call. Currently, only functions are
supported as a tool"""
frequency_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based"""
presence_penalty: Optional[float] = None
"""Number between -2.0 and 2.0. Positive values penalize new tokens based
on whether they appear in the text so far."""
logit_bias: Optional[dict] = None
"""JSON object that maps tokens to an associated bias value from -100 to 100."""
stop: Optional[Union[str, List[str]]] = None
"""Up to 4 sequences where the API will stop generating further tokens."""
seed: Optional[int] = None
"""This feature is in Beta. If specified, our system will make a best effort
to sample deterministically."""
client: Any client: Any
class Config: class Config:
@ -268,21 +252,34 @@ class ChatPremAI(BaseChatModel, BaseModel):
@property @property
def _default_params(self) -> Dict[str, Any]: def _default_params(self) -> Dict[str, Any]:
# FIXME: n and stop is not supported, so hardcoding to current default value
return { return {
"model": self.model, "model": self.model,
"system_prompt": self.system_prompt, "system_prompt": self.system_prompt,
"top_p": self.top_p,
"temperature": self.temperature, "temperature": self.temperature,
"logit_bias": self.logit_bias,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
"presence_penalty": self.presence_penalty, "repositories": self.repositories,
"frequency_penalty": self.frequency_penalty,
"seed": self.seed,
"stop": None,
} }
def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]: def _get_all_kwargs(self, **kwargs: Any) -> Dict[str, Any]:
kwargs_to_ignore = [
"top_p",
"tools",
"frequency_penalty",
"presence_penalty",
"logit_bias",
"stop",
"seed",
]
keys_to_remove = []
for key in kwargs:
if key in kwargs_to_ignore:
warnings.warn(f"WARNING: Parameter {key} is not supported in kwargs.")
keys_to_remove.append(key)
for key in keys_to_remove:
kwargs.pop(key)
all_kwargs = {**self._default_params, **kwargs} all_kwargs = {**self._default_params, **kwargs}
for key in list(self._default_params.keys()): for key in list(self._default_params.keys()):
if all_kwargs.get(key) is None or all_kwargs.get(key) == "": if all_kwargs.get(key) is None or all_kwargs.get(key) == "":
@ -298,7 +295,6 @@ class ChatPremAI(BaseChatModel, BaseModel):
) -> ChatResult: ) -> ChatResult:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) # type: ignore
kwargs["stop"] = stop
if system_prompt is not None and system_prompt != "": if system_prompt is not None and system_prompt != "":
kwargs["system_prompt"] = system_prompt kwargs["system_prompt"] = system_prompt
@ -322,7 +318,9 @@ class ChatPremAI(BaseChatModel, BaseModel):
**kwargs: Any, **kwargs: Any,
) -> Iterator[ChatGenerationChunk]: ) -> Iterator[ChatGenerationChunk]:
system_prompt, messages_to_pass = _messages_to_prompt_dict(messages) system_prompt, messages_to_pass = _messages_to_prompt_dict(messages)
kwargs["stop"] = stop
if stop is not None:
logger.warning("stop is not supported in langchain streaming")
if "system_prompt" not in kwargs: if "system_prompt" not in kwargs:
if system_prompt is not None and system_prompt != "": if system_prompt is not None and system_prompt != "":