mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-08-21 17:03:49 +00:00
Updated api for ocr component
This commit is contained in:
parent
fbd298212f
commit
175b4e29ac
5136
poetry.lock
generated
5136
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@ -11,8 +11,7 @@ from typing import Any
|
||||
|
||||
from llama_index.core.data_structs import IndexDict
|
||||
from llama_index.core.embeddings.utils import EmbedType
|
||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage, SimpleKeywordTableIndex
|
||||
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
|
||||
from llama_index.core.indices.base import BaseIndex
|
||||
from llama_index.core.ingestion import run_transformations
|
||||
from llama_index.core.schema import BaseNode, Document, TransformComponent
|
||||
@ -84,7 +83,7 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||
except ValueError:
|
||||
# There are no index in the storage context, creating a new one
|
||||
logger.info("Creating a new vector store index")
|
||||
index = VectorStoreIndex1.from_documents(
|
||||
index = VectorStoreIndex.from_documents(
|
||||
[],
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
@ -93,17 +92,6 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
|
||||
transformations=self.transformations,
|
||||
)
|
||||
index.storage_context.persist(persist_dir=local_data_path)
|
||||
|
||||
keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||
[],
|
||||
storage_context=self.storage_context,
|
||||
store_nodes_override=True, # Force store nodes in index and document stores
|
||||
show_progress=self.show_progress,
|
||||
transformations=self.transformations,
|
||||
llm=
|
||||
)
|
||||
# Store the keyword index in the vector store
|
||||
index.keyword_index = keyword_index
|
||||
return index
|
||||
|
||||
def _save_index(self) -> None:
|
||||
|
@ -1,57 +1,58 @@
|
||||
# from paddleocr import PaddleOCR
|
||||
import io
|
||||
from typing import Union
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
from doctr.models import ocr_predictor
|
||||
from doctr.io import DocumentFile
|
||||
from injector import singleton, inject
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
from doctr.models import ocr_predictor
|
||||
from injector import inject, singleton
|
||||
from pdf2image import convert_from_bytes
|
||||
|
||||
# device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
device = "cpu"
|
||||
|
||||
@singleton
|
||||
class GetOCRText:
|
||||
@inject
|
||||
def __init__(self) -> None:
|
||||
self._image = None
|
||||
# self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
|
||||
self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True).to(device)
|
||||
|
||||
def _preprocess_image(self, img):
|
||||
resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC)
|
||||
resized_image = cv2.resize(img, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
|
||||
gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(gray_image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
_, binary = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
return binary
|
||||
|
||||
## paddleOCR
|
||||
# def extract_text(self, cell_image):
|
||||
# text = ""
|
||||
# self._image = cell_image
|
||||
# preprocessd_image = self._preprocess_image(self._image)
|
||||
# results = self.ocr.ocr(preprocessd_image, cls=True)
|
||||
# print(results)
|
||||
# if len(results) > 0:
|
||||
# for result in results[0]:
|
||||
# text += f"{result[-1][0]} "
|
||||
# else:
|
||||
# text = ""
|
||||
# return text
|
||||
|
||||
## docTR OCR
|
||||
def extract_text(self, cell_image=None, image_file=False, file_path=None):
|
||||
def extract_text(self, cell_image: Union[None, bytes] = None, image_file: bool = False, file_path: Union[None, str] = None):
|
||||
text = ""
|
||||
|
||||
if image_file:
|
||||
if file_path is None:
|
||||
raise ValueError("file_path must be provided when image_file is True.")
|
||||
pdf_file = DocumentFile.from_images(file_path)
|
||||
result = self.doctr(pdf_file)
|
||||
output = result.export()
|
||||
else:
|
||||
self._image = cell_image
|
||||
preprocessd_image = self._preprocess_image(self._image)
|
||||
result = self.doctr([self._image])
|
||||
output = result.export()
|
||||
if cell_image is None:
|
||||
raise ValueError("cell_image must be provided when image_file is False.")
|
||||
|
||||
if isinstance(cell_image, bytes):
|
||||
images = convert_from_bytes(cell_image)
|
||||
pdf_file = DocumentFile.from_images(images)
|
||||
result = self.doctr(pdf_file)
|
||||
else:
|
||||
self._image = cell_image
|
||||
preprocessed_image = self._preprocess_image(self._image)
|
||||
result = self.doctr([preprocessed_image])
|
||||
output = result.export()
|
||||
|
||||
for obj1 in output['pages'][0]["blocks"]:
|
||||
for obj2 in obj1["lines"]:
|
||||
for obj3 in obj2["words"]:
|
||||
text += (f"{obj3['value']} ").replace("\n", "")
|
||||
|
||||
text = text + "\n"
|
||||
text += "\n"
|
||||
text += "\n"
|
||||
if text:
|
||||
return text
|
||||
return text.strip()
|
||||
return " "
|
@ -34,10 +34,9 @@ async def save_uploaded_file(file: UploadFile, upload_dir: str):
|
||||
|
||||
|
||||
async def process_images_and_generate_doc(request: Request, pdf_path: str, upload_dir: str):
|
||||
doc = Document()
|
||||
|
||||
ocr = request.state.injector.get(GetOCRText)
|
||||
img_tab = request.state.injector.get(ImageToTable)
|
||||
# img_tab = request.state.injector.get(ImageToTable)
|
||||
pdf_writer = fitz.open()
|
||||
pdf_doc = fitz.open(pdf_path)
|
||||
|
||||
for page_index in range(len(pdf_doc)):
|
||||
@ -55,18 +54,19 @@ async def process_images_and_generate_doc(request: Request, pdf_path: str, uploa
|
||||
pix = fitz.Pixmap(fitz.csRGB, pix)(
|
||||
"RGB", [pix.width, pix.height], pix.samples)
|
||||
|
||||
image_path = f"page_{page_index}-image_{image_index}.png"
|
||||
image_path = f"temp_page_{page_index}_image_{image_index}.png"
|
||||
pix.save(image_path)
|
||||
|
||||
extracted_text = ocr.extract_text(
|
||||
image_file=True, file_path=image_path)
|
||||
doc.add_paragraph(extracted_text)
|
||||
table_data = img_tab.table_to_csv(image_path)
|
||||
doc.add_paragraph(table_data)
|
||||
# Create a new page with the same dimensions as the original page
|
||||
pdf_page = pdf_writer.new_page(width=page.rect.width, height=page.rect.height)
|
||||
pdf_page.insert_text((10, 10), extracted_text, fontsize=9)
|
||||
os.remove(image_path)
|
||||
|
||||
save_path = os.path.join(
|
||||
upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}_ocr.docx")
|
||||
doc.save(save_path)
|
||||
save_path = os.path.join(upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}.pdf")
|
||||
pdf_writer.save(save_path)
|
||||
pdf_writer.close()
|
||||
return save_path
|
||||
|
||||
|
||||
@ -82,6 +82,7 @@ async def process_pdf_ocr(
|
||||
try:
|
||||
pdf_path = await save_uploaded_file(file, UPLOAD_DIR)
|
||||
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
|
||||
print("FILE PATH:", ocr_doc_path)
|
||||
ingested_documents = await common_ingest_logic(
|
||||
request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=None, log_audit=log_audit, departments=departments
|
||||
)
|
||||
@ -111,26 +112,6 @@ async def process_ocr(
|
||||
detail=f"There was an error processing OCR: {e}"
|
||||
)
|
||||
|
||||
async def process_both_ocr(
|
||||
request: Request,
|
||||
pdf_path: str
|
||||
):
|
||||
UPLOAD_DIR = OCR_UPLOAD
|
||||
try:
|
||||
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
|
||||
ingested_ocr_documents = await ingest(request=request, file_path=ocr_doc_path) # ingest ocr
|
||||
ingested_documents = await ingest(request=request, file_path=pdf_path) # ingest pdf
|
||||
return ingested_documents
|
||||
|
||||
except Exception as e:
|
||||
print(traceback.print_exc())
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail=f"There was an error processing OCR: {e}"
|
||||
)
|
||||
|
||||
|
||||
|
||||
async def process_both(
|
||||
request: Request,
|
||||
db: Session,
|
||||
@ -170,18 +151,3 @@ async def get_pdf_ocr_wrapper(
|
||||
)
|
||||
):
|
||||
return await process_pdf_ocr(request, db, file, current_user, log_audit, departments)
|
||||
|
||||
|
||||
@pdf_router.post("/both")
|
||||
async def get_both_wrapper(
|
||||
request: Request,
|
||||
departments: schemas.DocumentDepartmentList = Depends(),
|
||||
db: Session = Depends(deps.get_db),
|
||||
log_audit: models.Audit = Depends(deps.get_audit_logger),
|
||||
file: UploadFile = File(...),
|
||||
current_user: models.User = Security(
|
||||
deps.get_current_user,
|
||||
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
|
||||
)
|
||||
):
|
||||
return await process_both(request, db, file, current_user, log_audit, departments)
|
||||
|
@ -1,55 +0,0 @@
|
||||
|
||||
|
||||
# import QueryBundle
|
||||
from llama_index.core import QueryBundle
|
||||
|
||||
# import NodeWithScore
|
||||
from llama_index.core.schema import NodeWithScore
|
||||
|
||||
# Retrievers
|
||||
from llama_index.core.retrievers import (
|
||||
BaseRetriever,
|
||||
VectorIndexRetriever,
|
||||
KeywordTableSimpleRetriever,
|
||||
)
|
||||
|
||||
from typing import List
|
||||
|
||||
|
||||
class CustomRetriever(BaseRetriever):
|
||||
"""Custom retriever that performs both semantic search and hybrid search."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vector_retriever: VectorIndexRetriever,
|
||||
keyword_retriever: KeywordTableSimpleRetriever,
|
||||
mode: str = "AND",
|
||||
) -> None:
|
||||
"""Init params."""
|
||||
|
||||
self._vector_retriever = vector_retriever
|
||||
self._keyword_retriever = keyword_retriever
|
||||
if mode not in ("AND", "OR"):
|
||||
raise ValueError("Invalid mode.")
|
||||
self._mode = mode
|
||||
super().__init__()
|
||||
|
||||
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
|
||||
"""Retrieve nodes given query."""
|
||||
|
||||
vector_nodes = self._vector_retriever.retrieve(query_bundle)
|
||||
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
|
||||
|
||||
vector_ids = {n.node.node_id for n in vector_nodes}
|
||||
keyword_ids = {n.node.node_id for n in keyword_nodes}
|
||||
|
||||
combined_dict = {n.node.node_id: n for n in vector_nodes}
|
||||
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
|
||||
|
||||
if self._mode == "AND":
|
||||
retrieve_ids = vector_ids.intersection(keyword_ids)
|
||||
else:
|
||||
retrieve_ids = vector_ids.union(keyword_ids)
|
||||
|
||||
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
|
||||
return retrieve_nodes
|
@ -3,14 +3,12 @@ import typing
|
||||
|
||||
from injector import inject, singleton
|
||||
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
|
||||
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||
from llama_index.core.vector_stores.types import (
|
||||
FilterCondition,
|
||||
MetadataFilter,
|
||||
MetadataFilters,
|
||||
VectorStore,
|
||||
)
|
||||
from private_gpt.utils.vector_store import VectorStoreIndex1
|
||||
|
||||
from private_gpt.open_ai.extensions.context_filter import ContextFilter
|
||||
from private_gpt.paths import local_data_path
|
||||
@ -132,21 +130,12 @@ class VectorStoreComponent:
|
||||
|
||||
def get_retriever(
|
||||
self,
|
||||
index: VectorStoreIndex1,
|
||||
index: VectorStoreIndex,
|
||||
context_filter: ContextFilter | None = None,
|
||||
similarity_top_k: int = 2,
|
||||
) -> VectorIndexRetriever:
|
||||
# This way we support qdrant (using doc_ids) and the rest (using filters)
|
||||
|
||||
# from llama_index.core import get_response_synthesizer
|
||||
# from llama_index.core.query_engine import RetrieverQueryEngine
|
||||
from .retriever import CustomRetriever
|
||||
from llama_index.core.retrievers import (
|
||||
VectorIndexRetriever,
|
||||
KeywordTableSimpleRetriever,
|
||||
)
|
||||
|
||||
vector_retriever = VectorIndexRetriever(
|
||||
return VectorIndexRetriever(
|
||||
index=index,
|
||||
similarity_top_k=similarity_top_k,
|
||||
doc_ids=context_filter.docs_ids if context_filter else None,
|
||||
@ -156,19 +145,6 @@ class VectorStoreComponent:
|
||||
else None
|
||||
),
|
||||
)
|
||||
keyword_retriever = KeywordTableSimpleRetriever(index=index.keyword_index)
|
||||
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)
|
||||
|
||||
# define response synthesizer
|
||||
# response_synthesizer = get_response_synthesizer()
|
||||
|
||||
# # assemble query engine
|
||||
# custom_query_engine = RetrieverQueryEngine(
|
||||
# retriever=custom_retriever,
|
||||
# response_synthesizer=response_synthesizer,
|
||||
# )
|
||||
|
||||
return custom_retriever
|
||||
|
||||
def close(self) -> None:
|
||||
if hasattr(self.vector_store.client, "close"):
|
||||
|
@ -5,7 +5,7 @@ from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
|
||||
from llama_index.core.chat_engine.types import (
|
||||
BaseChatEngine,
|
||||
)
|
||||
from llama_index.core.indices import VectorStoreIndex, SimpleKeywordTableIndex
|
||||
from llama_index.core.indices import VectorStoreIndex
|
||||
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
|
||||
from llama_index.core.llms import ChatMessage, MessageRole
|
||||
from llama_index.core.postprocessor import (
|
||||
@ -99,12 +99,6 @@ class ChatService:
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
self.keyword_index = SimpleKeywordTableIndex.from_documents(
|
||||
vector_store_component.vector_store,
|
||||
storage_context=self.storage_context,
|
||||
embed_model=embedding_component.embedding_model,
|
||||
show_progress=True,
|
||||
)
|
||||
|
||||
def _chat_engine(
|
||||
self,
|
||||
@ -116,7 +110,6 @@ class ChatService:
|
||||
if use_context:
|
||||
vector_index_retriever = self.vector_store_component.get_retriever(
|
||||
index=self.index,
|
||||
keyword_index=self.keyword_index,
|
||||
context_filter=context_filter,
|
||||
similarity_top_k=self.settings.rag.similarity_top_k,
|
||||
)
|
||||
@ -195,17 +188,17 @@ class ChatService:
|
||||
)
|
||||
system_prompt = (
|
||||
"""
|
||||
You are QuickGPT, a helpful assistant by Quickfox Consulting.
|
||||
You are a helpful assistant named QuickGPT by Quickfox Consulting.
|
||||
Your responses must be strictly and exclusively based on the context documents provided.
|
||||
|
||||
Responses should be based on the context documents provided
|
||||
and should be relevant, informative, and easy to understand.
|
||||
You should aim to deliver high-quality responses that are
|
||||
respectful and helpful, using clear and concise language.
|
||||
Avoid providing information outside of the context documents unless
|
||||
it is necessary for clarity or completeness. Focus on providing
|
||||
accurate and reliable answers based on the given context.
|
||||
If answer is not in the context documents, just say I don't have answer
|
||||
in respectful way.
|
||||
You are not allowed to use any information, knowledge, or external sources outside of the given context documents.
|
||||
If the answer to a query is not present in the context documents,
|
||||
you should respond with "I do not have enough information in the provided context to answer this question."
|
||||
|
||||
Your responses should be relevant, informative, and easy to understand.
|
||||
Aim to deliver high-quality answers that are respectful and helpful, using clear and concise language.
|
||||
Focus on providing accurate and reliable answers based solely on the given context.
|
||||
Do not make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the context documents.
|
||||
"""
|
||||
)
|
||||
chat_history = (
|
||||
|
@ -253,34 +253,6 @@ async def common_ingest_logic(
|
||||
f.write(file.read())
|
||||
file.seek(0)
|
||||
ingested_documents = service.ingest_bin_data(file_name, file)
|
||||
# Handling Original File
|
||||
if original_file:
|
||||
try:
|
||||
print("ORIGINAL PDF FILE PATH IS :: ", original_file)
|
||||
file_name = Path(original_file).name
|
||||
upload_path = Path(f"{UPLOAD_DIR}/{file_name}")
|
||||
|
||||
if file_name is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_400_BAD_REQUEST,
|
||||
detail="No file name provided",
|
||||
)
|
||||
|
||||
await create_documents(db, file_name, current_user, departments, log_audit)
|
||||
|
||||
with open(upload_path, "wb") as f:
|
||||
with open(original_file, "rb") as original_file_reader:
|
||||
f.write(original_file_reader.read())
|
||||
|
||||
with open(upload_path, "rb") as f:
|
||||
ingested_documents = service.ingest_bin_data(file_name, f)
|
||||
except Exception as e:
|
||||
print(traceback.print_exc())
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal Server Error: Unable to ingest file.",
|
||||
)
|
||||
|
||||
logger.info(
|
||||
f"{file_name} is uploaded by the {current_user.username}.")
|
||||
|
||||
|
@ -17,7 +17,7 @@ from private_gpt.users.core.config import settings
|
||||
from private_gpt.users import crud, models, schemas
|
||||
from private_gpt.server.ingest.ingest_router import create_documents, ingest
|
||||
from private_gpt.users.models.document import MakerCheckerActionType, MakerCheckerStatus
|
||||
from private_gpt.components.ocr_components.table_ocr_api import process_both_ocr, process_ocr
|
||||
from private_gpt.components.ocr_components.table_ocr_api import process_ocr
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
router = APIRouter(prefix='/documents', tags=['Documents'])
|
||||
@ -385,8 +385,6 @@ async def verify_documents(
|
||||
|
||||
if document.doc_type_id == 2: # For OCR
|
||||
return await process_ocr(request, unchecked_path)
|
||||
elif document.doc_type_id == 3: # For BOTH
|
||||
return await process_both_ocr(request, unchecked_path)
|
||||
else:
|
||||
return await ingest(request, unchecked_path) # For pdf
|
||||
|
||||
|
@ -1,13 +0,0 @@
|
||||
from llama_index.core.indices.vector_store import VectorStoreIndex
|
||||
from llama_index.core.indices import SimpleKeywordTableIndex
|
||||
|
||||
class VectorStoreIndex1(VectorStoreIndex):
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.keyword_index = None
|
||||
|
||||
def set_keyword_index(self, keyword_index: SimpleKeywordTableIndex):
|
||||
self.keyword_index = keyword_index
|
||||
|
||||
def get_keyword_index(self) -> SimpleKeywordTableIndex:
|
||||
return self.keyword_index
|
@ -68,6 +68,8 @@ openpyxl = "^3.1.2"
|
||||
pandas = "^2.2.2"
|
||||
fastapi-pagination = "^0.12.23"
|
||||
xlsxwriter = "^3.2.0"
|
||||
pdf2image = "^1.17.0"
|
||||
pymupdf = "^1.24.4"
|
||||
|
||||
[tool.poetry.extras]
|
||||
ui = ["gradio"]
|
||||
|
@ -57,8 +57,8 @@ rag:
|
||||
llamacpp:
|
||||
# llm_hf_repo_id: bartowski/Meta-Llama-3-8B-Instruct-GGUF
|
||||
# llm_hf_model_file: Meta-Llama-3-8B-Instruct-Q6_K.gguf
|
||||
llm_hf_repo_id: NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF
|
||||
llm_hf_model_file: Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q6_K.gguf
|
||||
llm_hf_repo_id: qwp4w3hyb/Hermes-2-Pro-Llama-3-8B-iMat-GGUF
|
||||
llm_hf_model_file: hermes-2-pro-llama-3-8b-imat-Q6_K.gguf
|
||||
tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
|
||||
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
|
||||
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)
|
||||
|
Loading…
Reference in New Issue
Block a user