mirror of
https://github.com/hwchase17/langchain.git
synced 2025-09-13 13:36:15 +00:00
community[feat]: Adds LLMLingua as a document compressor (#17711)
**Description**: This PR adds support for using the [LLMLingua project ](https://github.com/microsoft/LLMLingua) especially the LongLLMLingua (Enhancing Large Language Model Inference via Prompt Compression) as a document compressor / transformer. The LLMLingua project is an interesting project that can greatly improve RAG system by compressing prompts and contexts while keeping their semantic relevance. **Issue**: https://github.com/microsoft/LLMLingua/issues/31 **Dependencies**: [llmlingua](https://pypi.org/project/llmlingua/) @baskaryan --------- Co-authored-by: Ayodeji Ayibiowu <ayodeji.ayibiowu@getinge.com> Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
@@ -0,0 +1,7 @@
|
||||
from langchain_community.document_compressors.llmlingua_filter import (
|
||||
LLMLinguaCompressor,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"LLMLinguaCompressor",
|
||||
]
|
@@ -0,0 +1,185 @@
|
||||
# LLM Lingua Document Compressor
|
||||
|
||||
import re
|
||||
from typing import Any, Dict, List, Optional, Sequence, Tuple
|
||||
|
||||
from langchain_core.callbacks import Callbacks
|
||||
from langchain_core.documents import Document
|
||||
from langchain_core.documents.compressor import (
|
||||
BaseDocumentCompressor,
|
||||
)
|
||||
from langchain_core.pydantic_v1 import root_validator
|
||||
|
||||
DEFAULT_LLM_LINGUA_INSTRUCTION = (
|
||||
"Given this documents, please answer the final question"
|
||||
)
|
||||
|
||||
|
||||
class LLMLinguaCompressor(BaseDocumentCompressor):
|
||||
"""
|
||||
Compress using LLMLingua Project.
|
||||
|
||||
https://github.com/microsoft/LLMLingua
|
||||
"""
|
||||
|
||||
# Pattern to match ref tags at the beginning or end of the string,
|
||||
# allowing for malformed tags
|
||||
_pattern_beginning = re.compile(r"\A(?:<#)?(?:ref)?(\d+)(?:#>?)?")
|
||||
_pattern_ending = re.compile(r"(?:<#)?(?:ref)?(\d+)(?:#>?)?\Z")
|
||||
|
||||
model_name: str = "NousResearch/Llama-2-7b-hf"
|
||||
"""The hugging face model to use"""
|
||||
device_map: str = "cuda"
|
||||
"""The device to use for llm lingua"""
|
||||
target_token: int = 300
|
||||
"""The target number of compressed tokens"""
|
||||
rank_method: str = "longllmlingua"
|
||||
"""The ranking method to use"""
|
||||
model_config: dict = {}
|
||||
"""Custom configuration for the model"""
|
||||
open_api_config: dict = {}
|
||||
"""open_api configuration"""
|
||||
instruction: str = DEFAULT_LLM_LINGUA_INSTRUCTION
|
||||
"""The instruction for the LLM"""
|
||||
additional_compress_kwargs: dict = {
|
||||
"condition_compare": True,
|
||||
"condition_in_question": "after",
|
||||
"context_budget": "+100",
|
||||
"reorder_context": "sort",
|
||||
"dynamic_context_compression_ratio": 0.4,
|
||||
}
|
||||
"""Extra compression arguments"""
|
||||
lingua: Any
|
||||
"""The instance of the llm linqua"""
|
||||
|
||||
@root_validator(pre=True)
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that the python package exists in environment."""
|
||||
try:
|
||||
from llmlingua import PromptCompressor
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import llmlingua python package. "
|
||||
"Please install it with `pip install llmlingua`."
|
||||
)
|
||||
if not values.get("lingua"):
|
||||
values["lingua"] = PromptCompressor(
|
||||
model_name=values.get("model_name", {}),
|
||||
device_map=values.get("device_map", {}),
|
||||
model_config=values.get("model_config", {}),
|
||||
open_api_config=values.get("open_api_config", {}),
|
||||
)
|
||||
return values
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = "forbid"
|
||||
arbitrary_types_allowed = True
|
||||
|
||||
@staticmethod
|
||||
def _format_context(docs: Sequence[Document]) -> List[str]:
|
||||
"""
|
||||
Format the output of the retriever by including
|
||||
special ref tags for tracking the metadata after compression
|
||||
"""
|
||||
formatted_docs = []
|
||||
for i, doc in enumerate(docs):
|
||||
content = doc.page_content.replace("\n\n", "\n")
|
||||
doc_string = f"\n\n<#ref{i}#> {content} <#ref{i}#>\n\n"
|
||||
formatted_docs.append(doc_string)
|
||||
return formatted_docs
|
||||
|
||||
def extract_ref_id_tuples_and_clean(
|
||||
self, contents: List[str]
|
||||
) -> List[Tuple[str, int]]:
|
||||
"""
|
||||
Extracts reference IDs from the contents and cleans up the ref tags.
|
||||
|
||||
This function processes a list of strings, searching for reference ID tags
|
||||
at the beginning and end of each string. When a ref tag is found, it is
|
||||
removed from the string, and its ID is recorded. If no ref ID is found,
|
||||
a generic ID of "-1" is assigned.
|
||||
|
||||
The search for ref tags is performed only at the beginning and
|
||||
end of the string, with the assumption that there will
|
||||
be at most one ref ID per string. Malformed ref tags are
|
||||
handled gracefully.
|
||||
|
||||
Args:
|
||||
contents (List[str]): A list of contents to be processed.
|
||||
|
||||
Returns:
|
||||
List[Tuple[str, int]]: The cleaned string and the associated ref ID.
|
||||
|
||||
Examples:
|
||||
>>> strings_list = [
|
||||
'<#ref0#> Example content <#ref0#>',
|
||||
'Content with no ref ID.'
|
||||
]
|
||||
>>> extract_ref_id_tuples_and_clean(strings_list)
|
||||
[('Example content', 0), ('Content with no ref ID.', -1)]
|
||||
"""
|
||||
ref_id_tuples = []
|
||||
for content in contents:
|
||||
clean_string = content.strip()
|
||||
if not clean_string:
|
||||
continue
|
||||
|
||||
# Search for ref tags at the beginning and the end of the string
|
||||
ref_id = None
|
||||
for pattern in [self._pattern_beginning, self._pattern_ending]:
|
||||
match = pattern.search(clean_string)
|
||||
if match:
|
||||
ref_id = match.group(1)
|
||||
clean_string = pattern.sub("", clean_string).strip()
|
||||
# Convert ref ID to int or use -1 if not found
|
||||
ref_id_to_use = int(ref_id) if ref_id and ref_id.isdigit() else -1
|
||||
ref_id_tuples.append((clean_string, ref_id_to_use))
|
||||
|
||||
return ref_id_tuples
|
||||
|
||||
def compress_documents(
|
||||
self,
|
||||
documents: Sequence[Document],
|
||||
query: str,
|
||||
callbacks: Optional[Callbacks] = None,
|
||||
) -> Sequence[Document]:
|
||||
"""
|
||||
Compress documents using BAAI/bge-reranker models.
|
||||
|
||||
Args:
|
||||
documents: A sequence of documents to compress.
|
||||
query: The query to use for compressing the documents.
|
||||
callbacks: Callbacks to run during the compression process.
|
||||
|
||||
Returns:
|
||||
A sequence of compressed documents.
|
||||
"""
|
||||
if len(documents) == 0: # to avoid empty api call
|
||||
return []
|
||||
|
||||
compressed_prompt = self.lingua.compress_prompt(
|
||||
context=self._format_context(documents),
|
||||
instruction=self.instruction,
|
||||
question=query,
|
||||
target_token=self.target_token,
|
||||
rank_method=self.rank_method,
|
||||
concate_question=False,
|
||||
add_instruction=True,
|
||||
**self.additional_compress_kwargs,
|
||||
)
|
||||
compreseed_context = compressed_prompt["compressed_prompt"].split("\n\n")[1:]
|
||||
|
||||
extracted_metadata = self.extract_ref_id_tuples_and_clean(compreseed_context)
|
||||
|
||||
compressed_docs: List[Document] = []
|
||||
|
||||
for context, index in extracted_metadata:
|
||||
if index == -1 or index >= len(documents):
|
||||
doc = Document(page_content=context)
|
||||
else:
|
||||
doc = Document(page_content=context, metadata=documents[index].metadata)
|
||||
compressed_docs.append(doc)
|
||||
|
||||
return compressed_docs
|
Reference in New Issue
Block a user