Updated api for ocr component

This commit is contained in:
Saurab-Shrestha9639*969**9858//852 2024-06-02 09:45:53 +05:45
parent fbd298212f
commit 175b4e29ac
12 changed files with 2537 additions and 2901 deletions

5132
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@ -11,8 +11,7 @@ from typing import Any
from llama_index.core.data_structs import IndexDict from llama_index.core.data_structs import IndexDict
from llama_index.core.embeddings.utils import EmbedType from llama_index.core.embeddings.utils import EmbedType
from llama_index.core.indices import VectorStoreIndex, load_index_from_storage, SimpleKeywordTableIndex from llama_index.core.indices import VectorStoreIndex, load_index_from_storage
from private_gpt.utils.vector_store import VectorStoreIndex1
from llama_index.core.indices.base import BaseIndex from llama_index.core.indices.base import BaseIndex
from llama_index.core.ingestion import run_transformations from llama_index.core.ingestion import run_transformations
from llama_index.core.schema import BaseNode, Document, TransformComponent from llama_index.core.schema import BaseNode, Document, TransformComponent
@ -84,7 +83,7 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
except ValueError: except ValueError:
# There are no index in the storage context, creating a new one # There are no index in the storage context, creating a new one
logger.info("Creating a new vector store index") logger.info("Creating a new vector store index")
index = VectorStoreIndex1.from_documents( index = VectorStoreIndex.from_documents(
[], [],
storage_context=self.storage_context, storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores store_nodes_override=True, # Force store nodes in index and document stores
@ -93,17 +92,6 @@ class BaseIngestComponentWithIndex(BaseIngestComponent, abc.ABC):
transformations=self.transformations, transformations=self.transformations,
) )
index.storage_context.persist(persist_dir=local_data_path) index.storage_context.persist(persist_dir=local_data_path)
keyword_index = SimpleKeywordTableIndex.from_documents(
[],
storage_context=self.storage_context,
store_nodes_override=True, # Force store nodes in index and document stores
show_progress=self.show_progress,
transformations=self.transformations,
llm=
)
# Store the keyword index in the vector store
index.keyword_index = keyword_index
return index return index
def _save_index(self) -> None: def _save_index(self) -> None:

View File

@ -1,57 +1,58 @@
# from paddleocr import PaddleOCR import io
from typing import Union
import cv2 import cv2
import torch import torch
from doctr.models import ocr_predictor
from doctr.io import DocumentFile from doctr.io import DocumentFile
from injector import singleton, inject from doctr.models import ocr_predictor
device = "cuda" if torch.cuda.is_available() else "cpu" from injector import inject, singleton
from pdf2image import convert_from_bytes
# device = "cuda" if torch.cuda.is_available() else "cpu"
device = "cpu"
@singleton @singleton
class GetOCRText: class GetOCRText:
@inject @inject
def __init__(self) -> None: def __init__(self) -> None:
self._image = None self._image = None
# self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True).to(device) self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True).to(device)
def _preprocess_image(self, img): def _preprocess_image(self, img):
resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC) resized_image = cv2.resize(img, None, fx=1.2, fy=1.2, interpolation=cv2.INTER_CUBIC)
gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY) gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray_image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU) _, binary = cv2.threshold(gray_image, 0, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
return binary return binary
## paddleOCR def extract_text(self, cell_image: Union[None, bytes] = None, image_file: bool = False, file_path: Union[None, str] = None):
# def extract_text(self, cell_image):
# text = ""
# self._image = cell_image
# preprocessd_image = self._preprocess_image(self._image)
# results = self.ocr.ocr(preprocessd_image, cls=True)
# print(results)
# if len(results) > 0:
# for result in results[0]:
# text += f"{result[-1][0]} "
# else:
# text = ""
# return text
## docTR OCR
def extract_text(self, cell_image=None, image_file=False, file_path=None):
text = "" text = ""
if image_file: if image_file:
if file_path is None:
raise ValueError("file_path must be provided when image_file is True.")
pdf_file = DocumentFile.from_images(file_path) pdf_file = DocumentFile.from_images(file_path)
result = self.doctr(pdf_file) result = self.doctr(pdf_file)
output = result.export() output = result.export()
else:
if cell_image is None:
raise ValueError("cell_image must be provided when image_file is False.")
if isinstance(cell_image, bytes):
images = convert_from_bytes(cell_image)
pdf_file = DocumentFile.from_images(images)
result = self.doctr(pdf_file)
else: else:
self._image = cell_image self._image = cell_image
preprocessd_image = self._preprocess_image(self._image) preprocessed_image = self._preprocess_image(self._image)
result = self.doctr([self._image]) result = self.doctr([preprocessed_image])
output = result.export() output = result.export()
for obj1 in output['pages'][0]["blocks"]: for obj1 in output['pages'][0]["blocks"]:
for obj2 in obj1["lines"]: for obj2 in obj1["lines"]:
for obj3 in obj2["words"]: for obj3 in obj2["words"]:
text += (f"{obj3['value']} ").replace("\n", "") text += (f"{obj3['value']} ").replace("\n", "")
text += "\n"
text = text + "\n" text += "\n"
if text: if text:
return text return text.strip()
return " " return " "

View File

@ -34,10 +34,9 @@ async def save_uploaded_file(file: UploadFile, upload_dir: str):
async def process_images_and_generate_doc(request: Request, pdf_path: str, upload_dir: str): async def process_images_and_generate_doc(request: Request, pdf_path: str, upload_dir: str):
doc = Document()
ocr = request.state.injector.get(GetOCRText) ocr = request.state.injector.get(GetOCRText)
img_tab = request.state.injector.get(ImageToTable) # img_tab = request.state.injector.get(ImageToTable)
pdf_writer = fitz.open()
pdf_doc = fitz.open(pdf_path) pdf_doc = fitz.open(pdf_path)
for page_index in range(len(pdf_doc)): for page_index in range(len(pdf_doc)):
@ -55,18 +54,19 @@ async def process_images_and_generate_doc(request: Request, pdf_path: str, uploa
pix = fitz.Pixmap(fitz.csRGB, pix)( pix = fitz.Pixmap(fitz.csRGB, pix)(
"RGB", [pix.width, pix.height], pix.samples) "RGB", [pix.width, pix.height], pix.samples)
image_path = f"page_{page_index}-image_{image_index}.png" image_path = f"temp_page_{page_index}_image_{image_index}.png"
pix.save(image_path) pix.save(image_path)
extracted_text = ocr.extract_text( extracted_text = ocr.extract_text(
image_file=True, file_path=image_path) image_file=True, file_path=image_path)
doc.add_paragraph(extracted_text) # Create a new page with the same dimensions as the original page
table_data = img_tab.table_to_csv(image_path) pdf_page = pdf_writer.new_page(width=page.rect.width, height=page.rect.height)
doc.add_paragraph(table_data) pdf_page.insert_text((10, 10), extracted_text, fontsize=9)
os.remove(image_path) os.remove(image_path)
save_path = os.path.join( save_path = os.path.join(upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}.pdf")
upload_dir, f"{os.path.splitext(os.path.basename(pdf_path))[0]}_ocr.docx") pdf_writer.save(save_path)
doc.save(save_path) pdf_writer.close()
return save_path return save_path
@ -82,6 +82,7 @@ async def process_pdf_ocr(
try: try:
pdf_path = await save_uploaded_file(file, UPLOAD_DIR) pdf_path = await save_uploaded_file(file, UPLOAD_DIR)
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR) ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
print("FILE PATH:", ocr_doc_path)
ingested_documents = await common_ingest_logic( ingested_documents = await common_ingest_logic(
request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=None, log_audit=log_audit, departments=departments request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=None, log_audit=log_audit, departments=departments
) )
@ -111,26 +112,6 @@ async def process_ocr(
detail=f"There was an error processing OCR: {e}" detail=f"There was an error processing OCR: {e}"
) )
async def process_both_ocr(
request: Request,
pdf_path: str
):
UPLOAD_DIR = OCR_UPLOAD
try:
ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR)
ingested_ocr_documents = await ingest(request=request, file_path=ocr_doc_path) # ingest ocr
ingested_documents = await ingest(request=request, file_path=pdf_path) # ingest pdf
return ingested_documents
except Exception as e:
print(traceback.print_exc())
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail=f"There was an error processing OCR: {e}"
)
async def process_both( async def process_both(
request: Request, request: Request,
db: Session, db: Session,
@ -170,18 +151,3 @@ async def get_pdf_ocr_wrapper(
) )
): ):
return await process_pdf_ocr(request, db, file, current_user, log_audit, departments) return await process_pdf_ocr(request, db, file, current_user, log_audit, departments)
@pdf_router.post("/both")
async def get_both_wrapper(
request: Request,
departments: schemas.DocumentDepartmentList = Depends(),
db: Session = Depends(deps.get_db),
log_audit: models.Audit = Depends(deps.get_audit_logger),
file: UploadFile = File(...),
current_user: models.User = Security(
deps.get_current_user,
scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]],
)
):
return await process_both(request, db, file, current_user, log_audit, departments)

View File

@ -1,55 +0,0 @@
# import QueryBundle
from llama_index.core import QueryBundle
# import NodeWithScore
from llama_index.core.schema import NodeWithScore
# Retrievers
from llama_index.core.retrievers import (
BaseRetriever,
VectorIndexRetriever,
KeywordTableSimpleRetriever,
)
from typing import List
class CustomRetriever(BaseRetriever):
"""Custom retriever that performs both semantic search and hybrid search."""
def __init__(
self,
vector_retriever: VectorIndexRetriever,
keyword_retriever: KeywordTableSimpleRetriever,
mode: str = "AND",
) -> None:
"""Init params."""
self._vector_retriever = vector_retriever
self._keyword_retriever = keyword_retriever
if mode not in ("AND", "OR"):
raise ValueError("Invalid mode.")
self._mode = mode
super().__init__()
def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]:
"""Retrieve nodes given query."""
vector_nodes = self._vector_retriever.retrieve(query_bundle)
keyword_nodes = self._keyword_retriever.retrieve(query_bundle)
vector_ids = {n.node.node_id for n in vector_nodes}
keyword_ids = {n.node.node_id for n in keyword_nodes}
combined_dict = {n.node.node_id: n for n in vector_nodes}
combined_dict.update({n.node.node_id: n for n in keyword_nodes})
if self._mode == "AND":
retrieve_ids = vector_ids.intersection(keyword_ids)
else:
retrieve_ids = vector_ids.union(keyword_ids)
retrieve_nodes = [combined_dict[rid] for rid in retrieve_ids]
return retrieve_nodes

View File

@ -3,14 +3,12 @@ import typing
from injector import inject, singleton from injector import inject, singleton
from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex from llama_index.core.indices.vector_store import VectorIndexRetriever, VectorStoreIndex
from llama_index.core.indices import SimpleKeywordTableIndex
from llama_index.core.vector_stores.types import ( from llama_index.core.vector_stores.types import (
FilterCondition, FilterCondition,
MetadataFilter, MetadataFilter,
MetadataFilters, MetadataFilters,
VectorStore, VectorStore,
) )
from private_gpt.utils.vector_store import VectorStoreIndex1
from private_gpt.open_ai.extensions.context_filter import ContextFilter from private_gpt.open_ai.extensions.context_filter import ContextFilter
from private_gpt.paths import local_data_path from private_gpt.paths import local_data_path
@ -132,21 +130,12 @@ class VectorStoreComponent:
def get_retriever( def get_retriever(
self, self,
index: VectorStoreIndex1, index: VectorStoreIndex,
context_filter: ContextFilter | None = None, context_filter: ContextFilter | None = None,
similarity_top_k: int = 2, similarity_top_k: int = 2,
) -> VectorIndexRetriever: ) -> VectorIndexRetriever:
# This way we support qdrant (using doc_ids) and the rest (using filters) # This way we support qdrant (using doc_ids) and the rest (using filters)
return VectorIndexRetriever(
# from llama_index.core import get_response_synthesizer
# from llama_index.core.query_engine import RetrieverQueryEngine
from .retriever import CustomRetriever
from llama_index.core.retrievers import (
VectorIndexRetriever,
KeywordTableSimpleRetriever,
)
vector_retriever = VectorIndexRetriever(
index=index, index=index,
similarity_top_k=similarity_top_k, similarity_top_k=similarity_top_k,
doc_ids=context_filter.docs_ids if context_filter else None, doc_ids=context_filter.docs_ids if context_filter else None,
@ -156,19 +145,6 @@ class VectorStoreComponent:
else None else None
), ),
) )
keyword_retriever = KeywordTableSimpleRetriever(index=index.keyword_index)
custom_retriever = CustomRetriever(vector_retriever, keyword_retriever)
# define response synthesizer
# response_synthesizer = get_response_synthesizer()
# # assemble query engine
# custom_query_engine = RetrieverQueryEngine(
# retriever=custom_retriever,
# response_synthesizer=response_synthesizer,
# )
return custom_retriever
def close(self) -> None: def close(self) -> None:
if hasattr(self.vector_store.client, "close"): if hasattr(self.vector_store.client, "close"):

View File

@ -5,7 +5,7 @@ from llama_index.core.chat_engine import ContextChatEngine, SimpleChatEngine
from llama_index.core.chat_engine.types import ( from llama_index.core.chat_engine.types import (
BaseChatEngine, BaseChatEngine,
) )
from llama_index.core.indices import VectorStoreIndex, SimpleKeywordTableIndex from llama_index.core.indices import VectorStoreIndex
from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor from llama_index.core.indices.postprocessor import MetadataReplacementPostProcessor
from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.llms import ChatMessage, MessageRole
from llama_index.core.postprocessor import ( from llama_index.core.postprocessor import (
@ -99,12 +99,6 @@ class ChatService:
embed_model=embedding_component.embedding_model, embed_model=embedding_component.embedding_model,
show_progress=True, show_progress=True,
) )
self.keyword_index = SimpleKeywordTableIndex.from_documents(
vector_store_component.vector_store,
storage_context=self.storage_context,
embed_model=embedding_component.embedding_model,
show_progress=True,
)
def _chat_engine( def _chat_engine(
self, self,
@ -116,7 +110,6 @@ class ChatService:
if use_context: if use_context:
vector_index_retriever = self.vector_store_component.get_retriever( vector_index_retriever = self.vector_store_component.get_retriever(
index=self.index, index=self.index,
keyword_index=self.keyword_index,
context_filter=context_filter, context_filter=context_filter,
similarity_top_k=self.settings.rag.similarity_top_k, similarity_top_k=self.settings.rag.similarity_top_k,
) )
@ -195,17 +188,17 @@ class ChatService:
) )
system_prompt = ( system_prompt = (
""" """
You are QuickGPT, a helpful assistant by Quickfox Consulting. You are a helpful assistant named QuickGPT by Quickfox Consulting.
Your responses must be strictly and exclusively based on the context documents provided.
Responses should be based on the context documents provided You are not allowed to use any information, knowledge, or external sources outside of the given context documents.
and should be relevant, informative, and easy to understand. If the answer to a query is not present in the context documents,
You should aim to deliver high-quality responses that are you should respond with "I do not have enough information in the provided context to answer this question."
respectful and helpful, using clear and concise language.
Avoid providing information outside of the context documents unless Your responses should be relevant, informative, and easy to understand.
it is necessary for clarity or completeness. Focus on providing Aim to deliver high-quality answers that are respectful and helpful, using clear and concise language.
accurate and reliable answers based on the given context. Focus on providing accurate and reliable answers based solely on the given context.
If answer is not in the context documents, just say I don't have answer Do not make assumptions, inferences, or draw upon any prior knowledge beyond what is explicitly stated in the context documents.
in respectful way.
""" """
) )
chat_history = ( chat_history = (

View File

@ -253,34 +253,6 @@ async def common_ingest_logic(
f.write(file.read()) f.write(file.read())
file.seek(0) file.seek(0)
ingested_documents = service.ingest_bin_data(file_name, file) ingested_documents = service.ingest_bin_data(file_name, file)
# Handling Original File
if original_file:
try:
print("ORIGINAL PDF FILE PATH IS :: ", original_file)
file_name = Path(original_file).name
upload_path = Path(f"{UPLOAD_DIR}/{file_name}")
if file_name is None:
raise HTTPException(
status_code=status.HTTP_400_BAD_REQUEST,
detail="No file name provided",
)
await create_documents(db, file_name, current_user, departments, log_audit)
with open(upload_path, "wb") as f:
with open(original_file, "rb") as original_file_reader:
f.write(original_file_reader.read())
with open(upload_path, "rb") as f:
ingested_documents = service.ingest_bin_data(file_name, f)
except Exception as e:
print(traceback.print_exc())
raise HTTPException(
status_code=500,
detail="Internal Server Error: Unable to ingest file.",
)
logger.info( logger.info(
f"{file_name} is uploaded by the {current_user.username}.") f"{file_name} is uploaded by the {current_user.username}.")

View File

@ -17,7 +17,7 @@ from private_gpt.users.core.config import settings
from private_gpt.users import crud, models, schemas from private_gpt.users import crud, models, schemas
from private_gpt.server.ingest.ingest_router import create_documents, ingest from private_gpt.server.ingest.ingest_router import create_documents, ingest
from private_gpt.users.models.document import MakerCheckerActionType, MakerCheckerStatus from private_gpt.users.models.document import MakerCheckerActionType, MakerCheckerStatus
from private_gpt.components.ocr_components.table_ocr_api import process_both_ocr, process_ocr from private_gpt.components.ocr_components.table_ocr_api import process_ocr
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
router = APIRouter(prefix='/documents', tags=['Documents']) router = APIRouter(prefix='/documents', tags=['Documents'])
@ -385,8 +385,6 @@ async def verify_documents(
if document.doc_type_id == 2: # For OCR if document.doc_type_id == 2: # For OCR
return await process_ocr(request, unchecked_path) return await process_ocr(request, unchecked_path)
elif document.doc_type_id == 3: # For BOTH
return await process_both_ocr(request, unchecked_path)
else: else:
return await ingest(request, unchecked_path) # For pdf return await ingest(request, unchecked_path) # For pdf

View File

@ -1,13 +0,0 @@
from llama_index.core.indices.vector_store import VectorStoreIndex
from llama_index.core.indices import SimpleKeywordTableIndex
class VectorStoreIndex1(VectorStoreIndex):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.keyword_index = None
def set_keyword_index(self, keyword_index: SimpleKeywordTableIndex):
self.keyword_index = keyword_index
def get_keyword_index(self) -> SimpleKeywordTableIndex:
return self.keyword_index

View File

@ -68,6 +68,8 @@ openpyxl = "^3.1.2"
pandas = "^2.2.2" pandas = "^2.2.2"
fastapi-pagination = "^0.12.23" fastapi-pagination = "^0.12.23"
xlsxwriter = "^3.2.0" xlsxwriter = "^3.2.0"
pdf2image = "^1.17.0"
pymupdf = "^1.24.4"
[tool.poetry.extras] [tool.poetry.extras]
ui = ["gradio"] ui = ["gradio"]

View File

@ -57,8 +57,8 @@ rag:
llamacpp: llamacpp:
# llm_hf_repo_id: bartowski/Meta-Llama-3-8B-Instruct-GGUF # llm_hf_repo_id: bartowski/Meta-Llama-3-8B-Instruct-GGUF
# llm_hf_model_file: Meta-Llama-3-8B-Instruct-Q6_K.gguf # llm_hf_model_file: Meta-Llama-3-8B-Instruct-Q6_K.gguf
llm_hf_repo_id: NousResearch/Hermes-2-Theta-Llama-3-8B-GGUF llm_hf_repo_id: qwp4w3hyb/Hermes-2-Pro-Llama-3-8B-iMat-GGUF
llm_hf_model_file: Hermes-2-Pro-Llama-3-Instruct-Merged-DPO-Q6_K.gguf llm_hf_model_file: hermes-2-pro-llama-3-8b-imat-Q6_K.gguf
tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting tfs_z: 1.0 # Tail free sampling is used to reduce the impact of less probable tokens from the output. A higher value (e.g., 2.0) will reduce the impact more, while a value of 1.0 disables this setting
top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40) top_k: 40 # Reduces the probability of generating nonsense. A higher value (e.g. 100) will give more diverse answers, while a lower value (e.g. 10) will be more conservative. (Default: 40)
top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9) top_p: 0.9 # Works together with top-k. A higher value (e.g., 0.95) will lead to more diverse text, while a lower value (e.g., 0.5) will generate more focused and conservative text. (Default: 0.9)