From a87531c14122b7649f3c39cc159d13caba4ce33a Mon Sep 17 00:00:00 2001 From: Saurab-Shrestha Date: Thu, 29 Feb 2024 12:22:25 +0545 Subject: [PATCH] Updated single injector for pdf ocr --- .../ocr_components/table_ocr_api.py | 13 +++++++---- private_gpt/launcher.py | 23 +++++++++---------- 2 files changed, 19 insertions(+), 17 deletions(-) diff --git a/private_gpt/components/ocr_components/table_ocr_api.py b/private_gpt/components/ocr_components/table_ocr_api.py index d27c745b..612720d1 100644 --- a/private_gpt/components/ocr_components/table_ocr_api.py +++ b/private_gpt/components/ocr_components/table_ocr_api.py @@ -39,10 +39,13 @@ async def save_uploaded_file(file: UploadFile, upload_dir: str): return file_path -async def process_images_and_generate_doc(pdf_path: str, upload_dir: str): +async def process_images_and_generate_doc(request: Request, pdf_path: str, upload_dir: str): doc = Document() - ocr = GetOCRText() - img_tab = ImageToTable() + + ocr = request.state.injector.get(GetOCRText) + img_tab = request.state.injector.get(ImageToTable) + # ocr = GetOCRText() + # img_tab = ImageToTable() pdf_doc = fitz.open(pdf_path) for page_index in range(len(pdf_doc)): @@ -87,7 +90,7 @@ async def process_pdf_ocr( print("The file name is: ", file.filename) pdf_path = await save_uploaded_file(file, UPLOAD_DIR) print("The file path: ", pdf_path) - ocr_doc_path = await process_images_and_generate_doc(pdf_path, UPLOAD_DIR) + ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR) ingested_documents = await common_ingest_logic( request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=None, log_audit=log_audit ) @@ -110,7 +113,7 @@ async def process_both( UPLOAD_DIR = OCR_UPLOAD try: pdf_path = await save_uploaded_file(file, UPLOAD_DIR) - ocr_doc_path = await process_images_and_generate_doc(pdf_path, UPLOAD_DIR) + ocr_doc_path = await process_images_and_generate_doc(request, pdf_path, UPLOAD_DIR) ingested_documents = await common_ingest_logic( request=request, db=db, ocr_file=ocr_doc_path, current_user=current_user, original_file=pdf_path, log_audit=log_audit ) diff --git a/private_gpt/launcher.py b/private_gpt/launcher.py index cc2bb03a..b5f99b01 100644 --- a/private_gpt/launcher.py +++ b/private_gpt/launcher.py @@ -1,21 +1,20 @@ """FastAPI app creation, logger configuration and main API routes.""" import logging +from injector import Injector from fastapi import Depends, FastAPI, Request from fastapi.middleware.cors import CORSMiddleware -from injector import Injector - -from private_gpt.server.chat.chat_router import chat_router -from private_gpt.server.chunks.chunks_router import chunks_router -from private_gpt.server.completions.completions_router import completions_router -from private_gpt.server.embeddings.embeddings_router import embeddings_router -from private_gpt.server.health.health_router import health_router -from private_gpt.server.ingest.ingest_router import ingest_router -from private_gpt.users.api.v1.api import api_router -from private_gpt.components.ocr_components.table_ocr_api import pdf_router from private_gpt.settings.settings import Settings -from private_gpt.home import home_router +from private_gpt.users.api.v1.api import api_router +from private_gpt.server.chat.chat_router import chat_router +from private_gpt.server.health.health_router import health_router +from private_gpt.server.chunks.chunks_router import chunks_router +from private_gpt.server.ingest.ingest_router import ingest_router +from private_gpt.components.ocr_components.table_ocr_api import pdf_router +from private_gpt.server.completions.completions_router import completions_router +from private_gpt.server.embeddings.embeddings_router import embeddings_router + logger = logging.getLogger(__name__) @@ -34,8 +33,8 @@ def create_app(root_injector: Injector) -> FastAPI: app.include_router(health_router) app.include_router(api_router) - # app.include_router(home_router) app.include_router(pdf_router) + settings = root_injector.get(Settings) if settings.server.cors.enabled: logger.debug("Setting up CORS middleware")