Added OCRComponents

This commit is contained in:
quick-karsth 2024-02-15 11:33:37 +05:45
parent b9949204de
commit 7bba4d13eb
3 changed files with 405 additions and 0 deletions

View File

@ -0,0 +1,175 @@
import io
import cv2
import csv
import torch
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from transformers import AutoModelForObjectDetection
from transformers import TableTransformerForObjectDetection
from typing import Literal
from TextExtraction import GetOCRText
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
return resized_image
class ImageToTable:
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
self._table_model = "microsoft/table-transformer-detection"
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
self._image = None
self._table_image = None
self._file_path = None
self.text_data =[]
self.tokens = []
self.detection_class_thresholds = {
"table": 0.5,
"table rotated": 0.5,
"no object": 10
}
# for ocr stuffs
self.get_ocr = GetOCRText()
def _prepare_for_nn_input(self, image):
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = structure_transform(image).unsqueeze(0)
pixel_values = pixel_values.to(device)
return pixel_values
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
if detection_type == "table":
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
elif detection_type == "table_structure":
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
pixel_values = self._prepare_for_nn_input(image)
pixel_values = pixel_values.to(device)
model.to(device)
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
return outputs, id2label
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
tables_crop = []
for obj in objects:
if obj['score'] < class_thresholds[obj['label']]:
continue
cropped_table = {}
bbox = obj['bbox']
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
cropped_img = img.crop(bbox)
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
for token in table_tokens:
token['bbox'] = [token['bbox'][0]-bbox[0],
token['bbox'][1]-bbox[1],
token['bbox'][2]-bbox[0],
token['bbox'][3]-bbox[1]]
if obj['label'] == 'table rotated':
cropped_img = cropped_img.rotate(270, expand=True)
for token in table_tokens:
bbox = token['bbox']
bbox = [cropped_img.size[0]-bbox[3]-1,
bbox[0],
cropped_img.size[0]-bbox[1]-1,
bbox[2]]
token['bbox'] = bbox
cropped_table['image'] = cropped_img
cropped_table['tokens'] = table_tokens
tables_crop.append(cropped_table)
return tables_crop
def outputs_to_objects(self, outputs, img_size, id2label):
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
return objects
def get_cell_coordinates_by_row(self, table_data):
rows = [entry for entry in table_data if entry['label'] == 'table row']
columns = [entry for entry in table_data if entry['label'] == 'table column']
rows.sort(key=lambda x: x['bbox'][1])
columns.sort(key=lambda x: x['bbox'][0])
def find_cell_coordinates(row, column):
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
return cell_bbox
cell_coordinates = []
for row in rows:
row_cells = []
for column in columns:
cell_bbox = find_cell_coordinates(row, column)
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
row_cells.sort(key=lambda x: x['column'][0])
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
cell_coordinates.sort(key=lambda x: x['row'][1])
return cell_coordinates
def apply_ocr(self, cell_coordinates):
for idx, row in enumerate(tqdm(cell_coordinates)):
row_text = []
for cell in row["cells"]:
print(cell)
cell_image = np.array(self._table_image.crop(cell["cell"]))
result = self.get_ocr.extract_text(np.array(cell_image))
row_text.append(result)
self.text_data.append(row_text)
def table_to_csv(self, image_path):
self._image = Image.open(image_path).convert("RGB")
outputs, id2label = self._detection(detection_type='table', image=self._image)
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
for table_crop in tables_crop:
cropped_table = table_crop['image'].convert("RGB")
self._table_image = cropped_table
resized_image = self._prepare_for_nn_input(cropped_table)
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
cell_coordinates = self.get_cell_coordinates_by_row(cells)
self.apply_ocr(cell_coordinates)
if self.text_data:
# print(self.text_data)
return "".join(",".join(row) for row in self.text_data)
return ""

View File

@ -0,0 +1,175 @@
import io
import cv2
import csv
import torch
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from transformers import AutoModelForObjectDetection
from transformers import TableTransformerForObjectDetection
from typing import Literal
from TextExtraction import GetOCRText
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
return resized_image
class ImageToTable:
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
self._table_model = "microsoft/table-transformer-detection"
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
self._image = None
self._table_image = None
self._file_path = None
self.text_data =[]
self.tokens = []
self.detection_class_thresholds = {
"table": 0.5,
"table rotated": 0.5,
"no object": 10
}
# for ocr stuffs
self.get_ocr = GetOCRText()
def _prepare_for_nn_input(self, image):
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = structure_transform(image).unsqueeze(0)
pixel_values = pixel_values.to(device)
return pixel_values
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
if detection_type == "table":
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
elif detection_type == "table_structure":
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
pixel_values = self._prepare_for_nn_input(image)
pixel_values = pixel_values.to(device)
model.to(device)
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
return outputs, id2label
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
tables_crop = []
for obj in objects:
if obj['score'] < class_thresholds[obj['label']]:
continue
cropped_table = {}
bbox = obj['bbox']
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
cropped_img = img.crop(bbox)
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
for token in table_tokens:
token['bbox'] = [token['bbox'][0]-bbox[0],
token['bbox'][1]-bbox[1],
token['bbox'][2]-bbox[0],
token['bbox'][3]-bbox[1]]
if obj['label'] == 'table rotated':
cropped_img = cropped_img.rotate(270, expand=True)
for token in table_tokens:
bbox = token['bbox']
bbox = [cropped_img.size[0]-bbox[3]-1,
bbox[0],
cropped_img.size[0]-bbox[1]-1,
bbox[2]]
token['bbox'] = bbox
cropped_table['image'] = cropped_img
cropped_table['tokens'] = table_tokens
tables_crop.append(cropped_table)
return tables_crop
def outputs_to_objects(self, outputs, img_size, id2label):
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
return objects
def get_cell_coordinates_by_row(self, table_data):
rows = [entry for entry in table_data if entry['label'] == 'table row']
columns = [entry for entry in table_data if entry['label'] == 'table column']
rows.sort(key=lambda x: x['bbox'][1])
columns.sort(key=lambda x: x['bbox'][0])
def find_cell_coordinates(row, column):
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
return cell_bbox
cell_coordinates = []
for row in rows:
row_cells = []
for column in columns:
cell_bbox = find_cell_coordinates(row, column)
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
row_cells.sort(key=lambda x: x['column'][0])
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
cell_coordinates.sort(key=lambda x: x['row'][1])
return cell_coordinates
def apply_ocr(self, cell_coordinates):
for idx, row in enumerate(tqdm(cell_coordinates)):
row_text = []
for cell in row["cells"]:
print(cell)
cell_image = np.array(self._table_image.crop(cell["cell"]))
result = self.get_ocr.extract_text(np.array(cell_image))
row_text.append(result)
self.text_data.append(row_text)
def table_to_csv(self, image_path):
self._image = Image.open(image_path).convert("RGB")
outputs, id2label = self._detection(detection_type='table', image=self._image)
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
for table_crop in tables_crop:
cropped_table = table_crop['image'].convert("RGB")
self._table_image = cropped_table
resized_image = self._prepare_for_nn_input(cropped_table)
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
cell_coordinates = self.get_cell_coordinates_by_row(cells)
self.apply_ocr(cell_coordinates)
if self.text_data:
print("\n".join(",".join(row) for row in self.text_data))
return "\n".join(",".join(row) for row in self.text_data)
return ""

View File

@ -0,0 +1,55 @@
from fastapi import FastAPI, File, UploadFile, Response
from fastapi.responses import FileResponse
from pydantic import BaseModel
from docx import Document
import os
import fitz
from table_ocr import ImageToTable
from TextExtraction import GetOCRText
app = FastAPI()
@app.post("/pdf_ocr")
async def get_pdf_ocr(file: UploadFile = File(...)):
UPLOAD_DIR = os.getcwd()
try:
contents = await file.read()
except Exception:
return {"message": "There was an error uploading the file"}
# Save the uploaded file to the dir
file_path = os.path.join(UPLOAD_DIR, file.filename)
with open(file_path, "wb") as f:
f.write(contents)
doc = Document()
ocr = GetOCRText()
img_tab = ImageToTable()
pdf_doc = fitz.open(file_path)
for page_index in range(len(pdf_doc)): # iterate over pdf pages
page = pdf_doc[page_index] # get the page
image_list = page.get_images()
for image_index, img in enumerate(image_list, start=1): # enumerate the image list
xref = img[0]
pix = fitz.Pixmap(pdf_doc, xref)
if pix.n - pix.alpha > 3:
pix = fitz.Pixmap(fitz.csRGB, pix)("RGB", [pix.width, pix.height], pix.samples)
image_path = "page_%s-image_%s.png" % (page_index, image_index)
pix.save("page_%s-image_%s.png" % (page_index, image_index)) # save the image as png
pixs = None
extracted_text = ocr.extract_text(image_file=True, file_path=image_path)
doc.add_paragraph(extracted_text)
table_data = img_tab.table_to_csv(image_path)
print(table_data)
doc.add_paragraph(table_data)
# remove image file
doc.save(os.path.join(UPLOAD_DIR, "ocr_result.docx"))
return FileResponse(path=os.path.join(UPLOAD_DIR, "ocr_result.docx"), filename="ocr_result.docx", media_type="application/pdf")