Added OCRComponents

This commit is contained in:
quick-karsth 2024-02-15 11:46:06 +05:45
parent 7bba4d13eb
commit a29e0f4253

View File

@ -1,175 +1,53 @@
import io
# from paddleocr import PaddleOCR
import cv2
import csv
import torch
import numpy as np
from PIL import Image
from tqdm.auto import tqdm
from torchvision import transforms
from transformers import AutoModelForObjectDetection
from transformers import TableTransformerForObjectDetection
from typing import Literal
from TextExtraction import GetOCRText
device = "cuda" if torch.cuda.is_available() else "cpu"
class MaxResize(object):
def __init__(self, max_size=800):
self.max_size = max_size
def __call__(self, image):
width, height = image.size
current_max_size = max(width, height)
scale = self.max_size / current_max_size
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
return resized_image
from doctr.models import ocr_predictor
from doctr.io import DocumentFile
class ImageToTable:
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
self._table_model = "microsoft/table-transformer-detection"
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
class GetOCRText:
def __init__(self) -> None:
self._image = None
self._table_image = None
self._file_path = None
self.text_data =[]
self.tokens = []
self.detection_class_thresholds = {
"table": 0.5,
"table rotated": 0.5,
"no object": 10
}
# for ocr stuffs
self.get_ocr = GetOCRText()
# self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
def _prepare_for_nn_input(self, image):
structure_transform = transforms.Compose([
MaxResize(1000),
transforms.ToTensor(),
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
])
pixel_values = structure_transform(image).unsqueeze(0)
pixel_values = pixel_values.to(device)
return pixel_values
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
if detection_type == "table":
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
elif detection_type == "table_structure":
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
pixel_values = self._prepare_for_nn_input(image)
pixel_values = pixel_values.to(device)
model.to(device)
with torch.no_grad():
outputs = model(pixel_values)
id2label = model.config.id2label
id2label[len(model.config.id2label)] = "no object"
return outputs, id2label
def _preprocess_image(self, img):
resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC)
gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
_, binary = cv2.threshold(gray_image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
return binary
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
tables_crop = []
for obj in objects:
if obj['score'] < class_thresholds[obj['label']]:
continue
cropped_table = {}
bbox = obj['bbox']
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
cropped_img = img.crop(bbox)
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
for token in table_tokens:
token['bbox'] = [token['bbox'][0]-bbox[0],
token['bbox'][1]-bbox[1],
token['bbox'][2]-bbox[0],
token['bbox'][3]-bbox[1]]
if obj['label'] == 'table rotated':
cropped_img = cropped_img.rotate(270, expand=True)
for token in table_tokens:
bbox = token['bbox']
bbox = [cropped_img.size[0]-bbox[3]-1,
bbox[0],
cropped_img.size[0]-bbox[1]-1,
bbox[2]]
token['bbox'] = bbox
cropped_table['image'] = cropped_img
cropped_table['tokens'] = table_tokens
tables_crop.append(cropped_table)
return tables_crop
def outputs_to_objects(self, outputs, img_size, id2label):
def box_cxcywh_to_xyxy(x):
x_c, y_c, w, h = x.unbind(-1)
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
return torch.stack(b, dim=1)
def rescale_bboxes(out_bbox, size):
img_w, img_h = size
b = box_cxcywh_to_xyxy(out_bbox)
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
return b
m = outputs.logits.softmax(-1).max(-1)
pred_labels = list(m.indices.detach().cpu().numpy())[0]
pred_scores = list(m.values.detach().cpu().numpy())[0]
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
objects = []
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
class_label = id2label[int(label)]
if not class_label == 'no object':
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
return objects
def get_cell_coordinates_by_row(self, table_data):
rows = [entry for entry in table_data if entry['label'] == 'table row']
columns = [entry for entry in table_data if entry['label'] == 'table column']
rows.sort(key=lambda x: x['bbox'][1])
columns.sort(key=lambda x: x['bbox'][0])
def find_cell_coordinates(row, column):
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
return cell_bbox
cell_coordinates = []
for row in rows:
row_cells = []
for column in columns:
cell_bbox = find_cell_coordinates(row, column)
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
row_cells.sort(key=lambda x: x['column'][0])
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
cell_coordinates.sort(key=lambda x: x['row'][1])
return cell_coordinates
## paddleOCR
# def extract_text(self, cell_image):
# text = ""
# self._image = cell_image
# preprocessd_image = self._preprocess_image(self._image)
# results = self.ocr.ocr(preprocessd_image, cls=True)
# print(results)
# if len(results) > 0:
# for result in results[0]:
# text += f"{result[-1][0]} "
# else:
# text = ""
# return text
def apply_ocr(self, cell_coordinates):
for idx, row in enumerate(tqdm(cell_coordinates)):
row_text = []
for cell in row["cells"]:
print(cell)
cell_image = np.array(self._table_image.crop(cell["cell"]))
result = self.get_ocr.extract_text(np.array(cell_image))
row_text.append(result)
self.text_data.append(row_text)
def table_to_csv(self, image_path):
self._image = Image.open(image_path).convert("RGB")
outputs, id2label = self._detection(detection_type='table', image=self._image)
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
for table_crop in tables_crop:
cropped_table = table_crop['image'].convert("RGB")
self._table_image = cropped_table
resized_image = self._prepare_for_nn_input(cropped_table)
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
cell_coordinates = self.get_cell_coordinates_by_row(cells)
self.apply_ocr(cell_coordinates)
if self.text_data:
print("\n".join(",".join(row) for row in self.text_data))
return "\n".join(",".join(row) for row in self.text_data)
return ""
## docTR OCR
def extract_text(self, cell_image=None, image_file=False, file_path=None):
text = ""
if image_file:
pdf_file = DocumentFile.from_images(file_path)
result = self.doctr(pdf_file)
output = result.export()
else:
self._image = cell_image
preprocessd_image = self._preprocess_image(self._image)
result = self.doctr([self._image])
output = result.export()
for obj1 in output['pages'][0]["blocks"]:
for obj2 in obj1["lines"]:
for obj3 in obj2["words"]:
text += (f"{obj3['value']} ").replace("\n", "")
text = text + "\n"
if text:
return text
return " "