Updated api for ocr component

This commit is contained in:
Saurab-Shrestha9639*969**9858//852 2024-06-02 09:45:53 +05:45
parent fbd298212f
commit 175b4e29ac
12 changed files with 2537 additions and 2901 deletions

5136
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,8 +11,7 @@ from typing import Any
from llama_index.core.data_structs import IndexDict
from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage, SimpleKeywordTableIndex
from private_gpt.utils.vector_store import VectorStoreIndex1
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import BaseNode, Document, TransformComponent
@ -84,7 +83,7 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
except ValueError:
# There are no index in the storage context, creating a new one
logger.info("Creating a new vector store index")
index = VectorStoreIndex1.from_documents(
index = VectorStoreIndex.from_documents(
[],
storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores
@ -93,17 +92,6 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
transformations=self.transformations,
)
index.storage_context.persist(persist_dir=local_data_path)
keyword_index = SimpleKeywordTableIndex.from_documents(
[],
storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress,
transformations=self.transformations,
llm=
)
# Store the keyword index in the vector store
index.keyword_index = keyword_index
return index
def _save_index(self) -> None:

View File

@ -1,57 +1,58 @@
# from paddleocr import PaddleOCR
import io
from typing import Union
import cv2
import torch
from doctr.models import ocr_predictor
from doctr.io import DocumentFile
from injector import singleton, inject
device = "cuda" if torch.cuda.is_available() else "cpu"
from doctr.models import ocr_predictor
from injector import inject, singleton
from pdf2image import convert_from_bytes
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
@singleton
class GetOCRText:
@inject
def __init__(self) -> None:
self._image = None
# self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True).to(device)
def _preprocess_image(self, img):
resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC)
resized_image = cv2.resize(img, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray_image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
_, binary = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
return binary
## paddleOCR
# def extract_text(self, cell_image):
# text = ""
# self._image = cell_image
# preprocessd_image = self._preprocess_image(self._image)
# results = self.ocr.ocr(preprocessd_image, cls=True)
# print(results)
# if len(results) > 0:
# for result in results[0]:
# text += f"{result[-1][0]} "
# else:
# text = ""
# return text
## docTR OCR
def extract_text(self, cell_image=None, image_file=False, file_path=None):
def extract_text(self, cell_image: Union[None, bytes] = None, image_file: bool = False, file_path: Union[None, str] = None):
text = ""
if image_file:
if file_path is None:
raise ValueError("file_path must be provided when image_file is True.")
pdf_file = DocumentFile.from_images(file_path)
result = self.doctr(pdf_file)
output = result.export()
else:
self._image = cell_image
preprocessd_image = self._preprocess_image(self._image)
result = self.doctr([self._image])
output = result.export()
if cell_image is None:
raise ValueError("cell_image must be provided when image_file is False.")
if isinstance(cell_image, bytes):
images = convert_from_bytes(cell_image)
pdf_file = DocumentFile.from_images(images)
result = self.doctr(pdf_file)
else:
self._image = cell_image
preprocessed_image = self._preprocess_image(self._image)
result = self.doctr([preprocessed_image])
output = result.export()
for obj1 in output['pages'][0]["blocks"]:
for obj2 in obj1["lines"]:
for obj3 in obj2["words"]:
text += (f"{obj3['value']} ").replace("\n", "")
text = text + "\n"
text += "\n"
text += "\n"
if text:
return text
return text.strip()
return " "

View File

@ -34,10 +34,9 @@ async def save_uploaded_file(file: UploadFile, upload_dir: str):
async def process_images_and_generate_doc(request: Request, pdf_path: str, upload_dir: str):
doc = Document()
ocr = request.state.injector.get(GetOCRText)
img_tab = request.state.injector.get(ImageToTable)
# img_tab = request.state.injector.get(ImageToTable)
pdf_writer = fitz.open()
pdf_doc = fitz.open(pdf_path)
for page_index in range(len(pdf_doc)):
@ -55,18 +54,19 @@ async def process_images_and_generate_doc(request: Request, pdf_path: str, uploa
pix = fitz.Pixmap(fitz.csRGB, pix)(
"RGB", [pix.width, pix.height], pix.samples)
image_path = f"page_{page_index}-image_{image_index}.png"
image_path = f"temp_page_{page_index}_image_{image_index}.png"
pix.save(image_path)
extracted_text = ocr.extract_text(
image_file=True, file_path=image_path)
doc.add_paragraph(extracted_text)
table_data = img_tab.table_to_csv(image_path)
doc.add_paragraph(table_data)
# Create a new page with the same dimensions as the original page
pdf_page = pdf_writer.new_page(width=page.rect.width, height=page.rect.height)
pdf_page.insert_text((10, 10), extracted_text, fontsize=9)
os.remove(image_path)
save_path = os.path.join(
upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}_ocr.docx")
doc.save(save_path)
save_path = os.path.join(upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}.pdf")
pdf_writer.save(save_path)
pdf_writer.close()
return save_path
@ -82,6 +82,7 @@ async def process_pdf_ocr(
try:
pdf_path = await save_uploaded_file(file, UPLOAD_DIR)
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
print("FILE PATH:", ocr_doc_path)
ingested_documents = await common_ingest_logic(
request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=None, log_audit=log_audit, departments=departments
)
@ -111,26 +112,6 @@ async def process_ocr(
detail=f"There was an error processing OCR: {e}"
)
async def process_both_ocr(
request: Request,
pdf_path: str
):
UPLOAD_DIR = OCR_UPLOAD
try:
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
ingested_ocr_documents = await ingest(request=request, file_path=ocr_doc_path) # ingest ocr
ingested_documents = await ingest(request=request, file_path=pdf_path) # ingest pdf
return ingested_documents
except Exception as e:
print(traceback.print_exc())
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"There was an error processing OCR: {e}"
)
async def process_both(
request: Request,
db: Session,
@ -170,18 +151,3 @@ async def get_pdf_ocr_wrapper(
)
):
return await process_pdf_ocr(request, db, file, current_user, log_audit, departments)
@pdf_router.post("/both")
async def get_both_wrapper(
request: Request,
departments: schemas.DocumentDepartmentList = Depends(),
db: Session = Depends(deps.get_db),
log_audit: models.Audit = Depends(deps.get_audit_logger),
file: UploadFile = File(...),
current_user: models.User = Security(
deps.get_current_user,
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
)
):
return await process_both(request, db, file, current_user, log_audit, departments)

View File

@ -1,55 +0,0 @@
# import QueryBundle
from llama_index.core import QueryBundle
# import NodeWithScore
from llama_index.core.schema import NodeWithScore
# Retrievers
from llama_index.core.retrievers import (
BaseRetriever,
VectorIndexRetriever,
KeywordTableSimpleRetriever,
)
from typing import List
class CustomRetriever(BaseRetriever):
"""Custom retriever that performs both semantic search and hybrid search."""
def __init__(
self,
vector_retriever: VectorIndexRetriever,
keyword_retriever: KeywordTableSimpleRetriever,
mode: str = "AND",
) -> None:
"""Init params."""
self._vector_retriever = vector_retriever
self._keyword_retriever = keyword_retriever
if mode not in ("AND", "OR"):
raise ValueError("Invalid mode.")
self._mode = mode
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query."""
vector_nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
vector_ids = {n.node.node_id for n in vector_nodes}
keyword_ids = {n.node.node_id for n in keyword_nodes}
combined_dict = {n.node.node_id: n for n in vector_nodes}
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
if self._mode == "AND":
retrieve_ids = vector_ids.intersection(keyword_ids)
else:
retrieve_ids = vector_ids.union(keyword_ids)
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
return retrieve_nodes

View File

@ -3,14 +3,12 @@ import typing
from injector import inject, singleton
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
from llama_index.core.indices import SimpleKeywordTableIndex
from llama_index.core.vector_stores.types import (
FilterCondition,
MetadataFilter,
MetadataFilters,
VectorStore,
)
from private_gpt.utils.vector_store import VectorStoreIndex1
from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path
@ -132,21 +130,12 @@ class VectorStoreComponent:
def get_retriever(
self,
index: VectorStoreIndex1,
index: VectorStoreIndex,
context_filter: ContextFilter | None = None,
similarity_top_k: int = 2,
) -> VectorIndexRetriever:
# This way we support qdrant (using doc_ids) and the rest (using filters)
# from llama_index.core import get_response_synthesizer
# from llama_index.core.query_engine import RetrieverQueryEngine
from .retriever import CustomRetriever
from llama_index.core.retrievers import (
VectorIndexRetriever,
KeywordTableSimpleRetriever,
)
vector_retriever = VectorIndexRetriever(
return VectorIndexRetriever(
index=index,
similarity_top_k=similarity_top_k,
doc_ids=context_filter.docs_ids if context_filter else None,
@ -156,19 +145,6 @@ class VectorStoreComponent:
else None
),
)
keyword_retriever = KeywordTableSimpleRetriever(index=index.keyword_index)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)
# define response synthesizer
# response_synthesizer = get_response_synthesizer()
# # assemble query engine
# custom_query_engine = RetrieverQueryEngine(
# retriever=custom_retriever,
# response_synthesizer=response_synthesizer,
# )
return custom_retriever
def close(self) -> None:
if hasattr(self.vector_store.client, "close"):

View File

@ -5,7 +5,7 @@ from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.core.chat_engine.types import (
BaseChatEngine,
)
from llama_index.core.indices import VectorStoreIndex, SimpleKeywordTableIndex
from llama_index.core.indices import VectorStoreIndex
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.postprocessor import (
@ -99,12 +99,6 @@ class ChatService:
embed_model=embedding_component.embedding_model,
show_progress=True,
)
self.keyword_index = SimpleKeywordTableIndex.from_documents(
vector_store_component.vector_store,
storage_context=self.storage_context,
embed_model=embedding_component.embedding_model,
show_progress=True,
)
def _chat_engine(
self,
@ -116,7 +110,6 @@ class ChatService:
if use_context:
vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index,
keyword_index=self.keyword_index,
context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k,
)
@ -195,17 +188,17 @@ class ChatService:
)
system_prompt = (
"""
You are QuickGPT, a helpful assistant by Quickfox Consulting.
You are a helpful assistant named QuickGPT by Quickfox Consulting.
Your responses must be strictly and exclusively based on the context documents provided.
Responses should be based on the context documents provided
and should be relevant, informative, and easy to understand.
You should aim to deliver high-quality responses that are
respectful and helpful, using clear and concise language.
Avoid providing information outside of the context documents unless
it is necessary for clarity or completeness. Focus on providing
accurate and reliable answers based on the given context.
If answer is not in the context documents, just say I don't have answer
in respectful way.
You are not allowed to use any information, knowledge, or external sources outside of the given context documents.
If the answer to a query is not present in the context documents,
you should respond with "I do not have enough information in the provided context to answer this question."
Your responses should be relevant, informative, and easy to understand.
Aim to deliver high-quality answers that are respectful and helpful, using clear and concise language.
Focus on providing accurate and reliable answers based solely on the given context.
Do not make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the context documents.
"""
)
chat_history = (

View File

@ -253,34 +253,6 @@ async def common_ingest_logic(
f.write(file.read())
file.seek(0)
ingested_documents = service.ingest_bin_data(file_name, file)
# Handling Original File
if original_file:
try:
print("ORIGINAL PDF FILE PATH IS :: ", original_file)
file_name = Path(original_file).name
upload_path = Path(f"{UPLOAD_DIR}/{file_name}")
if file_name is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file name provided",
)
await create_documents(db, file_name, current_user, departments, log_audit)
with open(upload_path, "wb") as f:
with open(original_file, "rb") as original_file_reader:
f.write(original_file_reader.read())
with open(upload_path, "rb") as f:
ingested_documents = service.ingest_bin_data(file_name, f)
except Exception as e:
print(traceback.print_exc())
raise HTTPException(
status_code=500,
detail="Internal Server Error: Unable to ingest file.",
)
logger.info(
f"{file_name} is uploaded by the {current_user.username}.")

View File

@ -17,7 +17,7 @@ from private_gpt.users.core.config import settings
from private_gpt.users import crud, models, schemas
from private_gpt.server.ingest.ingest_router import create_documents, ingest
from private_gpt.users.models.document import MakerCheckerActionType, MakerCheckerStatus
from private_gpt.components.ocr_components.table_ocr_api import process_both_ocr, process_ocr
from private_gpt.components.ocr_components.table_ocr_api import process_ocr
logger = logging.getLogger(__name__)
router = APIRouter(prefix='/documents', tags=['Documents'])
@ -385,8 +385,6 @@ async def verify_documents(
if document.doc_type_id == 2: # For OCR
return await process_ocr(request, unchecked_path)
elif document.doc_type_id == 3: # For BOTH
return await process_both_ocr(request, unchecked_path)
else:
return await ingest(request, unchecked_path) # For pdf

View File

@ -1,13 +0,0 @@
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.indices import SimpleKeywordTableIndex
class VectorStoreIndex1(VectorStoreIndex):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.keyword_index = None
def set_keyword_index(self, keyword_index: SimpleKeywordTableIndex):
self.keyword_index = keyword_index
def get_keyword_index(self) -> SimpleKeywordTableIndex:
return self.keyword_index

View File

@ -68,6 +68,8 @@ openpyxl = "^3.1.2"
pandas = "^2.2.2"
fastapi-pagination = "^0.12.23"
xlsxwriter = "^3.2.0"
pdf2image = "^1.17.0"
pymupdf = "^1.24.4"
[tool.poetry.extras]
ui = ["gradio"]

View File

@ -57,8 +57,8 @@ rag:
llamacpp:
# llm_hf_repo_id: bartowski/Meta-Llama-3-8B-Instruct-GGUF
# llm_hf_model_file: Meta-Llama-3-8B-Instruct-Q6_K.gguf
llm_hf_repo_id: NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF
llm_hf_model_file: Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q6_K.gguf
llm_hf_repo_id: qwp4w3hyb/Hermes-2-Pro-Llama-3-8B-iMat-GGUF
llm_hf_model_file: hermes-2-pro-llama-3-8b-imat-Q6_K.gguf
tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)