diff --git a/private_gpt/components/ocr_components/table_ocr.py b/private_gpt/components/ocr_components/table_ocr.py index d317c9f7..f8248dde 100644 --- a/private_gpt/components/ocr_components/table_ocr.py +++ b/private_gpt/components/ocr_components/table_ocr.py @@ -1,14 +1,16 @@ # from paddleocr import PaddleOCR import cv2 +import torch from doctr.models import ocr_predictor from doctr.io import DocumentFile +device = "cuda" if torch.cuda.is_available() else "cpu" class GetOCRText: def __init__(self) -> None: self._image = None # self.ocr = PaddleOCR(use_angle_cls=True, lang='en') - self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True) + self.doctr = ocr_predictor(det_arch='db_resnet50', reco_arch='crnn_vgg16_bn', pretrained=True).to(device) def _preprocess_image(self, img): resized_image = cv2.resize(img, None, fx=1.6, fy=1.6, interpolation=cv2.INTER_CUBIC)