Skip to content

Commit

Permalink
add factory for ocr engines
Browse files Browse the repository at this point in the history
Signed-off-by: Michele Dolfi <[email protected]>
  • Loading branch information
dolfim-ibm committed Feb 18, 2025
1 parent 7450050 commit 0fa9e14
Show file tree
Hide file tree
Showing 17 changed files with 367 additions and 139 deletions.
31 changes: 10 additions & 21 deletions docling/cli/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,18 +29,13 @@
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrEngine,
OcrMacOptions,
OcrOptions,
PdfBackend,
PdfPipelineOptions,
RapidOcrOptions,
TableFormerMode,
TesseractCliOcrOptions,
TesseractOcrOptions,
)
from docling.datamodel.settings import settings
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
from docling.models.factories import get_ocr_factory

warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
Expand All @@ -50,6 +45,8 @@

err_console = Console(stderr=True)

ocr_factory = get_ocr_factory()
ocr_engines_enum = ocr_factory.get_enum()

app = typer.Typer(
name="Docling",
Expand Down Expand Up @@ -194,9 +191,11 @@ def convert(
help="Replace any existing text with OCR generated text over the full content.",
),
] = False,
ocr_engine: Annotated[
OcrEngine, typer.Option(..., help="The OCR engine to use.")
] = OcrEngine.EASYOCR,
ocr_engine: Annotated[ # type: ignore
ocr_engines_enum,
# ocr_factory.get_registered_enum(),
typer.Option(..., help="The OCR engine to use."),
] = EasyOcrOptions.kind,
ocr_lang: Annotated[
Optional[str],
typer.Option(
Expand Down Expand Up @@ -367,18 +366,8 @@ def convert(
export_txt = OutputFormat.TEXT in to_formats
export_doctags = OutputFormat.DOCTAGS in to_formats

if ocr_engine == OcrEngine.EASYOCR:
ocr_options: OcrOptions = EasyOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.TESSERACT_CLI:
ocr_options = TesseractCliOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.TESSERACT:
ocr_options = TesseractOcrOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.OCRMAC:
ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr)
elif ocr_engine == OcrEngine.RAPIDOCR:
ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr)
else:
raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}")
ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore
ocr_options = ocr_options_class(force_full_page_ocr=force_ocr)

ocr_lang_list = _split_list(ocr_lang)
if ocr_lang_list is not None:
Expand Down
47 changes: 21 additions & 26 deletions docling/datamodel/pipeline_options.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import logging
import os
import re
import warnings
from enum import Enum
from pathlib import Path
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union

from pydantic import (
AnyUrl,
Expand All @@ -13,11 +12,9 @@
Field,
field_validator,
model_validator,
validator,
)
from pydantic_settings import (
BaseSettings,
PydanticBaseSettingsSource,
SettingsConfigDict,
)
from typing_extensions import deprecated
Expand Down Expand Up @@ -82,6 +79,12 @@ def check_alternative_envvars(cls, data: Any) -> Any:
return data


class BaseOptions(BaseModel):
"""Base class for options."""

kind: ClassVar[str]


class TableFormerMode(str, Enum):
"""Modes for the TableFormer model."""

Expand All @@ -101,10 +104,9 @@ class TableStructureOptions(BaseModel):
mode: TableFormerMode = TableFormerMode.FAST


class OcrOptions(BaseModel):
class OcrOptions(BaseOptions):
"""OCR options."""

kind: str
lang: List[str]
force_full_page_ocr: bool = False # If enabled a full page OCR is always applied
bitmap_area_threshold: float = (
Expand All @@ -115,7 +117,7 @@ class OcrOptions(BaseModel):
class RapidOcrOptions(OcrOptions):
"""Options for the RapidOCR engine."""

kind: Literal["rapidocr"] = "rapidocr"
kind: ClassVar[Literal["rapidocr"]] = "rapidocr"

# English and chinese are the most commly used models and have been tested with RapidOCR.
lang: List[str] = [
Expand Down Expand Up @@ -154,7 +156,7 @@ class RapidOcrOptions(OcrOptions):
class EasyOcrOptions(OcrOptions):
"""Options for the EasyOCR engine."""

kind: Literal["easyocr"] = "easyocr"
kind: ClassVar[Literal["easyocr"]] = "easyocr"
lang: List[str] = ["fr", "de", "es", "en"]

use_gpu: Optional[bool] = None
Expand All @@ -174,7 +176,7 @@ class EasyOcrOptions(OcrOptions):
class TesseractCliOcrOptions(OcrOptions):
"""Options for the TesseractCli engine."""

kind: Literal["tesseract"] = "tesseract"
kind: ClassVar[Literal["tesseract"]] = "tesseract"
lang: List[str] = ["fra", "deu", "spa", "eng"]
tesseract_cmd: str = "tesseract"
path: Optional[str] = None
Expand All @@ -187,7 +189,7 @@ class TesseractCliOcrOptions(OcrOptions):
class TesseractOcrOptions(OcrOptions):
"""Options for the Tesseract engine."""

kind: Literal["tesserocr"] = "tesserocr"
kind: ClassVar[Literal["tesserocr"]] = "tesserocr"
lang: List[str] = ["fra", "deu", "spa", "eng"]
path: Optional[str] = None

Expand All @@ -199,7 +201,7 @@ class TesseractOcrOptions(OcrOptions):
class OcrMacOptions(OcrOptions):
"""Options for the Mac OCR engine."""

kind: Literal["ocrmac"] = "ocrmac"
kind: ClassVar[Literal["ocrmac"]] = "ocrmac"
lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"]
recognition: str = "accurate"
framework: str = "vision"
Expand All @@ -209,8 +211,7 @@ class OcrMacOptions(OcrOptions):
)


class PictureDescriptionBaseOptions(BaseModel):
kind: str
class PictureDescriptionBaseOptions(BaseOptions):
batch_size: int = 8
scale: float = 2

Expand All @@ -220,7 +221,7 @@ class PictureDescriptionBaseOptions(BaseModel):


class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
kind: Literal["api"] = "api"
kind: ClassVar[Literal["api"]] = "api"

url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions")
headers: Dict[str, str] = {}
Expand All @@ -232,7 +233,7 @@ class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):


class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
kind: Literal["vlm"] = "vlm"
kind: ClassVar[Literal["vlm"]] = "vlm"

repo_id: str
prompt: str = "Describe this image in a few sentences."
Expand Down Expand Up @@ -264,6 +265,7 @@ class PdfBackend(str, Enum):


# Define an enum for the ocr engines
@deprecated("Use ocr_factory.registered_enum")
class OcrEngine(str, Enum):
"""Enum of valid OCR engines."""

Expand Down Expand Up @@ -297,17 +299,10 @@ class PdfPipelineOptions(PipelineOptions):
do_picture_description: bool = False # True: run describe pictures in documents

table_structure_options: TableStructureOptions = TableStructureOptions()
ocr_options: Union[
EasyOcrOptions,
TesseractCliOcrOptions,
TesseractOcrOptions,
OcrMacOptions,
RapidOcrOptions,
] = Field(EasyOcrOptions(), discriminator="kind")
picture_description_options: Annotated[
Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions],
Field(discriminator="kind"),
] = smolvlm_picture_description
ocr_options: OcrOptions = EasyOcrOptions()
picture_description_options: PictureDescriptionBaseOptions = (
smolvlm_picture_description
)

images_scale: float = 1.0
generate_page_images: bool = False
Expand Down
8 changes: 7 additions & 1 deletion docling/models/base_model.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,20 @@
from abc import ABC, abstractmethod
from typing import Any, Generic, Iterable, Optional
from typing import Any, Generic, Iterable, Optional, Protocol, Type

from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
from typing_extensions import TypeVar

from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import BaseOptions
from docling.datamodel.settings import settings


class BaseModelWithOptions(Protocol):
@classmethod
def get_options_type(cls) -> Type[BaseOptions]: ...


class BasePageModel(ABC):
@abstractmethod
def __call__(
Expand Down
21 changes: 16 additions & 5 deletions docling/models/base_ocr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from abc import abstractmethod
from pathlib import Path
from typing import Iterable, List
from typing import Iterable, List, Optional, Type

import numpy as np
from docling_core.types.doc import BoundingBox, CoordOrigin
Expand All @@ -12,15 +12,21 @@

from docling.datamodel.base_models import Cell, OcrCell, Page
from docling.datamodel.document import ConversionResult
from docling.datamodel.pipeline_options import OcrOptions
from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions
from docling.datamodel.settings import settings
from docling.models.base_model import BasePageModel
from docling.models.base_model import BaseModelWithOptions, BasePageModel

_log = logging.getLogger(__name__)


class BaseOcrModel(BasePageModel):
def __init__(self, enabled: bool, options: OcrOptions):
class BaseOcrModel(BasePageModel, BaseModelWithOptions):
def __init__(
self,
enabled: bool,
artifacts_path: Optional[Path],
options: OcrOptions,
accelerator_options: AcceleratorOptions,
):
self.enabled = enabled
self.options = options

Expand Down Expand Up @@ -187,3 +193,8 @@ def __call__(
self, conv_res: ConversionResult, page_batch: Iterable[Page]
) -> Iterable[Page]:
pass

@classmethod
@abstractmethod
def get_options_type(cls) -> Type[OcrOptions]:
pass
14 changes: 12 additions & 2 deletions docling/models/easyocr_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import warnings
import zipfile
from pathlib import Path
from typing import Iterable, List, Optional
from typing import Iterable, List, Optional, Type

import numpy
from docling_core.types.doc import BoundingBox, CoordOrigin
Expand All @@ -13,6 +13,7 @@
AcceleratorDevice,
AcceleratorOptions,
EasyOcrOptions,
OcrOptions,
)
from docling.datamodel.settings import settings
from docling.models.base_ocr_model import BaseOcrModel
Expand All @@ -33,7 +34,12 @@ def __init__(
options: EasyOcrOptions,
accelerator_options: AcceleratorOptions,
):
super().__init__(enabled=enabled, options=options)
super().__init__(
enabled=enabled,
artifacts_path=artifacts_path,
options=options,
accelerator_options=accelerator_options,
)
self.options: EasyOcrOptions

self.scale = 3 # multiplier for 72 dpi == 216 dpi.
Expand Down Expand Up @@ -175,3 +181,7 @@ def __call__(
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)

yield page

@classmethod
def get_options_type(cls) -> Type[OcrOptions]:
return EasyOcrOptions
14 changes: 14 additions & 0 deletions docling/models/factories/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
import logging
from functools import lru_cache

from docling.models.factories.ocr_factory import OcrFactory

logger = logging.getLogger(__name__)


@lru_cache(maxsize=1)
def get_ocr_factory():
factory = OcrFactory()
factory.load_from_plugins()
# logger.info("Registered ocr engines: %r", factory.registered_kind)
return factory
Loading

0 comments on commit 0fa9e14

Please sign in to comment.