diff --git a/private_gpt/components/ocr_components/TextExtraction.py b/private_gpt/components/ocr_components/TextExtraction.py index 1fb58e5f..ae6069d5 100644 --- a/private_gpt/components/ocr_components/TextExtraction.py +++ b/private_gpt/components/ocr_components/TextExtraction.py @@ -11,7 +11,7 @@ from transformers import AutoModelForObjectDetection from transformers import TableTransformerForObjectDetection from typing import Literal - +from injector import inject, singleton from private_gpt.components.ocr_components.table_ocr import GetOCRText device = "cuda" if torch.cuda.is_available() else "cpu" @@ -26,8 +26,9 @@ class MaxResize(object): resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) return resized_image - +@singleton class ImageToTable: + @inject def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None: self._table_model = "microsoft/table-transformer-detection" self._structure_model = "microsoft/table-structure-recognition-v1.1-all" diff --git a/private_gpt/components/ocr_components/table_ocr.py b/private_gpt/components/ocr_components/table_ocr.py index f8248dde..a14bc61e 100644 --- a/private_gpt/components/ocr_components/table_ocr.py +++ b/private_gpt/components/ocr_components/table_ocr.py @@ -3,10 +3,12 @@ import cv2 import torch from doctr.models import ocr_predictor from doctr.io import DocumentFile - +from injector import singleton, inject device = "cuda" if torch.cuda.is_available() else "cpu" +@singleton class GetOCRText: + @inject def __init__(self) -> None: self._image = None # self.ocr = PaddleOCR(use_angle_cls=True, lang='en') diff --git a/private_gpt/server/ingest/ingest_router.py b/private_gpt/server/ingest/ingest_router.py index 75fd965c..7472ac7b 100644 --- a/private_gpt/server/ingest/ingest_router.py +++ b/private_gpt/server/ingest/ingest_router.py @@ -177,75 +177,6 @@ def delete_file( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Internal Server Error") -# @ingest_router.post("/ingest/file", response_model=IngestResponse, tags=["Ingestion"]) -# def ingest_file( -# request: Request, -# log_audit: models.Audit = Depends(deps.get_audit_logger), - -# db: Session = Depends(deps.get_db), -# file: UploadFile = File(...), -# current_user: models.User = Security( -# deps.get_current_user, -# scopes=[Role.ADMIN["name"], Role.SUPER_ADMIN["name"]], -# )) -> IngestResponse: -# """Ingests and processes a file, storing its chunks to be used as context.""" -# service = request.state.injector.get(IngestService) -# print("-------------------------------------->",file) -# try: -# file_ingested = crud.documents.get_by_filename(db, file_name=file.filename) -# if file_ingested: -# raise HTTPException( -# status_code=status.HTTP_409_CONFLICT, -# detail="File already exists. Choose a different file.", -# ) - -# if file.filename is None: -# raise HTTPException( -# status_code=status.HTTP_400_BAD_REQUEST, -# detail="No file name provided", -# ) - -# # try: -# docs_in = schemas.DocumentCreate(filename=file.filename, uploaded_by=current_user.id, department_id=current_user.department_id) -# crud.documents.create(db=db, obj_in=docs_in) -# # except Exception as e: -# # raise HTTPException( -# # status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, -# # detail="Unable to upload file.", -# # ) -# upload_path = Path(f"{UPLOAD_DIR}/{file.filename}") - -# with open(upload_path, "wb") as f: -# f.write(file.file.read()) - -# with open(upload_path, "rb") as f: -# ingested_documents = service.ingest_bin_data(file.filename, f) -# logger.info(f"{file.filename} is uploaded by the {current_user.fullname}.") -# response = IngestResponse( -# object="list", model="private-gpt", data=ingested_documents) - -# log_audit(model='Document', action='create', -# details={ -# 'filename': f"{file.filename} uploaded successfully", -# 'user': current_user.fullname, -# }, user_id=current_user.id) - -# return response -# except HTTPException: -# print(traceback.print_exc()) -# raise - -# except Exception as e: -# print(traceback.print_exc()) -# log_audit(model='Document', action='create', -# details={"status": 500, "detail": "Internal Server Error: Unable to ingest file.", }, user_id=current_user.id) -# logger.error(f"There was an error uploading the file(s): {str(e)}") -# raise HTTPException( -# status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, -# detail="Internal Server Error: Unable to ingest file.", -# ) - - async def common_ingest_logic( request: Request, db: Session,