From 0fa9e14a7b3129e128bc28403ccffb95782f42cf Mon Sep 17 00:00:00 2001 From: Michele Dolfi Date: Tue, 18 Feb 2025 13:07:30 +0100 Subject: [PATCH] add factory for ocr engines Signed-off-by: Michele Dolfi --- docling/cli/main.py | 31 +++---- docling/datamodel/pipeline_options.py | 47 +++++----- docling/models/base_model.py | 8 +- docling/models/base_ocr_model.py | 21 +++-- docling/models/easyocr_model.py | 14 ++- docling/models/factories/__init__.py | 14 +++ docling/models/factories/base_factory.py | 103 ++++++++++++++++++++++ docling/models/factories/ocr_factory.py | 28 ++++++ docling/models/ocr_mac_model.py | 31 ++++++- docling/models/plugins/__init__.py | 0 docling/models/plugins/ocr_engines.py | 19 ++++ docling/models/rapid_ocr_model.py | 16 +++- docling/models/tesseract_ocr_cli_model.py | 28 +++++- docling/models/tesseract_ocr_model.py | 28 +++++- docling/pipeline/standard_pdf_pipeline.py | 54 +++--------- poetry.lock | 60 ++++++------- pyproject.toml | 4 + 17 files changed, 367 insertions(+), 139 deletions(-) create mode 100644 docling/models/factories/__init__.py create mode 100644 docling/models/factories/base_factory.py create mode 100644 docling/models/factories/ocr_factory.py create mode 100644 docling/models/plugins/__init__.py create mode 100644 docling/models/plugins/ocr_engines.py diff --git a/docling/cli/main.py b/docling/cli/main.py index 6686da9a..d9bb2e16 100644 --- a/docling/cli/main.py +++ b/docling/cli/main.py @@ -29,18 +29,13 @@ AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, - OcrEngine, - OcrMacOptions, - OcrOptions, PdfBackend, PdfPipelineOptions, - RapidOcrOptions, TableFormerMode, - TesseractCliOcrOptions, - TesseractOcrOptions, ) from docling.datamodel.settings import settings from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption +from docling.models.factories import get_ocr_factory warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch") warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr") @@ -50,6 +45,8 @@ err_console = Console(stderr=True) +ocr_factory = get_ocr_factory() +ocr_engines_enum = ocr_factory.get_enum() app = typer.Typer( name="Docling", @@ -194,9 +191,11 @@ def convert( help="Replace any existing text with OCR generated text over the full content.", ), ] = False, - ocr_engine: Annotated[ - OcrEngine, typer.Option(..., help="The OCR engine to use.") - ] = OcrEngine.EASYOCR, + ocr_engine: Annotated[ # type: ignore + ocr_engines_enum, + # ocr_factory.get_registered_enum(), + typer.Option(..., help="The OCR engine to use."), + ] = EasyOcrOptions.kind, ocr_lang: Annotated[ Optional[str], typer.Option( @@ -367,18 +366,8 @@ def convert( export_txt = OutputFormat.TEXT in to_formats export_doctags = OutputFormat.DOCTAGS in to_formats - if ocr_engine == OcrEngine.EASYOCR: - ocr_options: OcrOptions = EasyOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.TESSERACT_CLI: - ocr_options = TesseractCliOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.TESSERACT: - ocr_options = TesseractOcrOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.OCRMAC: - ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr) - elif ocr_engine == OcrEngine.RAPIDOCR: - ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr) - else: - raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}") + ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore + ocr_options = ocr_options_class(force_full_page_ocr=force_ocr) ocr_lang_list = _split_list(ocr_lang) if ocr_lang_list is not None: diff --git a/docling/datamodel/pipeline_options.py b/docling/datamodel/pipeline_options.py index d317e7d9..3cd3f67f 100644 --- a/docling/datamodel/pipeline_options.py +++ b/docling/datamodel/pipeline_options.py @@ -1,10 +1,9 @@ import logging import os import re -import warnings from enum import Enum from pathlib import Path -from typing import Annotated, Any, Dict, List, Literal, Optional, Union +from typing import Any, ClassVar, Dict, List, Literal, Optional, Union from pydantic import ( AnyUrl, @@ -13,11 +12,9 @@ Field, field_validator, model_validator, - validator, ) from pydantic_settings import ( BaseSettings, - PydanticBaseSettingsSource, SettingsConfigDict, ) from typing_extensions import deprecated @@ -82,6 +79,12 @@ def check_alternative_envvars(cls, data: Any) -> Any: return data +class BaseOptions(BaseModel): + """Base class for options.""" + + kind: ClassVar[str] + + class TableFormerMode(str, Enum): """Modes for the TableFormer model.""" @@ -101,10 +104,9 @@ class TableStructureOptions(BaseModel): mode: TableFormerMode = TableFormerMode.FAST -class OcrOptions(BaseModel): +class OcrOptions(BaseOptions): """OCR options.""" - kind: str lang: List[str] force_full_page_ocr: bool = False # If enabled a full page OCR is always applied bitmap_area_threshold: float = ( @@ -115,7 +117,7 @@ class OcrOptions(BaseModel): class RapidOcrOptions(OcrOptions): """Options for the RapidOCR engine.""" - kind: Literal["rapidocr"] = "rapidocr" + kind: ClassVar[Literal["rapidocr"]] = "rapidocr" # English and chinese are the most commly used models and have been tested with RapidOCR. lang: List[str] = [ @@ -154,7 +156,7 @@ class RapidOcrOptions(OcrOptions): class EasyOcrOptions(OcrOptions): """Options for the EasyOCR engine.""" - kind: Literal["easyocr"] = "easyocr" + kind: ClassVar[Literal["easyocr"]] = "easyocr" lang: List[str] = ["fr", "de", "es", "en"] use_gpu: Optional[bool] = None @@ -174,7 +176,7 @@ class EasyOcrOptions(OcrOptions): class TesseractCliOcrOptions(OcrOptions): """Options for the TesseractCli engine.""" - kind: Literal["tesseract"] = "tesseract" + kind: ClassVar[Literal["tesseract"]] = "tesseract" lang: List[str] = ["fra", "deu", "spa", "eng"] tesseract_cmd: str = "tesseract" path: Optional[str] = None @@ -187,7 +189,7 @@ class TesseractCliOcrOptions(OcrOptions): class TesseractOcrOptions(OcrOptions): """Options for the Tesseract engine.""" - kind: Literal["tesserocr"] = "tesserocr" + kind: ClassVar[Literal["tesserocr"]] = "tesserocr" lang: List[str] = ["fra", "deu", "spa", "eng"] path: Optional[str] = None @@ -199,7 +201,7 @@ class TesseractOcrOptions(OcrOptions): class OcrMacOptions(OcrOptions): """Options for the Mac OCR engine.""" - kind: Literal["ocrmac"] = "ocrmac" + kind: ClassVar[Literal["ocrmac"]] = "ocrmac" lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"] recognition: str = "accurate" framework: str = "vision" @@ -209,8 +211,7 @@ class OcrMacOptions(OcrOptions): ) -class PictureDescriptionBaseOptions(BaseModel): - kind: str +class PictureDescriptionBaseOptions(BaseOptions): batch_size: int = 8 scale: float = 2 @@ -220,7 +221,7 @@ class PictureDescriptionBaseOptions(BaseModel): class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): - kind: Literal["api"] = "api" + kind: ClassVar[Literal["api"]] = "api" url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions") headers: Dict[str, str] = {} @@ -232,7 +233,7 @@ class PictureDescriptionApiOptions(PictureDescriptionBaseOptions): class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions): - kind: Literal["vlm"] = "vlm" + kind: ClassVar[Literal["vlm"]] = "vlm" repo_id: str prompt: str = "Describe this image in a few sentences." @@ -264,6 +265,7 @@ class PdfBackend(str, Enum): # Define an enum for the ocr engines +@deprecated("Use ocr_factory.registered_enum") class OcrEngine(str, Enum): """Enum of valid OCR engines.""" @@ -297,17 +299,10 @@ class PdfPipelineOptions(PipelineOptions): do_picture_description: bool = False # True: run describe pictures in documents table_structure_options: TableStructureOptions = TableStructureOptions() - ocr_options: Union[ - EasyOcrOptions, - TesseractCliOcrOptions, - TesseractOcrOptions, - OcrMacOptions, - RapidOcrOptions, - ] = Field(EasyOcrOptions(), discriminator="kind") - picture_description_options: Annotated[ - Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions], - Field(discriminator="kind"), - ] = smolvlm_picture_description + ocr_options: OcrOptions = EasyOcrOptions() + picture_description_options: PictureDescriptionBaseOptions = ( + smolvlm_picture_description + ) images_scale: float = 1.0 generate_page_images: bool = False diff --git a/docling/models/base_model.py b/docling/models/base_model.py index 9cdc0ecb..c2cce6bd 100644 --- a/docling/models/base_model.py +++ b/docling/models/base_model.py @@ -1,14 +1,20 @@ from abc import ABC, abstractmethod -from typing import Any, Generic, Iterable, Optional +from typing import Any, Generic, Iterable, Optional, Protocol, Type from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem from typing_extensions import TypeVar from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page from docling.datamodel.document import ConversionResult +from docling.datamodel.pipeline_options import BaseOptions from docling.datamodel.settings import settings +class BaseModelWithOptions(Protocol): + @classmethod + def get_options_type(cls) -> Type[BaseOptions]: ... + + class BasePageModel(ABC): @abstractmethod def __call__( diff --git a/docling/models/base_ocr_model.py b/docling/models/base_ocr_model.py index 9afb7dde..a4c9f9e3 100644 --- a/docling/models/base_ocr_model.py +++ b/docling/models/base_ocr_model.py @@ -2,7 +2,7 @@ import logging from abc import abstractmethod from pathlib import Path -from typing import Iterable, List +from typing import Iterable, List, Optional, Type import numpy as np from docling_core.types.doc import BoundingBox, CoordOrigin @@ -12,15 +12,21 @@ from docling.datamodel.base_models import Cell, OcrCell, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import OcrOptions +from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions from docling.datamodel.settings import settings -from docling.models.base_model import BasePageModel +from docling.models.base_model import BaseModelWithOptions, BasePageModel _log = logging.getLogger(__name__) -class BaseOcrModel(BasePageModel): - def __init__(self, enabled: bool, options: OcrOptions): +class BaseOcrModel(BasePageModel, BaseModelWithOptions): + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: OcrOptions, + accelerator_options: AcceleratorOptions, + ): self.enabled = enabled self.options = options @@ -187,3 +193,8 @@ def __call__( self, conv_res: ConversionResult, page_batch: Iterable[Page] ) -> Iterable[Page]: pass + + @classmethod + @abstractmethod + def get_options_type(cls) -> Type[OcrOptions]: + pass diff --git a/docling/models/easyocr_model.py b/docling/models/easyocr_model.py index 0eccb988..a71b158d 100644 --- a/docling/models/easyocr_model.py +++ b/docling/models/easyocr_model.py @@ -2,7 +2,7 @@ import warnings import zipfile from pathlib import Path -from typing import Iterable, List, Optional +from typing import Iterable, List, Optional, Type import numpy from docling_core.types.doc import BoundingBox, CoordOrigin @@ -13,6 +13,7 @@ AcceleratorDevice, AcceleratorOptions, EasyOcrOptions, + OcrOptions, ) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel @@ -33,7 +34,12 @@ def __init__( options: EasyOcrOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: EasyOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -175,3 +181,7 @@ def __call__( self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return EasyOcrOptions diff --git a/docling/models/factories/__init__.py b/docling/models/factories/__init__.py new file mode 100644 index 00000000..a1b50970 --- /dev/null +++ b/docling/models/factories/__init__.py @@ -0,0 +1,14 @@ +import logging +from functools import lru_cache + +from docling.models.factories.ocr_factory import OcrFactory + +logger = logging.getLogger(__name__) + + +@lru_cache(maxsize=1) +def get_ocr_factory(): + factory = OcrFactory() + factory.load_from_plugins() + # logger.info("Registered ocr engines: %r", factory.registered_kind) + return factory diff --git a/docling/models/factories/base_factory.py b/docling/models/factories/base_factory.py new file mode 100644 index 00000000..adb3fbb0 --- /dev/null +++ b/docling/models/factories/base_factory.py @@ -0,0 +1,103 @@ +import enum +import logging +from abc import ABCMeta +from typing import Generic, Optional, Type, TypeVar + +from pluggy import PluginManager + +from docling.datamodel.pipeline_options import BaseOptions +from docling.models.base_model import BaseModelWithOptions + +A = TypeVar("A", bound=BaseModelWithOptions) + + +logger = logging.getLogger(__name__) + + +class BaseFactory(Generic[A], metaclass=ABCMeta): + default_plugin_name = "docling" + + def __init__(self, plugin_attr_name: str, plugin_name=default_plugin_name): + self.plugin_name = plugin_name + self.plugin_attr_name = plugin_attr_name + + self._classes: dict[Type[BaseOptions], Type[A]] = {} + + @property + def registered_kind(self) -> list[str]: + return list(opt.kind for opt in self._classes.keys()) + + def get_enum(self) -> enum.Enum: + return enum.Enum( + self.plugin_attr_name + "_enum", + names={kind: kind for kind in self.registered_kind}, + type=str, + module=__name__, + ) + + @property + def classes(self): + return self._classes + + def get_class(self, options: BaseOptions, *args, **kwargs) -> Type[A]: + try: + return self._classes[type(options)] + except KeyError: + return self.on_class_not_found(options.kind, *args, **kwargs) + + def get_class_by_kind(self, kind: str, *args, **kwargs) -> Type[A]: + for opt, cls in self._classes.items(): + if opt.kind == kind: + return cls + return self.on_class_not_found(kind, *args, **kwargs) + + def get_options_class(self, kind: str, *args, **kwargs) -> Type[BaseOptions]: + for opt, cls in self._classes.items(): + if opt.kind == kind: + return opt + return self.on_class_not_found(kind, *args, **kwargs) + + def on_class_not_found(self, kind: str, *args, **kwargs): + msg = [] + + for opt, cls in self._classes.items(): + msg.append(f"\t{opt.kind!r} => {cls!r}") + + msg_str = "\n".join(msg) + + raise RuntimeError( + f"No class found with the name {kind!r}, known classes are:\n{msg_str}" + ) + + def register(self, cls: Type[A]): + opt_type = cls.get_options_type() + + if opt_type in self._classes: + raise ValueError( + f"{opt_type.kind!r} already registered to class {self._classes[opt_type]!r}" + ) + + self._classes[opt_type] = cls + + def load_from_plugins(self, plugin_name: Optional[str] = None): + plugin_name = plugin_name or self.plugin_name + + plugin_manager = PluginManager(plugin_name) + plugin_manager.load_setuptools_entrypoints(plugin_name) + + for plugin_name, plugin_module in plugin_manager.list_name_plugin(): + + attr = getattr(plugin_module, self.plugin_attr_name, None) + + if callable(attr): + logger.info("Loading plugin %r", plugin_name) + + config = attr() + self.process_plugin(config) + + def process_plugin(self, config): + for item in config[self.plugin_attr_name]: + try: + self.register(item) + except ValueError: + logger.warning("%r already registered", item) diff --git a/docling/models/factories/ocr_factory.py b/docling/models/factories/ocr_factory.py new file mode 100644 index 00000000..414d004a --- /dev/null +++ b/docling/models/factories/ocr_factory.py @@ -0,0 +1,28 @@ +import logging + +from docling.datamodel.pipeline_options import OcrOptions +from docling.models.base_ocr_model import BaseOcrModel +from docling.models.factories.base_factory import BaseFactory + +logger = logging.getLogger(__name__) + + +class OcrFactory(BaseFactory[BaseOcrModel]): + def __init__(self, *args, **kwargs): + super().__init__("ocr_engines", *args, **kwargs) + + +# def on_class_not_found(self, kind: str, *args, **kwargs): + +# raise NoSuchOcrEngine(kind, self.registered_kind) + + +# class NoSuchOcrEngine(Exception): +# def __init__(self, ocr_engine_kind, known_engines=None): +# if known_engines is None: +# known_engines = [] +# super(NoSuchOcrEngine, self).__init__( +# "No OCR engine found with the name '%s', known engines are: %r", +# ocr_engine_kind, +# [cls.__name__ for cls in known_engines], +# ) diff --git a/docling/models/ocr_mac_model.py b/docling/models/ocr_mac_model.py index 38bcf1ca..ad7046c5 100644 --- a/docling/models/ocr_mac_model.py +++ b/docling/models/ocr_mac_model.py @@ -1,12 +1,18 @@ import logging +import sys import tempfile -from typing import Iterable, Optional, Tuple +from pathlib import Path +from typing import Iterable, Optional, Tuple, Type from docling_core.types.doc import BoundingBox, CoordOrigin from docling.datamodel.base_models import OcrCell, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import OcrMacOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrMacOptions, + OcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.profiling import TimeRecorder @@ -15,13 +21,26 @@ class OcrMacModel(BaseOcrModel): - def __init__(self, enabled: bool, options: OcrMacOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: OcrMacOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: OcrMacOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. if self.enabled: + if "darwin" != sys.platform: + raise RuntimeError(f"OcrMac is only supported on Mac.") install_errmsg = ( "ocrmac is not correctly installed. " "Please install it via `pip install ocrmac` to use this OCR engine. " @@ -116,3 +135,7 @@ def __call__( self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return OcrMacOptions diff --git a/docling/models/plugins/__init__.py b/docling/models/plugins/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/docling/models/plugins/ocr_engines.py b/docling/models/plugins/ocr_engines.py new file mode 100644 index 00000000..646d8782 --- /dev/null +++ b/docling/models/plugins/ocr_engines.py @@ -0,0 +1,19 @@ +import sys + +from docling.models.easyocr_model import EasyOcrModel +from docling.models.ocr_mac_model import OcrMacModel +from docling.models.rapid_ocr_model import RapidOcrModel +from docling.models.tesseract_ocr_cli_model import TesseractOcrCliModel +from docling.models.tesseract_ocr_model import TesseractOcrModel + + +def ocr_engines(): + return { + "ocr_engines": [ + EasyOcrModel, + OcrMacModel, + RapidOcrModel, + TesseractOcrModel, + TesseractOcrCliModel, + ] + } diff --git a/docling/models/rapid_ocr_model.py b/docling/models/rapid_ocr_model.py index fa3fbedf..742bfcb1 100644 --- a/docling/models/rapid_ocr_model.py +++ b/docling/models/rapid_ocr_model.py @@ -1,5 +1,6 @@ import logging -from typing import Iterable +from pathlib import Path +from typing import Iterable, Optional, Type import numpy from docling_core.types.doc import BoundingBox, CoordOrigin @@ -9,6 +10,7 @@ from docling.datamodel.pipeline_options import ( AcceleratorDevice, AcceleratorOptions, + OcrOptions, RapidOcrOptions, ) from docling.datamodel.settings import settings @@ -23,10 +25,16 @@ class RapidOcrModel(BaseOcrModel): def __init__( self, enabled: bool, + artifacts_path: Optional[Path], options: RapidOcrOptions, accelerator_options: AcceleratorOptions, ): - super().__init__(enabled=enabled, options=options) + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: RapidOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -126,3 +134,7 @@ def __call__( self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return RapidOcrOptions diff --git a/docling/models/tesseract_ocr_cli_model.py b/docling/models/tesseract_ocr_cli_model.py index cdc5671d..a13cbbd5 100644 --- a/docling/models/tesseract_ocr_cli_model.py +++ b/docling/models/tesseract_ocr_cli_model.py @@ -3,15 +3,20 @@ import logging import os import tempfile +from pathlib import Path from subprocess import DEVNULL, PIPE, Popen -from typing import Iterable, List, Optional, Tuple +from typing import Iterable, List, Optional, Tuple, Type import pandas as pd from docling_core.types.doc import BoundingBox, CoordOrigin from docling.datamodel.base_models import Cell, OcrCell, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import TesseractCliOcrOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrOptions, + TesseractCliOcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.ocr_utils import map_tesseract_script @@ -21,8 +26,19 @@ class TesseractOcrCliModel(BaseOcrModel): - def __init__(self, enabled: bool, options: TesseractCliOcrOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: TesseractCliOcrOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: TesseractCliOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -250,3 +266,7 @@ def __call__( self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return TesseractCliOcrOptions diff --git a/docling/models/tesseract_ocr_model.py b/docling/models/tesseract_ocr_model.py index c41806f5..4ad16ebc 100644 --- a/docling/models/tesseract_ocr_model.py +++ b/docling/models/tesseract_ocr_model.py @@ -1,11 +1,16 @@ import logging -from typing import Iterable +from pathlib import Path +from typing import Iterable, Optional, Type from docling_core.types.doc import BoundingBox, CoordOrigin from docling.datamodel.base_models import Cell, OcrCell, Page from docling.datamodel.document import ConversionResult -from docling.datamodel.pipeline_options import TesseractOcrOptions +from docling.datamodel.pipeline_options import ( + AcceleratorOptions, + OcrOptions, + TesseractOcrOptions, +) from docling.datamodel.settings import settings from docling.models.base_ocr_model import BaseOcrModel from docling.utils.ocr_utils import map_tesseract_script @@ -15,8 +20,19 @@ class TesseractOcrModel(BaseOcrModel): - def __init__(self, enabled: bool, options: TesseractOcrOptions): - super().__init__(enabled=enabled, options=options) + def __init__( + self, + enabled: bool, + artifacts_path: Optional[Path], + options: TesseractOcrOptions, + accelerator_options: AcceleratorOptions, + ): + super().__init__( + enabled=enabled, + artifacts_path=artifacts_path, + options=options, + accelerator_options=accelerator_options, + ) self.options: TesseractOcrOptions self.scale = 3 # multiplier for 72 dpi == 216 dpi. @@ -195,3 +211,7 @@ def __call__( self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects) yield page + + @classmethod + def get_options_type(cls) -> Type[OcrOptions]: + return TesseractOcrOptions diff --git a/docling/pipeline/standard_pdf_pipeline.py b/docling/pipeline/standard_pdf_pipeline.py index ae4ed478..88658c88 100644 --- a/docling/pipeline/standard_pdf_pipeline.py +++ b/docling/pipeline/standard_pdf_pipeline.py @@ -29,6 +29,7 @@ ) from docling.models.ds_glm_model import GlmModel, GlmOptions from docling.models.easyocr_model import EasyOcrModel +from docling.models.factories import get_ocr_factory from docling.models.layout_model import LayoutModel from docling.models.ocr_mac_model import OcrMacModel from docling.models.page_assemble_model import PageAssembleModel, PageAssembleOptions @@ -78,10 +79,7 @@ def __init__(self, pipeline_options: PdfPipelineOptions): self.glm_model = GlmModel(options=GlmOptions()) - if (ocr_model := self.get_ocr_model(artifacts_path=artifacts_path)) is None: - raise RuntimeError( - f"The specified OCR kind is not supported: {pipeline_options.ocr_options.kind}." - ) + ocr_model = self.get_ocr_model(artifacts_path=artifacts_path) self.build_pipe = [ # Pre-processing @@ -163,42 +161,18 @@ def download_models_hf( output_dir = download_models(output_dir=local_dir, force=force, progress=False) return output_dir - def get_ocr_model( - self, artifacts_path: Optional[Path] = None - ) -> Optional[BaseOcrModel]: - if isinstance(self.pipeline_options.ocr_options, EasyOcrOptions): - return EasyOcrModel( - enabled=self.pipeline_options.do_ocr, - artifacts_path=artifacts_path, - options=self.pipeline_options.ocr_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - elif isinstance(self.pipeline_options.ocr_options, TesseractCliOcrOptions): - return TesseractOcrCliModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - elif isinstance(self.pipeline_options.ocr_options, TesseractOcrOptions): - return TesseractOcrModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - elif isinstance(self.pipeline_options.ocr_options, RapidOcrOptions): - return RapidOcrModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - accelerator_options=self.pipeline_options.accelerator_options, - ) - elif isinstance(self.pipeline_options.ocr_options, OcrMacOptions): - if "darwin" != sys.platform: - raise RuntimeError( - f"The specified OCR type is only supported on Mac: {self.pipeline_options.ocr_options.kind}." - ) - return OcrMacModel( - enabled=self.pipeline_options.do_ocr, - options=self.pipeline_options.ocr_options, - ) - return None + def get_ocr_model(self, artifacts_path: Optional[Path] = None) -> BaseOcrModel: + ocr_factory = get_ocr_factory() + ocr_engine_cls = ocr_factory.get_class( + options=self.pipeline_options.ocr_options + ) + + return ocr_engine_cls( + enabled=self.pipeline_options.do_ocr, + artifacts_path=artifacts_path, + options=self.pipeline_options.ocr_options, + accelerator_options=self.pipeline_options.accelerator_options, + ) def get_picture_description_model( self, artifacts_path: Optional[Path] = None diff --git a/poetry.lock b/poetry.lock index 329e4ae7..291d2778 100644 --- a/poetry.lock +++ b/poetry.lock @@ -924,39 +924,39 @@ transformers = [ [[package]] name = "docling-parse" -version = "3.3.1" +version = "3.4.0" description = "Simple package to extract text with coordinates from programmatic PDFs" optional = false python-versions = "<4.0,>=3.9" files = [ - {file = "docling_parse-3.3.1-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:abf3a0c9ea35fc33fbd288031096826688d1e787f7c51e174cc9fea6a22d2f67"}, - {file = "docling_parse-3.3.1-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:1e07cc6e3603ff246affa11bec25a82d90f79c6b92c370d993f2bd6318476b7c"}, - {file = "docling_parse-3.3.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:338000deee0251f7e2ebdfde2bcd6392c388624206555410867cfc93608d84fe"}, - {file = "docling_parse-3.3.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:46f14a9871e840a021a642ea0ece4b675cc9584224eba3b85cd269feda892b76"}, - {file = "docling_parse-3.3.1-cp310-cp310-win_amd64.whl", hash = "sha256:c0bad3db594e05bca2366d46e630a0b8050b6eb37fcae2cbcd5020b06ac0879a"}, - {file = "docling_parse-3.3.1-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:db0cebe28b299f78f1da58b5567c22de6f5b30aa5b6fe4fb2daae9f372bd022c"}, - {file = "docling_parse-3.3.1-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:72e901cc6dfac9e4f5e13ddd841f758b41484e61b7092b891c693e2c036461ac"}, - {file = "docling_parse-3.3.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fcfc7309b46e9b0941cc5513560b06f0b1c221ff3a2d5e516eb752ae7f2ccb81"}, - {file = "docling_parse-3.3.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:76dc12fc510d08d1f76741dc429666a5791db003275067aa1f18da02b7a98925"}, - {file = "docling_parse-3.3.1-cp311-cp311-win_amd64.whl", hash = "sha256:4c096f5c2460a6eb308e046e3045bb0100b6b602ef4394924cfd4846cee5800e"}, - {file = "docling_parse-3.3.1-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:3cc23f0d6aa91d015117b8962162bf4a482e2208d2068abfada34fda112ef077"}, - {file = "docling_parse-3.3.1-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:7a62afe01b6f008f3f50a12e3d8feddb28d045bc2b96321d48933ab23ff1e201"}, - {file = "docling_parse-3.3.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd93d2a23a22b84c61213fa62db906ae444201e4e404d7dc2b6152d64d69ec50"}, - {file = "docling_parse-3.3.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:87a8678b8e37a1f0bd41fb0227938a7bba0dc7e6f6ce69777b1ee947ef0e28ef"}, - {file = "docling_parse-3.3.1-cp312-cp312-win_amd64.whl", hash = "sha256:c6ae62864d3d0e1e3bfb467e217c90ae938b0773c671412ff3ca110081b024ea"}, - {file = "docling_parse-3.3.1-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:79cc92e3b1d3d8957df11c8dfd5c8f89aaae06d5ac49f019a59a0aad301ba59c"}, - {file = "docling_parse-3.3.1-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:8d8b988ca77dd88111fb8e5d961806fcb26b3ea146841d7d304d1d52b82ed27b"}, - {file = "docling_parse-3.3.1-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:781a88c3a4eb27dd09598fe5956f8cd874acb49c102c2d35ccf0fbbeb3fc714d"}, - {file = "docling_parse-3.3.1-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:cb803dd61d79a7f876ff7febb4ce2f43de19646f05196031f35b891a8c24b57a"}, - {file = "docling_parse-3.3.1-cp313-cp313-win_amd64.whl", hash = "sha256:2052d3dc4711975fdcfa343947a1fbb9502c6a81ccc5834af41615868e61fb94"}, - {file = "docling_parse-3.3.1-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:7ccb186369f706b5df8d6751c6cfff2a4355c3c843c68b0210e3f53a2bdf9bf6"}, - {file = "docling_parse-3.3.1-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:5a00c37ed9923f7d7317044135d8ff81829474d1d47730dfc8bd2d2a3e3e60cd"}, - {file = "docling_parse-3.3.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:72beb63683f5e581c15d1c3370480dbe4457031f447944342a09bd23a66b378e"}, - {file = "docling_parse-3.3.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d5ac8f3bb64e50cf58959ce591574cb0ba3d6ebe9bdbedfabbff1817aaf34664"}, - {file = "docling_parse-3.3.1-cp39-cp39-win_amd64.whl", hash = "sha256:24ea10d7bda2ea35c6cc24b8db3fdea4a1e05182890ea44364fcd703e5090e54"}, - {file = "docling_parse-3.3.1-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:048157bc8640e3c03c082b4296f9f8946516c624a5469e10c7c9a32dcb0dc5c8"}, - {file = "docling_parse-3.3.1-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:c6f293644856d05a1fced3e100dd052374da793e7e3a0e6023c69b9d3eb64881"}, - {file = "docling_parse-3.3.1.tar.gz", hash = "sha256:536f581e7564cbfd37bff2e79d2cb17e7dbaa0d34d054cfdb28d648da31da85b"}, + {file = "docling_parse-3.4.0-cp310-cp310-macosx_13_0_x86_64.whl", hash = "sha256:96e95e63ab722dfe5340fcb04d0e07bd1c0a0ba2f62e93c91ac26dda0a312a44"}, + {file = "docling_parse-3.4.0-cp310-cp310-macosx_14_0_arm64.whl", hash = "sha256:f9e14a7a0b92526d4dfd3f390f3d7e075f59d14d6b8a0a564fbc26299e56cd47"}, + {file = "docling_parse-3.4.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fdef1d51291e841e5b6a32689a39a9f35986389f863b415eaa1790b29d021101"}, + {file = "docling_parse-3.4.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:68652610d6c34adc684dbaa77b5d596b25d004912a78e85ec4ae57910bf7086f"}, + {file = "docling_parse-3.4.0-cp310-cp310-win_amd64.whl", hash = "sha256:daad07fe93f306d8e2378acb24ef2fa68535ccdb960a1b99d6b36ab8c299fef1"}, + {file = "docling_parse-3.4.0-cp311-cp311-macosx_13_0_x86_64.whl", hash = "sha256:6f30c5fd3c04bd3d1a7d06baeae2e5c3adbebc284071a9a52b0150bcd4917a3d"}, + {file = "docling_parse-3.4.0-cp311-cp311-macosx_14_0_arm64.whl", hash = "sha256:2c3664e4c8980dc44e0d026b1b01fbc94f0dac9adf7be835071d4a761977c36d"}, + {file = "docling_parse-3.4.0-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3febf7515453d18df03c275356db2bb5b0618ba9fc033aba05d58318a9846b1a"}, + {file = "docling_parse-3.4.0-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:75aeb038bb7f6400ecde99cf6c4ef35867c528ac21676071a822ed72d0653149"}, + {file = "docling_parse-3.4.0-cp311-cp311-win_amd64.whl", hash = "sha256:8d20e3584022542448c21ed0ac868b2457ae35211cea63ed20142e375549e633"}, + {file = "docling_parse-3.4.0-cp312-cp312-macosx_13_0_x86_64.whl", hash = "sha256:ddfe2bd730ed08363f25954a0480da021e6e6bdb175276643cc2913a6bbd98e2"}, + {file = "docling_parse-3.4.0-cp312-cp312-macosx_14_0_arm64.whl", hash = "sha256:faf8ba9eaab8c17ea72516be5d440f754fcca27f37488dcf126a0f3ac3a63058"}, + {file = "docling_parse-3.4.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:9eb5e7e50b3057690d0d4fa651363cafd7735bb952378dd8a4ca6c7d359507db"}, + {file = "docling_parse-3.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:452334b387e2c699f69acf37a4ea4ae7097d062a2dd1980c573b73051c031158"}, + {file = "docling_parse-3.4.0-cp312-cp312-win_amd64.whl", hash = "sha256:1ba00147ccb0a1dc10cdf58645e67f4ee895c6920bc583bc6f25d27cd562bfed"}, + {file = "docling_parse-3.4.0-cp313-cp313-macosx_13_0_x86_64.whl", hash = "sha256:2b22a33a2d2f3616a7ac0f4b2f2ba6099f8a5dc6fa328be0f17c9c506455d7c1"}, + {file = "docling_parse-3.4.0-cp313-cp313-macosx_14_0_arm64.whl", hash = "sha256:0dd2440a94d555f98b702e88bfe7cc5a585d9191f4ea93884b02e286e7af3a06"}, + {file = "docling_parse-3.4.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f5828744a0e33136e09e8c61ca0b2c0ead8f76595f2e0955beaac16adce51f5"}, + {file = "docling_parse-3.4.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:26fff6e36809d17ff855532f985df3738ada8d86a9fc746049ea6e6524d5e0a2"}, + {file = "docling_parse-3.4.0-cp313-cp313-win_amd64.whl", hash = "sha256:13fc442f64171280db98dc4507274ffa0a65bac94eecbcc60c3cbf41f433b556"}, + {file = "docling_parse-3.4.0-cp39-cp39-macosx_13_0_x86_64.whl", hash = "sha256:16d570ab655ea5a25d9cd1e27bc4d6905372784907d679cde4cef2fb22df61c7"}, + {file = "docling_parse-3.4.0-cp39-cp39-macosx_14_0_arm64.whl", hash = "sha256:05bd405635be2379ef6cb0c7c39dc08edf3ba93788eb0fca7426b2218538bce1"}, + {file = "docling_parse-3.4.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:f6c92f0353bbae7ca9b39553cc4d03f5fefdab33ecd26809ab710cc752fac03c"}, + {file = "docling_parse-3.4.0-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8e883326ec4121891c48d365d064e5ae30c5b90a2dac44ed61ac02e7da41345d"}, + {file = "docling_parse-3.4.0-cp39-cp39-win_amd64.whl", hash = "sha256:b2a0fe1e1d88c3814553137daa597ee34dc310f50fe415e1f8a1c6e611d95e42"}, + {file = "docling_parse-3.4.0-pp310-pypy310_pp73-win_amd64.whl", hash = "sha256:930f5a5d78404de573c0ba302d313b6647f1e86714766e5a1cdc09af014ca111"}, + {file = "docling_parse-3.4.0-pp39-pypy39_pp73-win_amd64.whl", hash = "sha256:328fd72f274b939d454e3ff20a73074d99664cb4a51e6ccdaf195a6626691b95"}, + {file = "docling_parse-3.4.0.tar.gz", hash = "sha256:36cdd17bcc4a833b5c9af9ae3dc461ed18a975c1b084ccfd19a9d9cde4f66e14"}, ] [package.dependencies] @@ -7842,4 +7842,4 @@ vlm = ["transformers", "transformers"] [metadata] lock-version = "2.0" python-versions = "^3.9" -content-hash = "63f9271160d39cac74fa3fc959dbb0f91530d76a693c69d81ced006477d04315" +content-hash = "ccc7e0aec5519c63d5dee8482f1242c94b8241fea831e86071cde94ae1fc04d6" diff --git a/pyproject.toml b/pyproject.toml index 0c04acf4..aeb56631 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -65,6 +65,7 @@ transformers = [ ] pillow = ">=10.0.0,<12.0.0" tqdm = "^4.65.0" +pluggy = "^1.0.0" [tool.poetry.group.dev.dependencies] black = {extras = ["jupyter"], version = "^24.4.2"} @@ -132,6 +133,9 @@ rapidocr = ["rapidocr-onnxruntime", "onnxruntime"] docling = "docling.cli.main:app" docling-tools = "docling.cli.tools:app" +[tool.poetry.plugins."docling"] +"docling_ocr_engines" = "docling.models.plugins.ocr_engines" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api"