From 7bba4d13ebc29f21c9dc47e3ff7692ec24c4fc97 Mon Sep 17 00:00:00 2001 From: quick-karsth Date: Thu, 15 Feb 2024 11:33:37 +0545 Subject: [PATCH] Added OCRComponents --- .../ocr_components/TextExtraction.py | 175 ++++++++++++++++++ .../components/ocr_components/table_ocr.py | 175 ++++++++++++++++++ .../ocr_components/table_ocr_api.py | 55 ++++++ 3 files changed, 405 insertions(+) create mode 100644 private_gpt/components/ocr_components/TextExtraction.py create mode 100644 private_gpt/components/ocr_components/table_ocr.py create mode 100644 private_gpt/components/ocr_components/table_ocr_api.py diff --git a/private_gpt/components/ocr_components/TextExtraction.py b/private_gpt/components/ocr_components/TextExtraction.py new file mode 100644 index 00000000..8d4d04d7 --- /dev/null +++ b/private_gpt/components/ocr_components/TextExtraction.py @@ -0,0 +1,175 @@ +import io +import cv2 +import csv +import torch +import numpy as np +from PIL import Image +from tqdm.auto import tqdm +from torchvision import transforms + +from transformers import AutoModelForObjectDetection +from transformers import TableTransformerForObjectDetection + +from typing import Literal + +from TextExtraction import GetOCRText + +device = "cuda" if torch.cuda.is_available() else "cpu" + +class MaxResize(object): + def __init__(self, max_size=800): + self.max_size = max_size + def __call__(self, image): + width, height = image.size + current_max_size = max(width, height) + scale = self.max_size / current_max_size + resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) + return resized_image + + +class ImageToTable: + def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None: + self._table_model = "microsoft/table-transformer-detection" + self._structure_model = "microsoft/table-structure-recognition-v1.1-all" + self._image = None + self._table_image = None + self._file_path = None + self.text_data =[] + self.tokens = [] + self.detection_class_thresholds = { + "table": 0.5, + "table rotated": 0.5, + "no object": 10 + } + # for ocr stuffs + self.get_ocr = GetOCRText() + + def _prepare_for_nn_input(self, image): + structure_transform = transforms.Compose([ + MaxResize(1000), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + pixel_values = structure_transform(image).unsqueeze(0) + pixel_values = pixel_values.to(device) + return pixel_values + + + def _detection(self, detection_type: Literal['table', 'table_structure'], image): + if detection_type == "table": + model = AutoModelForObjectDetection.from_pretrained(self._table_model) + elif detection_type == "table_structure": + model = TableTransformerForObjectDetection.from_pretrained(self._structure_model) + pixel_values = self._prepare_for_nn_input(image) + pixel_values = pixel_values.to(device) + model.to(device) + with torch.no_grad(): + outputs = model(pixel_values) + id2label = model.config.id2label + id2label[len(model.config.id2label)] = "no object" + return outputs, id2label + + + def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10): + tables_crop = [] + for obj in objects: + if obj['score'] < class_thresholds[obj['label']]: + continue + cropped_table = {} + bbox = obj['bbox'] + bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding] + cropped_img = img.crop(bbox) + table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] + for token in table_tokens: + token['bbox'] = [token['bbox'][0]-bbox[0], + token['bbox'][1]-bbox[1], + token['bbox'][2]-bbox[0], + token['bbox'][3]-bbox[1]] + if obj['label'] == 'table rotated': + cropped_img = cropped_img.rotate(270, expand=True) + for token in table_tokens: + bbox = token['bbox'] + bbox = [cropped_img.size[0]-bbox[3]-1, + bbox[0], + cropped_img.size[0]-bbox[1]-1, + bbox[2]] + token['bbox'] = bbox + cropped_table['image'] = cropped_img + cropped_table['tokens'] = table_tokens + tables_crop.append(cropped_table) + return tables_crop + + def outputs_to_objects(self, outputs, img_size, id2label): + def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + m = outputs.logits.softmax(-1).max(-1) + pred_labels = list(m.indices.detach().cpu().numpy())[0] + pred_scores = list(m.values.detach().cpu().numpy())[0] + pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] + pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] + objects = [] + for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): + class_label = id2label[int(label)] + if not class_label == 'no object': + objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]}) + return objects + + def get_cell_coordinates_by_row(self, table_data): + rows = [entry for entry in table_data if entry['label'] == 'table row'] + columns = [entry for entry in table_data if entry['label'] == 'table column'] + rows.sort(key=lambda x: x['bbox'][1]) + columns.sort(key=lambda x: x['bbox'][0]) + def find_cell_coordinates(row, column): + cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]] + return cell_bbox + cell_coordinates = [] + for row in rows: + row_cells = [] + for column in columns: + cell_bbox = find_cell_coordinates(row, column) + row_cells.append({'column': column['bbox'], 'cell': cell_bbox}) + row_cells.sort(key=lambda x: x['column'][0]) + cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)}) + cell_coordinates.sort(key=lambda x: x['row'][1]) + return cell_coordinates + + def apply_ocr(self, cell_coordinates): + for idx, row in enumerate(tqdm(cell_coordinates)): + row_text = [] + for cell in row["cells"]: + print(cell) + cell_image = np.array(self._table_image.crop(cell["cell"])) + result = self.get_ocr.extract_text(np.array(cell_image)) + row_text.append(result) + self.text_data.append(row_text) + + + def table_to_csv(self, image_path): + self._image = Image.open(image_path).convert("RGB") + outputs, id2label = self._detection(detection_type='table', image=self._image) + objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label) + tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0) + for table_crop in tables_crop: + cropped_table = table_crop['image'].convert("RGB") + self._table_image = cropped_table + resized_image = self._prepare_for_nn_input(cropped_table) + outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table) + cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label) + cell_coordinates = self.get_cell_coordinates_by_row(cells) + self.apply_ocr(cell_coordinates) + if self.text_data: + # print(self.text_data) + return "".join(",".join(row) for row in self.text_data) + return "" + + + + + \ No newline at end of file diff --git a/private_gpt/components/ocr_components/table_ocr.py b/private_gpt/components/ocr_components/table_ocr.py new file mode 100644 index 00000000..98ee04c1 --- /dev/null +++ b/private_gpt/components/ocr_components/table_ocr.py @@ -0,0 +1,175 @@ +import io +import cv2 +import csv +import torch +import numpy as np +from PIL import Image +from tqdm.auto import tqdm +from torchvision import transforms + +from transformers import AutoModelForObjectDetection +from transformers import TableTransformerForObjectDetection + +from typing import Literal + +from TextExtraction import GetOCRText + +device = "cuda" if torch.cuda.is_available() else "cpu" + +class MaxResize(object): + def __init__(self, max_size=800): + self.max_size = max_size + def __call__(self, image): + width, height = image.size + current_max_size = max(width, height) + scale = self.max_size / current_max_size + resized_image = image.resize((int(round(scale*width)), int(round(scale*height)))) + return resized_image + + +class ImageToTable: + def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None: + self._table_model = "microsoft/table-transformer-detection" + self._structure_model = "microsoft/table-structure-recognition-v1.1-all" + self._image = None + self._table_image = None + self._file_path = None + self.text_data =[] + self.tokens = [] + self.detection_class_thresholds = { + "table": 0.5, + "table rotated": 0.5, + "no object": 10 + } + # for ocr stuffs + self.get_ocr = GetOCRText() + + def _prepare_for_nn_input(self, image): + structure_transform = transforms.Compose([ + MaxResize(1000), + transforms.ToTensor(), + transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]) + ]) + pixel_values = structure_transform(image).unsqueeze(0) + pixel_values = pixel_values.to(device) + return pixel_values + + + def _detection(self, detection_type: Literal['table', 'table_structure'], image): + if detection_type == "table": + model = AutoModelForObjectDetection.from_pretrained(self._table_model) + elif detection_type == "table_structure": + model = TableTransformerForObjectDetection.from_pretrained(self._structure_model) + pixel_values = self._prepare_for_nn_input(image) + pixel_values = pixel_values.to(device) + model.to(device) + with torch.no_grad(): + outputs = model(pixel_values) + id2label = model.config.id2label + id2label[len(model.config.id2label)] = "no object" + return outputs, id2label + + + def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10): + tables_crop = [] + for obj in objects: + if obj['score'] < class_thresholds[obj['label']]: + continue + cropped_table = {} + bbox = obj['bbox'] + bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding] + cropped_img = img.crop(bbox) + table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5] + for token in table_tokens: + token['bbox'] = [token['bbox'][0]-bbox[0], + token['bbox'][1]-bbox[1], + token['bbox'][2]-bbox[0], + token['bbox'][3]-bbox[1]] + if obj['label'] == 'table rotated': + cropped_img = cropped_img.rotate(270, expand=True) + for token in table_tokens: + bbox = token['bbox'] + bbox = [cropped_img.size[0]-bbox[3]-1, + bbox[0], + cropped_img.size[0]-bbox[1]-1, + bbox[2]] + token['bbox'] = bbox + cropped_table['image'] = cropped_img + cropped_table['tokens'] = table_tokens + tables_crop.append(cropped_table) + return tables_crop + + def outputs_to_objects(self, outputs, img_size, id2label): + def box_cxcywh_to_xyxy(x): + x_c, y_c, w, h = x.unbind(-1) + b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)] + return torch.stack(b, dim=1) + def rescale_bboxes(out_bbox, size): + img_w, img_h = size + b = box_cxcywh_to_xyxy(out_bbox) + b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32) + return b + m = outputs.logits.softmax(-1).max(-1) + pred_labels = list(m.indices.detach().cpu().numpy())[0] + pred_scores = list(m.values.detach().cpu().numpy())[0] + pred_bboxes = outputs['pred_boxes'].detach().cpu()[0] + pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)] + objects = [] + for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes): + class_label = id2label[int(label)] + if not class_label == 'no object': + objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]}) + return objects + + def get_cell_coordinates_by_row(self, table_data): + rows = [entry for entry in table_data if entry['label'] == 'table row'] + columns = [entry for entry in table_data if entry['label'] == 'table column'] + rows.sort(key=lambda x: x['bbox'][1]) + columns.sort(key=lambda x: x['bbox'][0]) + def find_cell_coordinates(row, column): + cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]] + return cell_bbox + cell_coordinates = [] + for row in rows: + row_cells = [] + for column in columns: + cell_bbox = find_cell_coordinates(row, column) + row_cells.append({'column': column['bbox'], 'cell': cell_bbox}) + row_cells.sort(key=lambda x: x['column'][0]) + cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)}) + cell_coordinates.sort(key=lambda x: x['row'][1]) + return cell_coordinates + + def apply_ocr(self, cell_coordinates): + for idx, row in enumerate(tqdm(cell_coordinates)): + row_text = [] + for cell in row["cells"]: + print(cell) + cell_image = np.array(self._table_image.crop(cell["cell"])) + result = self.get_ocr.extract_text(np.array(cell_image)) + row_text.append(result) + self.text_data.append(row_text) + + + def table_to_csv(self, image_path): + self._image = Image.open(image_path).convert("RGB") + outputs, id2label = self._detection(detection_type='table', image=self._image) + objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label) + tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0) + for table_crop in tables_crop: + cropped_table = table_crop['image'].convert("RGB") + self._table_image = cropped_table + resized_image = self._prepare_for_nn_input(cropped_table) + outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table) + cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label) + cell_coordinates = self.get_cell_coordinates_by_row(cells) + self.apply_ocr(cell_coordinates) + if self.text_data: + print("\n".join(",".join(row) for row in self.text_data)) + return "\n".join(",".join(row) for row in self.text_data) + return "" + + + + + \ No newline at end of file diff --git a/private_gpt/components/ocr_components/table_ocr_api.py b/private_gpt/components/ocr_components/table_ocr_api.py new file mode 100644 index 00000000..278899ef --- /dev/null +++ b/private_gpt/components/ocr_components/table_ocr_api.py @@ -0,0 +1,55 @@ +from fastapi import FastAPI, File, UploadFile, Response +from fastapi.responses import FileResponse +from pydantic import BaseModel +from docx import Document +import os +import fitz + +from table_ocr import ImageToTable +from TextExtraction import GetOCRText + +app = FastAPI() + + + +@app.post("/pdf_ocr") +async def get_pdf_ocr(file: UploadFile = File(...)): + UPLOAD_DIR = os.getcwd() + try: + contents = await file.read() + except Exception: + return {"message": "There was an error uploading the file"} + + # Save the uploaded file to the dir + file_path = os.path.join(UPLOAD_DIR, file.filename) + with open(file_path, "wb") as f: + f.write(contents) + + doc = Document() + ocr = GetOCRText() + img_tab = ImageToTable() + pdf_doc = fitz.open(file_path) + for page_index in range(len(pdf_doc)): # iterate over pdf pages + page = pdf_doc[page_index] # get the page + image_list = page.get_images() + + for image_index, img in enumerate(image_list, start=1): # enumerate the image list + xref = img[0] + pix = fitz.Pixmap(pdf_doc, xref) + + if pix.n - pix.alpha > 3: + pix = fitz.Pixmap(fitz.csRGB, pix)("RGB", [pix.width, pix.height], pix.samples) + image_path = "page_%s-image_%s.png" % (page_index, image_index) + pix.save("page_%s-image_%s.png" % (page_index, image_index)) # save the image as png + pixs = None + extracted_text = ocr.extract_text(image_file=True, file_path=image_path) + doc.add_paragraph(extracted_text) + table_data = img_tab.table_to_csv(image_path) + print(table_data) + doc.add_paragraph(table_data) + # remove image file + + doc.save(os.path.join(UPLOAD_DIR, "ocr_result.docx")) + + return FileResponse(path=os.path.join(UPLOAD_DIR, "ocr_result.docx"), filename="ocr_result.docx", media_type="application/pdf") +