You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
181 lines
4.9 KiB
Python
181 lines
4.9 KiB
Python
from typing import List, Optional
|
|
|
|
import numpy as np
|
|
from tqdm import tqdm
|
|
|
|
from surya.input.processing import slice_bboxes_from_image
|
|
from surya.settings import settings
|
|
import os
|
|
from concurrent.futures import ProcessPoolExecutor
|
|
from surya.recognition.languages import CODE_TO_LANGUAGE
|
|
from surya.recognition import RecognitionPredictor
|
|
from surya.detection import DetectionPredictor
|
|
|
|
|
|
def surya_lang_to_tesseract(code: str) -> Optional[str]:
|
|
lang_str = CODE_TO_LANGUAGE[code]
|
|
try:
|
|
tess_lang = TESS_LANGUAGE_TO_CODE[lang_str]
|
|
except KeyError:
|
|
return None
|
|
return tess_lang
|
|
|
|
|
|
def tesseract_ocr(img, bboxes, lang: str):
|
|
import pytesseract
|
|
line_imgs = slice_bboxes_from_image(img, bboxes)
|
|
config = f'--tessdata-dir "{settings.TESSDATA_PREFIX}"'
|
|
lines = []
|
|
for line_img in line_imgs:
|
|
line = pytesseract.image_to_string(line_img, lang=lang, config=config)
|
|
lines.append(line)
|
|
return lines
|
|
|
|
|
|
def tesseract_ocr_parallel(imgs, bboxes, langs: List[str], cpus=None):
|
|
tess_parallel_cores = min(len(imgs), RecognitionPredictor.get_batch_size())
|
|
if not cpus:
|
|
cpus = os.cpu_count()
|
|
tess_parallel_cores = min(tess_parallel_cores, cpus)
|
|
|
|
# Tesseract uses up to 4 processes per instance
|
|
# Divide by 2 because tesseract doesn't seem to saturate all 4 cores with these small images
|
|
tess_parallel = max(tess_parallel_cores // 2, 1)
|
|
|
|
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
|
|
tess_text = tqdm(executor.map(tesseract_ocr, imgs, bboxes, langs), total=len(imgs), desc="Running tesseract OCR")
|
|
tess_text = list(tess_text)
|
|
return tess_text
|
|
|
|
|
|
def tesseract_bboxes(img):
|
|
import pytesseract
|
|
from pytesseract import Output
|
|
arr_img = np.asarray(img, dtype=np.uint8)
|
|
ocr = pytesseract.image_to_data(arr_img, output_type=Output.DICT)
|
|
|
|
bboxes = []
|
|
n_boxes = len(ocr['level'])
|
|
for i in range(n_boxes):
|
|
# It is possible to merge by line here with line number, but it gives bad results.
|
|
_, x, y, w, h = ocr['text'][i], ocr['left'][i], ocr['top'][i], ocr['width'][i], ocr['height'][i]
|
|
bbox = (x, y, x + w, y + h)
|
|
bboxes.append(bbox)
|
|
|
|
return bboxes
|
|
|
|
|
|
def tesseract_parallel(imgs):
|
|
# Tesseract uses 4 threads per instance
|
|
tess_parallel_cores = min(len(imgs), DetectionPredictor.get_batch_size())
|
|
cpus = os.cpu_count()
|
|
tess_parallel_cores = min(tess_parallel_cores, cpus)
|
|
|
|
# Tesseract uses 4 threads per instance
|
|
tess_parallel = max(tess_parallel_cores // 4, 1)
|
|
|
|
with ProcessPoolExecutor(max_workers=tess_parallel) as executor:
|
|
tess_bboxes = tqdm(executor.map(tesseract_bboxes, imgs), total=len(imgs), desc="Running tesseract bbox detection")
|
|
tess_bboxes = list(tess_bboxes)
|
|
return tess_bboxes
|
|
|
|
|
|
TESS_CODE_TO_LANGUAGE = {
|
|
"afr": "Afrikaans",
|
|
"amh": "Amharic",
|
|
"ara": "Arabic",
|
|
"asm": "Assamese",
|
|
"aze": "Azerbaijani",
|
|
"bel": "Belarusian",
|
|
"ben": "Bengali",
|
|
"bod": "Tibetan",
|
|
"bos": "Bosnian",
|
|
"bre": "Breton",
|
|
"bul": "Bulgarian",
|
|
"cat": "Catalan",
|
|
"ceb": "Cebuano",
|
|
"ces": "Czech",
|
|
"chi_sim": "Chinese",
|
|
"chr": "Cherokee",
|
|
"cym": "Welsh",
|
|
"dan": "Danish",
|
|
"deu": "German",
|
|
"dzo": "Dzongkha",
|
|
"ell": "Greek",
|
|
"eng": "English",
|
|
"epo": "Esperanto",
|
|
"est": "Estonian",
|
|
"eus": "Basque",
|
|
"fas": "Persian",
|
|
"fin": "Finnish",
|
|
"fra": "French",
|
|
"fry": "Western Frisian",
|
|
"guj": "Gujarati",
|
|
"gla": "Scottish Gaelic",
|
|
"gle": "Irish",
|
|
"glg": "Galician",
|
|
"heb": "Hebrew",
|
|
"hin": "Hindi",
|
|
"hrv": "Croatian",
|
|
"hun": "Hungarian",
|
|
"hye": "Armenian",
|
|
"iku": "Inuktitut",
|
|
"ind": "Indonesian",
|
|
"isl": "Icelandic",
|
|
"ita": "Italian",
|
|
"jav": "Javanese",
|
|
"jpn": "Japanese",
|
|
"kan": "Kannada",
|
|
"kat": "Georgian",
|
|
"kaz": "Kazakh",
|
|
"khm": "Khmer",
|
|
"kir": "Kyrgyz",
|
|
"kor": "Korean",
|
|
"lao": "Lao",
|
|
"lat": "Latin",
|
|
"lav": "Latvian",
|
|
"lit": "Lithuanian",
|
|
"mal": "Malayalam",
|
|
"mar": "Marathi",
|
|
"mkd": "Macedonian",
|
|
"mlt": "Maltese",
|
|
"mon": "Mongolian",
|
|
"msa": "Malay",
|
|
"mya": "Burmese",
|
|
"nep": "Nepali",
|
|
"nld": "Dutch",
|
|
"nor": "Norwegian",
|
|
"ori": "Oriya",
|
|
"pan": "Punjabi",
|
|
"pol": "Polish",
|
|
"por": "Portuguese",
|
|
"pus": "Pashto",
|
|
"ron": "Romanian",
|
|
"rus": "Russian",
|
|
"san": "Sanskrit",
|
|
"sin": "Sinhala",
|
|
"slk": "Slovak",
|
|
"slv": "Slovenian",
|
|
"snd": "Sindhi",
|
|
"spa": "Spanish",
|
|
"sqi": "Albanian",
|
|
"srp": "Serbian",
|
|
"swa": "Swahili",
|
|
"swe": "Swedish",
|
|
"syr": "Syriac",
|
|
"tam": "Tamil",
|
|
"tel": "Telugu",
|
|
"tgk": "Tajik",
|
|
"tha": "Thai",
|
|
"tir": "Tigrinya",
|
|
"tur": "Turkish",
|
|
"uig": "Uyghur",
|
|
"ukr": "Ukrainian",
|
|
"urd": "Urdu",
|
|
"uzb": "Uzbek",
|
|
"vie": "Vietnamese",
|
|
"yid": "Yiddish"
|
|
}
|
|
|
|
TESS_LANGUAGE_TO_CODE = {v:k for k,v in TESS_CODE_TO_LANGUAGE.items()}
|