mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-28 08:18:09 +00:00
Added OCRComponents
This commit is contained in:
parent
7bba4d13eb
commit
a29e0f4253
@ -1,175 +1,53 @@
|
||||
import io
|
||||
# from paddleocr import PaddleOCR
|
||||
import cv2
|
||||
import csv
|
||||
import torch
|
||||
import numpy as np
|
||||
from PIL import Image
|
||||
from tqdm.auto import tqdm
|
||||
from torchvision import transforms
|
||||
|
||||
from transformers import AutoModelForObjectDetection
|
||||
from transformers import TableTransformerForObjectDetection
|
||||
|
||||
from typing import Literal
|
||||
|
||||
from TextExtraction import GetOCRText
|
||||
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
|
||||
class MaxResize(object):
|
||||
def __init__(self, max_size=800):
|
||||
self.max_size = max_size
|
||||
def __call__(self, image):
|
||||
width, height = image.size
|
||||
current_max_size = max(width, height)
|
||||
scale = self.max_size / current_max_size
|
||||
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
|
||||
return resized_image
|
||||
from doctr.models import ocr_predictor
|
||||
from doctr.io import DocumentFile
|
||||
|
||||
|
||||
class ImageToTable:
|
||||
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
|
||||
self._table_model = "microsoft/table-transformer-detection"
|
||||
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
|
||||
class GetOCRText:
|
||||
def __init__(self) -> None:
|
||||
self._image = None
|
||||
self._table_image = None
|
||||
self._file_path = None
|
||||
self.text_data =[]
|
||||
self.tokens = []
|
||||
self.detection_class_thresholds = {
|
||||
"table": 0.5,
|
||||
"table rotated": 0.5,
|
||||
"no object": 10
|
||||
}
|
||||
# for ocr stuffs
|
||||
self.get_ocr = GetOCRText()
|
||||
# self.ocr = PaddleOCR(use_angle_cls=True, lang='en')
|
||||
self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True)
|
||||
|
||||
def _prepare_for_nn_input(self, image):
|
||||
structure_transform = transforms.Compose([
|
||||
MaxResize(1000),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||
])
|
||||
pixel_values = structure_transform(image).unsqueeze(0)
|
||||
pixel_values = pixel_values.to(device)
|
||||
return pixel_values
|
||||
|
||||
|
||||
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
|
||||
if detection_type == "table":
|
||||
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
|
||||
elif detection_type == "table_structure":
|
||||
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
|
||||
pixel_values = self._prepare_for_nn_input(image)
|
||||
pixel_values = pixel_values.to(device)
|
||||
model.to(device)
|
||||
with torch.no_grad():
|
||||
outputs = model(pixel_values)
|
||||
id2label = model.config.id2label
|
||||
id2label[len(model.config.id2label)] = "no object"
|
||||
return outputs, id2label
|
||||
|
||||
def _preprocess_image(self, img):
|
||||
resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC)
|
||||
gray_image = cv2.cvtColor(resized_image, cv2.COLOR_BGR2GRAY)
|
||||
_, binary = cv2.threshold(gray_image, 128, 255, cv2.THRESH_BINARY_INV + cv2.THRESH_OTSU)
|
||||
return binary
|
||||
|
||||
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
|
||||
tables_crop = []
|
||||
for obj in objects:
|
||||
if obj['score'] < class_thresholds[obj['label']]:
|
||||
continue
|
||||
cropped_table = {}
|
||||
bbox = obj['bbox']
|
||||
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
|
||||
cropped_img = img.crop(bbox)
|
||||
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
|
||||
for token in table_tokens:
|
||||
token['bbox'] = [token['bbox'][0]-bbox[0],
|
||||
token['bbox'][1]-bbox[1],
|
||||
token['bbox'][2]-bbox[0],
|
||||
token['bbox'][3]-bbox[1]]
|
||||
if obj['label'] == 'table rotated':
|
||||
cropped_img = cropped_img.rotate(270, expand=True)
|
||||
for token in table_tokens:
|
||||
bbox = token['bbox']
|
||||
bbox = [cropped_img.size[0]-bbox[3]-1,
|
||||
bbox[0],
|
||||
cropped_img.size[0]-bbox[1]-1,
|
||||
bbox[2]]
|
||||
token['bbox'] = bbox
|
||||
cropped_table['image'] = cropped_img
|
||||
cropped_table['tokens'] = table_tokens
|
||||
tables_crop.append(cropped_table)
|
||||
return tables_crop
|
||||
|
||||
def outputs_to_objects(self, outputs, img_size, id2label):
|
||||
def box_cxcywh_to_xyxy(x):
|
||||
x_c, y_c, w, h = x.unbind(-1)
|
||||
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||
return torch.stack(b, dim=1)
|
||||
def rescale_bboxes(out_bbox, size):
|
||||
img_w, img_h = size
|
||||
b = box_cxcywh_to_xyxy(out_bbox)
|
||||
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
||||
return b
|
||||
m = outputs.logits.softmax(-1).max(-1)
|
||||
pred_labels = list(m.indices.detach().cpu().numpy())[0]
|
||||
pred_scores = list(m.values.detach().cpu().numpy())[0]
|
||||
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
|
||||
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
|
||||
objects = []
|
||||
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
||||
class_label = id2label[int(label)]
|
||||
if not class_label == 'no object':
|
||||
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
|
||||
return objects
|
||||
|
||||
def get_cell_coordinates_by_row(self, table_data):
|
||||
rows = [entry for entry in table_data if entry['label'] == 'table row']
|
||||
columns = [entry for entry in table_data if entry['label'] == 'table column']
|
||||
rows.sort(key=lambda x: x['bbox'][1])
|
||||
columns.sort(key=lambda x: x['bbox'][0])
|
||||
def find_cell_coordinates(row, column):
|
||||
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
|
||||
return cell_bbox
|
||||
cell_coordinates = []
|
||||
for row in rows:
|
||||
row_cells = []
|
||||
for column in columns:
|
||||
cell_bbox = find_cell_coordinates(row, column)
|
||||
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
|
||||
row_cells.sort(key=lambda x: x['column'][0])
|
||||
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
|
||||
cell_coordinates.sort(key=lambda x: x['row'][1])
|
||||
return cell_coordinates
|
||||
## paddleOCR
|
||||
# def extract_text(self, cell_image):
|
||||
# text = ""
|
||||
# self._image = cell_image
|
||||
# preprocessd_image = self._preprocess_image(self._image)
|
||||
# results = self.ocr.ocr(preprocessd_image, cls=True)
|
||||
# print(results)
|
||||
# if len(results) > 0:
|
||||
# for result in results[0]:
|
||||
# text += f"{result[-1][0]} "
|
||||
# else:
|
||||
# text = ""
|
||||
# return text
|
||||
|
||||
def apply_ocr(self, cell_coordinates):
|
||||
for idx, row in enumerate(tqdm(cell_coordinates)):
|
||||
row_text = []
|
||||
for cell in row["cells"]:
|
||||
print(cell)
|
||||
cell_image = np.array(self._table_image.crop(cell["cell"]))
|
||||
result = self.get_ocr.extract_text(np.array(cell_image))
|
||||
row_text.append(result)
|
||||
self.text_data.append(row_text)
|
||||
|
||||
|
||||
def table_to_csv(self, image_path):
|
||||
self._image = Image.open(image_path).convert("RGB")
|
||||
outputs, id2label = self._detection(detection_type='table', image=self._image)
|
||||
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
|
||||
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
|
||||
for table_crop in tables_crop:
|
||||
cropped_table = table_crop['image'].convert("RGB")
|
||||
self._table_image = cropped_table
|
||||
resized_image = self._prepare_for_nn_input(cropped_table)
|
||||
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
|
||||
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
|
||||
cell_coordinates = self.get_cell_coordinates_by_row(cells)
|
||||
self.apply_ocr(cell_coordinates)
|
||||
if self.text_data:
|
||||
print("\n".join(",".join(row) for row in self.text_data))
|
||||
return "\n".join(",".join(row) for row in self.text_data)
|
||||
return ""
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
## docTR OCR
|
||||
def extract_text(self, cell_image=None, image_file=False, file_path=None):
|
||||
text = ""
|
||||
if image_file:
|
||||
pdf_file = DocumentFile.from_images(file_path)
|
||||
result = self.doctr(pdf_file)
|
||||
output = result.export()
|
||||
else:
|
||||
self._image = cell_image
|
||||
preprocessd_image = self._preprocess_image(self._image)
|
||||
result = self.doctr([self._image])
|
||||
output = result.export()
|
||||
for obj1 in output['pages'][0]["blocks"]:
|
||||
for obj2 in obj1["lines"]:
|
||||
for obj3 in obj2["words"]:
|
||||
text += (f"{obj3['value']} ").replace("\n", "")
|
||||
|
||||
text = text + "\n"
|
||||
if text:
|
||||
return text
|
||||
return " "
|
Loading…
Reference in New Issue
Block a user