mirror of
https://github.com/imartinez/privateGPT.git
synced 2025-06-29 08:47:19 +00:00
Added OCRComponents
This commit is contained in:
parent
b9949204de
commit
7bba4d13eb
175
private_gpt/components/ocr_components/TextExtraction.py
Normal file
175
private_gpt/components/ocr_components/TextExtraction.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import io
|
||||||
|
import cv2
|
||||||
|
import csv
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from transformers import AutoModelForObjectDetection
|
||||||
|
from transformers import TableTransformerForObjectDetection
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from TextExtraction import GetOCRText
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
class MaxResize(object):
|
||||||
|
def __init__(self, max_size=800):
|
||||||
|
self.max_size = max_size
|
||||||
|
def __call__(self, image):
|
||||||
|
width, height = image.size
|
||||||
|
current_max_size = max(width, height)
|
||||||
|
scale = self.max_size / current_max_size
|
||||||
|
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToTable:
|
||||||
|
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
|
||||||
|
self._table_model = "microsoft/table-transformer-detection"
|
||||||
|
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
|
||||||
|
self._image = None
|
||||||
|
self._table_image = None
|
||||||
|
self._file_path = None
|
||||||
|
self.text_data =[]
|
||||||
|
self.tokens = []
|
||||||
|
self.detection_class_thresholds = {
|
||||||
|
"table": 0.5,
|
||||||
|
"table rotated": 0.5,
|
||||||
|
"no object": 10
|
||||||
|
}
|
||||||
|
# for ocr stuffs
|
||||||
|
self.get_ocr = GetOCRText()
|
||||||
|
|
||||||
|
def _prepare_for_nn_input(self, image):
|
||||||
|
structure_transform = transforms.Compose([
|
||||||
|
MaxResize(1000),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
pixel_values = structure_transform(image).unsqueeze(0)
|
||||||
|
pixel_values = pixel_values.to(device)
|
||||||
|
return pixel_values
|
||||||
|
|
||||||
|
|
||||||
|
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
|
||||||
|
if detection_type == "table":
|
||||||
|
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
|
||||||
|
elif detection_type == "table_structure":
|
||||||
|
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
|
||||||
|
pixel_values = self._prepare_for_nn_input(image)
|
||||||
|
pixel_values = pixel_values.to(device)
|
||||||
|
model.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
id2label = model.config.id2label
|
||||||
|
id2label[len(model.config.id2label)] = "no object"
|
||||||
|
return outputs, id2label
|
||||||
|
|
||||||
|
|
||||||
|
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
|
||||||
|
tables_crop = []
|
||||||
|
for obj in objects:
|
||||||
|
if obj['score'] < class_thresholds[obj['label']]:
|
||||||
|
continue
|
||||||
|
cropped_table = {}
|
||||||
|
bbox = obj['bbox']
|
||||||
|
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
|
||||||
|
cropped_img = img.crop(bbox)
|
||||||
|
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
|
||||||
|
for token in table_tokens:
|
||||||
|
token['bbox'] = [token['bbox'][0]-bbox[0],
|
||||||
|
token['bbox'][1]-bbox[1],
|
||||||
|
token['bbox'][2]-bbox[0],
|
||||||
|
token['bbox'][3]-bbox[1]]
|
||||||
|
if obj['label'] == 'table rotated':
|
||||||
|
cropped_img = cropped_img.rotate(270, expand=True)
|
||||||
|
for token in table_tokens:
|
||||||
|
bbox = token['bbox']
|
||||||
|
bbox = [cropped_img.size[0]-bbox[3]-1,
|
||||||
|
bbox[0],
|
||||||
|
cropped_img.size[0]-bbox[1]-1,
|
||||||
|
bbox[2]]
|
||||||
|
token['bbox'] = bbox
|
||||||
|
cropped_table['image'] = cropped_img
|
||||||
|
cropped_table['tokens'] = table_tokens
|
||||||
|
tables_crop.append(cropped_table)
|
||||||
|
return tables_crop
|
||||||
|
|
||||||
|
def outputs_to_objects(self, outputs, img_size, id2label):
|
||||||
|
def box_cxcywh_to_xyxy(x):
|
||||||
|
x_c, y_c, w, h = x.unbind(-1)
|
||||||
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||||
|
return torch.stack(b, dim=1)
|
||||||
|
def rescale_bboxes(out_bbox, size):
|
||||||
|
img_w, img_h = size
|
||||||
|
b = box_cxcywh_to_xyxy(out_bbox)
|
||||||
|
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
||||||
|
return b
|
||||||
|
m = outputs.logits.softmax(-1).max(-1)
|
||||||
|
pred_labels = list(m.indices.detach().cpu().numpy())[0]
|
||||||
|
pred_scores = list(m.values.detach().cpu().numpy())[0]
|
||||||
|
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
|
||||||
|
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
|
||||||
|
objects = []
|
||||||
|
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
||||||
|
class_label = id2label[int(label)]
|
||||||
|
if not class_label == 'no object':
|
||||||
|
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
|
||||||
|
return objects
|
||||||
|
|
||||||
|
def get_cell_coordinates_by_row(self, table_data):
|
||||||
|
rows = [entry for entry in table_data if entry['label'] == 'table row']
|
||||||
|
columns = [entry for entry in table_data if entry['label'] == 'table column']
|
||||||
|
rows.sort(key=lambda x: x['bbox'][1])
|
||||||
|
columns.sort(key=lambda x: x['bbox'][0])
|
||||||
|
def find_cell_coordinates(row, column):
|
||||||
|
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
|
||||||
|
return cell_bbox
|
||||||
|
cell_coordinates = []
|
||||||
|
for row in rows:
|
||||||
|
row_cells = []
|
||||||
|
for column in columns:
|
||||||
|
cell_bbox = find_cell_coordinates(row, column)
|
||||||
|
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
|
||||||
|
row_cells.sort(key=lambda x: x['column'][0])
|
||||||
|
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
|
||||||
|
cell_coordinates.sort(key=lambda x: x['row'][1])
|
||||||
|
return cell_coordinates
|
||||||
|
|
||||||
|
def apply_ocr(self, cell_coordinates):
|
||||||
|
for idx, row in enumerate(tqdm(cell_coordinates)):
|
||||||
|
row_text = []
|
||||||
|
for cell in row["cells"]:
|
||||||
|
print(cell)
|
||||||
|
cell_image = np.array(self._table_image.crop(cell["cell"]))
|
||||||
|
result = self.get_ocr.extract_text(np.array(cell_image))
|
||||||
|
row_text.append(result)
|
||||||
|
self.text_data.append(row_text)
|
||||||
|
|
||||||
|
|
||||||
|
def table_to_csv(self, image_path):
|
||||||
|
self._image = Image.open(image_path).convert("RGB")
|
||||||
|
outputs, id2label = self._detection(detection_type='table', image=self._image)
|
||||||
|
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
|
||||||
|
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
|
||||||
|
for table_crop in tables_crop:
|
||||||
|
cropped_table = table_crop['image'].convert("RGB")
|
||||||
|
self._table_image = cropped_table
|
||||||
|
resized_image = self._prepare_for_nn_input(cropped_table)
|
||||||
|
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
|
||||||
|
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
|
||||||
|
cell_coordinates = self.get_cell_coordinates_by_row(cells)
|
||||||
|
self.apply_ocr(cell_coordinates)
|
||||||
|
if self.text_data:
|
||||||
|
# print(self.text_data)
|
||||||
|
return "".join(",".join(row) for row in self.text_data)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
175
private_gpt/components/ocr_components/table_ocr.py
Normal file
175
private_gpt/components/ocr_components/table_ocr.py
Normal file
@ -0,0 +1,175 @@
|
|||||||
|
import io
|
||||||
|
import cv2
|
||||||
|
import csv
|
||||||
|
import torch
|
||||||
|
import numpy as np
|
||||||
|
from PIL import Image
|
||||||
|
from tqdm.auto import tqdm
|
||||||
|
from torchvision import transforms
|
||||||
|
|
||||||
|
from transformers import AutoModelForObjectDetection
|
||||||
|
from transformers import TableTransformerForObjectDetection
|
||||||
|
|
||||||
|
from typing import Literal
|
||||||
|
|
||||||
|
from TextExtraction import GetOCRText
|
||||||
|
|
||||||
|
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||||
|
|
||||||
|
class MaxResize(object):
|
||||||
|
def __init__(self, max_size=800):
|
||||||
|
self.max_size = max_size
|
||||||
|
def __call__(self, image):
|
||||||
|
width, height = image.size
|
||||||
|
current_max_size = max(width, height)
|
||||||
|
scale = self.max_size / current_max_size
|
||||||
|
resized_image = image.resize((int(round(scale*width)), int(round(scale*height))))
|
||||||
|
return resized_image
|
||||||
|
|
||||||
|
|
||||||
|
class ImageToTable:
|
||||||
|
def __init__(self, tokens:list=None, detection_class_thresholds:dict=None) -> None:
|
||||||
|
self._table_model = "microsoft/table-transformer-detection"
|
||||||
|
self._structure_model = "microsoft/table-structure-recognition-v1.1-all"
|
||||||
|
self._image = None
|
||||||
|
self._table_image = None
|
||||||
|
self._file_path = None
|
||||||
|
self.text_data =[]
|
||||||
|
self.tokens = []
|
||||||
|
self.detection_class_thresholds = {
|
||||||
|
"table": 0.5,
|
||||||
|
"table rotated": 0.5,
|
||||||
|
"no object": 10
|
||||||
|
}
|
||||||
|
# for ocr stuffs
|
||||||
|
self.get_ocr = GetOCRText()
|
||||||
|
|
||||||
|
def _prepare_for_nn_input(self, image):
|
||||||
|
structure_transform = transforms.Compose([
|
||||||
|
MaxResize(1000),
|
||||||
|
transforms.ToTensor(),
|
||||||
|
transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
|
||||||
|
])
|
||||||
|
pixel_values = structure_transform(image).unsqueeze(0)
|
||||||
|
pixel_values = pixel_values.to(device)
|
||||||
|
return pixel_values
|
||||||
|
|
||||||
|
|
||||||
|
def _detection(self, detection_type: Literal['table', 'table_structure'], image):
|
||||||
|
if detection_type == "table":
|
||||||
|
model = AutoModelForObjectDetection.from_pretrained(self._table_model)
|
||||||
|
elif detection_type == "table_structure":
|
||||||
|
model = TableTransformerForObjectDetection.from_pretrained(self._structure_model)
|
||||||
|
pixel_values = self._prepare_for_nn_input(image)
|
||||||
|
pixel_values = pixel_values.to(device)
|
||||||
|
model.to(device)
|
||||||
|
with torch.no_grad():
|
||||||
|
outputs = model(pixel_values)
|
||||||
|
id2label = model.config.id2label
|
||||||
|
id2label[len(model.config.id2label)] = "no object"
|
||||||
|
return outputs, id2label
|
||||||
|
|
||||||
|
|
||||||
|
def objects_to_crops(self, img, tokens, objects, class_thresholds, padding=10):
|
||||||
|
tables_crop = []
|
||||||
|
for obj in objects:
|
||||||
|
if obj['score'] < class_thresholds[obj['label']]:
|
||||||
|
continue
|
||||||
|
cropped_table = {}
|
||||||
|
bbox = obj['bbox']
|
||||||
|
bbox = [bbox[0]-padding, bbox[1]-padding, bbox[2]+padding, bbox[3]+padding]
|
||||||
|
cropped_img = img.crop(bbox)
|
||||||
|
table_tokens = [token for token in tokens if iob(token['bbox'], bbox) >= 0.5]
|
||||||
|
for token in table_tokens:
|
||||||
|
token['bbox'] = [token['bbox'][0]-bbox[0],
|
||||||
|
token['bbox'][1]-bbox[1],
|
||||||
|
token['bbox'][2]-bbox[0],
|
||||||
|
token['bbox'][3]-bbox[1]]
|
||||||
|
if obj['label'] == 'table rotated':
|
||||||
|
cropped_img = cropped_img.rotate(270, expand=True)
|
||||||
|
for token in table_tokens:
|
||||||
|
bbox = token['bbox']
|
||||||
|
bbox = [cropped_img.size[0]-bbox[3]-1,
|
||||||
|
bbox[0],
|
||||||
|
cropped_img.size[0]-bbox[1]-1,
|
||||||
|
bbox[2]]
|
||||||
|
token['bbox'] = bbox
|
||||||
|
cropped_table['image'] = cropped_img
|
||||||
|
cropped_table['tokens'] = table_tokens
|
||||||
|
tables_crop.append(cropped_table)
|
||||||
|
return tables_crop
|
||||||
|
|
||||||
|
def outputs_to_objects(self, outputs, img_size, id2label):
|
||||||
|
def box_cxcywh_to_xyxy(x):
|
||||||
|
x_c, y_c, w, h = x.unbind(-1)
|
||||||
|
b = [(x_c - 0.5 * w), (y_c - 0.5 * h), (x_c + 0.5 * w), (y_c + 0.5 * h)]
|
||||||
|
return torch.stack(b, dim=1)
|
||||||
|
def rescale_bboxes(out_bbox, size):
|
||||||
|
img_w, img_h = size
|
||||||
|
b = box_cxcywh_to_xyxy(out_bbox)
|
||||||
|
b = b * torch.tensor([img_w, img_h, img_w, img_h], dtype=torch.float32)
|
||||||
|
return b
|
||||||
|
m = outputs.logits.softmax(-1).max(-1)
|
||||||
|
pred_labels = list(m.indices.detach().cpu().numpy())[0]
|
||||||
|
pred_scores = list(m.values.detach().cpu().numpy())[0]
|
||||||
|
pred_bboxes = outputs['pred_boxes'].detach().cpu()[0]
|
||||||
|
pred_bboxes = [elem.tolist() for elem in rescale_bboxes(pred_bboxes, img_size)]
|
||||||
|
objects = []
|
||||||
|
for label, score, bbox in zip(pred_labels, pred_scores, pred_bboxes):
|
||||||
|
class_label = id2label[int(label)]
|
||||||
|
if not class_label == 'no object':
|
||||||
|
objects.append({'label': class_label, 'score': float(score), 'bbox': [float(elem) for elem in bbox]})
|
||||||
|
return objects
|
||||||
|
|
||||||
|
def get_cell_coordinates_by_row(self, table_data):
|
||||||
|
rows = [entry for entry in table_data if entry['label'] == 'table row']
|
||||||
|
columns = [entry for entry in table_data if entry['label'] == 'table column']
|
||||||
|
rows.sort(key=lambda x: x['bbox'][1])
|
||||||
|
columns.sort(key=lambda x: x['bbox'][0])
|
||||||
|
def find_cell_coordinates(row, column):
|
||||||
|
cell_bbox = [column['bbox'][0], row['bbox'][1], column['bbox'][2], row['bbox'][3]]
|
||||||
|
return cell_bbox
|
||||||
|
cell_coordinates = []
|
||||||
|
for row in rows:
|
||||||
|
row_cells = []
|
||||||
|
for column in columns:
|
||||||
|
cell_bbox = find_cell_coordinates(row, column)
|
||||||
|
row_cells.append({'column': column['bbox'], 'cell': cell_bbox})
|
||||||
|
row_cells.sort(key=lambda x: x['column'][0])
|
||||||
|
cell_coordinates.append({'row': row['bbox'], 'cells': row_cells, 'cell_count': len(row_cells)})
|
||||||
|
cell_coordinates.sort(key=lambda x: x['row'][1])
|
||||||
|
return cell_coordinates
|
||||||
|
|
||||||
|
def apply_ocr(self, cell_coordinates):
|
||||||
|
for idx, row in enumerate(tqdm(cell_coordinates)):
|
||||||
|
row_text = []
|
||||||
|
for cell in row["cells"]:
|
||||||
|
print(cell)
|
||||||
|
cell_image = np.array(self._table_image.crop(cell["cell"]))
|
||||||
|
result = self.get_ocr.extract_text(np.array(cell_image))
|
||||||
|
row_text.append(result)
|
||||||
|
self.text_data.append(row_text)
|
||||||
|
|
||||||
|
|
||||||
|
def table_to_csv(self, image_path):
|
||||||
|
self._image = Image.open(image_path).convert("RGB")
|
||||||
|
outputs, id2label = self._detection(detection_type='table', image=self._image)
|
||||||
|
objects = self.outputs_to_objects(outputs=outputs, img_size=self._image.size, id2label=id2label)
|
||||||
|
tables_crop = self.objects_to_crops(self._image, self.tokens, objects, self.detection_class_thresholds, padding=0)
|
||||||
|
for table_crop in tables_crop:
|
||||||
|
cropped_table = table_crop['image'].convert("RGB")
|
||||||
|
self._table_image = cropped_table
|
||||||
|
resized_image = self._prepare_for_nn_input(cropped_table)
|
||||||
|
outputs, structure_id2label = self._detection(detection_type='table_structure', image=cropped_table)
|
||||||
|
cells = self.outputs_to_objects(outputs, cropped_table.size, structure_id2label)
|
||||||
|
cell_coordinates = self.get_cell_coordinates_by_row(cells)
|
||||||
|
self.apply_ocr(cell_coordinates)
|
||||||
|
if self.text_data:
|
||||||
|
print("\n".join(",".join(row) for row in self.text_data))
|
||||||
|
return "\n".join(",".join(row) for row in self.text_data)
|
||||||
|
return ""
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
|
55
private_gpt/components/ocr_components/table_ocr_api.py
Normal file
55
private_gpt/components/ocr_components/table_ocr_api.py
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
from fastapi import FastAPI, File, UploadFile, Response
|
||||||
|
from fastapi.responses import FileResponse
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from docx import Document
|
||||||
|
import os
|
||||||
|
import fitz
|
||||||
|
|
||||||
|
from table_ocr import ImageToTable
|
||||||
|
from TextExtraction import GetOCRText
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
@app.post("/pdf_ocr")
|
||||||
|
async def get_pdf_ocr(file: UploadFile = File(...)):
|
||||||
|
UPLOAD_DIR = os.getcwd()
|
||||||
|
try:
|
||||||
|
contents = await file.read()
|
||||||
|
except Exception:
|
||||||
|
return {"message": "There was an error uploading the file"}
|
||||||
|
|
||||||
|
# Save the uploaded file to the dir
|
||||||
|
file_path = os.path.join(UPLOAD_DIR, file.filename)
|
||||||
|
with open(file_path, "wb") as f:
|
||||||
|
f.write(contents)
|
||||||
|
|
||||||
|
doc = Document()
|
||||||
|
ocr = GetOCRText()
|
||||||
|
img_tab = ImageToTable()
|
||||||
|
pdf_doc = fitz.open(file_path)
|
||||||
|
for page_index in range(len(pdf_doc)): # iterate over pdf pages
|
||||||
|
page = pdf_doc[page_index] # get the page
|
||||||
|
image_list = page.get_images()
|
||||||
|
|
||||||
|
for image_index, img in enumerate(image_list, start=1): # enumerate the image list
|
||||||
|
xref = img[0]
|
||||||
|
pix = fitz.Pixmap(pdf_doc, xref)
|
||||||
|
|
||||||
|
if pix.n - pix.alpha > 3:
|
||||||
|
pix = fitz.Pixmap(fitz.csRGB, pix)("RGB", [pix.width, pix.height], pix.samples)
|
||||||
|
image_path = "page_%s-image_%s.png" % (page_index, image_index)
|
||||||
|
pix.save("page_%s-image_%s.png" % (page_index, image_index)) # save the image as png
|
||||||
|
pixs = None
|
||||||
|
extracted_text = ocr.extract_text(image_file=True, file_path=image_path)
|
||||||
|
doc.add_paragraph(extracted_text)
|
||||||
|
table_data = img_tab.table_to_csv(image_path)
|
||||||
|
print(table_data)
|
||||||
|
doc.add_paragraph(table_data)
|
||||||
|
# remove image file
|
||||||
|
|
||||||
|
doc.save(os.path.join(UPLOAD_DIR, "ocr_result.docx"))
|
||||||
|
|
||||||
|
return FileResponse(path=os.path.join(UPLOAD_DIR, "ocr_result.docx"), filename="ocr_result.docx", media_type="application/pdf")
|
||||||
|
|
Loading…
Reference in New Issue
Block a user