Skip to content

Commit 0fa9e14

Browse files
committed
add factory for ocr engines
Signed-off-by: Michele Dolfi <[email protected]>
1 parent 7450050 commit 0fa9e14

17 files changed

+367
-139
lines changed

docling/cli/main.py

Lines changed: 10 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -29,18 +29,13 @@
2929
AcceleratorDevice,
3030
AcceleratorOptions,
3131
EasyOcrOptions,
32-
OcrEngine,
33-
OcrMacOptions,
34-
OcrOptions,
3532
PdfBackend,
3633
PdfPipelineOptions,
37-
RapidOcrOptions,
3834
TableFormerMode,
39-
TesseractCliOcrOptions,
40-
TesseractOcrOptions,
4135
)
4236
from docling.datamodel.settings import settings
4337
from docling.document_converter import DocumentConverter, FormatOption, PdfFormatOption
38+
from docling.models.factories import get_ocr_factory
4439

4540
warnings.filterwarnings(action="ignore", category=UserWarning, module="pydantic|torch")
4641
warnings.filterwarnings(action="ignore", category=FutureWarning, module="easyocr")
@@ -50,6 +45,8 @@
5045

5146
err_console = Console(stderr=True)
5247

48+
ocr_factory = get_ocr_factory()
49+
ocr_engines_enum = ocr_factory.get_enum()
5350

5451
app = typer.Typer(
5552
name="Docling",
@@ -194,9 +191,11 @@ def convert(
194191
help="Replace any existing text with OCR generated text over the full content.",
195192
),
196193
] = False,
197-
ocr_engine: Annotated[
198-
OcrEngine, typer.Option(..., help="The OCR engine to use.")
199-
] = OcrEngine.EASYOCR,
194+
ocr_engine: Annotated[ # type: ignore
195+
ocr_engines_enum,
196+
# ocr_factory.get_registered_enum(),
197+
typer.Option(..., help="The OCR engine to use."),
198+
] = EasyOcrOptions.kind,
200199
ocr_lang: Annotated[
201200
Optional[str],
202201
typer.Option(
@@ -367,18 +366,8 @@ def convert(
367366
export_txt = OutputFormat.TEXT in to_formats
368367
export_doctags = OutputFormat.DOCTAGS in to_formats
369368

370-
if ocr_engine == OcrEngine.EASYOCR:
371-
ocr_options: OcrOptions = EasyOcrOptions(force_full_page_ocr=force_ocr)
372-
elif ocr_engine == OcrEngine.TESSERACT_CLI:
373-
ocr_options = TesseractCliOcrOptions(force_full_page_ocr=force_ocr)
374-
elif ocr_engine == OcrEngine.TESSERACT:
375-
ocr_options = TesseractOcrOptions(force_full_page_ocr=force_ocr)
376-
elif ocr_engine == OcrEngine.OCRMAC:
377-
ocr_options = OcrMacOptions(force_full_page_ocr=force_ocr)
378-
elif ocr_engine == OcrEngine.RAPIDOCR:
379-
ocr_options = RapidOcrOptions(force_full_page_ocr=force_ocr)
380-
else:
381-
raise RuntimeError(f"Unexpected OCR engine type {ocr_engine}")
369+
ocr_options_class = ocr_factory.get_options_class(kind=str(ocr_engine.value)) # type: ignore
370+
ocr_options = ocr_options_class(force_full_page_ocr=force_ocr)
382371

383372
ocr_lang_list = _split_list(ocr_lang)
384373
if ocr_lang_list is not None:

docling/datamodel/pipeline_options.py

Lines changed: 21 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,9 @@
11
import logging
22
import os
33
import re
4-
import warnings
54
from enum import Enum
65
from pathlib import Path
7-
from typing import Annotated, Any, Dict, List, Literal, Optional, Union
6+
from typing import Any, ClassVar, Dict, List, Literal, Optional, Union
87

98
from pydantic import (
109
AnyUrl,
@@ -13,11 +12,9 @@
1312
Field,
1413
field_validator,
1514
model_validator,
16-
validator,
1715
)
1816
from pydantic_settings import (
1917
BaseSettings,
20-
PydanticBaseSettingsSource,
2118
SettingsConfigDict,
2219
)
2320
from typing_extensions import deprecated
@@ -82,6 +79,12 @@ def check_alternative_envvars(cls, data: Any) -> Any:
8279
return data
8380

8481

82+
class BaseOptions(BaseModel):
83+
"""Base class for options."""
84+
85+
kind: ClassVar[str]
86+
87+
8588
class TableFormerMode(str, Enum):
8689
"""Modes for the TableFormer model."""
8790

@@ -101,10 +104,9 @@ class TableStructureOptions(BaseModel):
101104
mode: TableFormerMode = TableFormerMode.FAST
102105

103106

104-
class OcrOptions(BaseModel):
107+
class OcrOptions(BaseOptions):
105108
"""OCR options."""
106109

107-
kind: str
108110
lang: List[str]
109111
force_full_page_ocr: bool = False # If enabled a full page OCR is always applied
110112
bitmap_area_threshold: float = (
@@ -115,7 +117,7 @@ class OcrOptions(BaseModel):
115117
class RapidOcrOptions(OcrOptions):
116118
"""Options for the RapidOCR engine."""
117119

118-
kind: Literal["rapidocr"] = "rapidocr"
120+
kind: ClassVar[Literal["rapidocr"]] = "rapidocr"
119121

120122
# English and chinese are the most commly used models and have been tested with RapidOCR.
121123
lang: List[str] = [
@@ -154,7 +156,7 @@ class RapidOcrOptions(OcrOptions):
154156
class EasyOcrOptions(OcrOptions):
155157
"""Options for the EasyOCR engine."""
156158

157-
kind: Literal["easyocr"] = "easyocr"
159+
kind: ClassVar[Literal["easyocr"]] = "easyocr"
158160
lang: List[str] = ["fr", "de", "es", "en"]
159161

160162
use_gpu: Optional[bool] = None
@@ -174,7 +176,7 @@ class EasyOcrOptions(OcrOptions):
174176
class TesseractCliOcrOptions(OcrOptions):
175177
"""Options for the TesseractCli engine."""
176178

177-
kind: Literal["tesseract"] = "tesseract"
179+
kind: ClassVar[Literal["tesseract"]] = "tesseract"
178180
lang: List[str] = ["fra", "deu", "spa", "eng"]
179181
tesseract_cmd: str = "tesseract"
180182
path: Optional[str] = None
@@ -187,7 +189,7 @@ class TesseractCliOcrOptions(OcrOptions):
187189
class TesseractOcrOptions(OcrOptions):
188190
"""Options for the Tesseract engine."""
189191

190-
kind: Literal["tesserocr"] = "tesserocr"
192+
kind: ClassVar[Literal["tesserocr"]] = "tesserocr"
191193
lang: List[str] = ["fra", "deu", "spa", "eng"]
192194
path: Optional[str] = None
193195

@@ -199,7 +201,7 @@ class TesseractOcrOptions(OcrOptions):
199201
class OcrMacOptions(OcrOptions):
200202
"""Options for the Mac OCR engine."""
201203

202-
kind: Literal["ocrmac"] = "ocrmac"
204+
kind: ClassVar[Literal["ocrmac"]] = "ocrmac"
203205
lang: List[str] = ["fr-FR", "de-DE", "es-ES", "en-US"]
204206
recognition: str = "accurate"
205207
framework: str = "vision"
@@ -209,8 +211,7 @@ class OcrMacOptions(OcrOptions):
209211
)
210212

211213

212-
class PictureDescriptionBaseOptions(BaseModel):
213-
kind: str
214+
class PictureDescriptionBaseOptions(BaseOptions):
214215
batch_size: int = 8
215216
scale: float = 2
216217

@@ -220,7 +221,7 @@ class PictureDescriptionBaseOptions(BaseModel):
220221

221222

222223
class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
223-
kind: Literal["api"] = "api"
224+
kind: ClassVar[Literal["api"]] = "api"
224225

225226
url: AnyUrl = AnyUrl("http://localhost:8000/v1/chat/completions")
226227
headers: Dict[str, str] = {}
@@ -232,7 +233,7 @@ class PictureDescriptionApiOptions(PictureDescriptionBaseOptions):
232233

233234

234235
class PictureDescriptionVlmOptions(PictureDescriptionBaseOptions):
235-
kind: Literal["vlm"] = "vlm"
236+
kind: ClassVar[Literal["vlm"]] = "vlm"
236237

237238
repo_id: str
238239
prompt: str = "Describe this image in a few sentences."
@@ -264,6 +265,7 @@ class PdfBackend(str, Enum):
264265

265266

266267
# Define an enum for the ocr engines
268+
@deprecated("Use ocr_factory.registered_enum")
267269
class OcrEngine(str, Enum):
268270
"""Enum of valid OCR engines."""
269271

@@ -297,17 +299,10 @@ class PdfPipelineOptions(PipelineOptions):
297299
do_picture_description: bool = False # True: run describe pictures in documents
298300

299301
table_structure_options: TableStructureOptions = TableStructureOptions()
300-
ocr_options: Union[
301-
EasyOcrOptions,
302-
TesseractCliOcrOptions,
303-
TesseractOcrOptions,
304-
OcrMacOptions,
305-
RapidOcrOptions,
306-
] = Field(EasyOcrOptions(), discriminator="kind")
307-
picture_description_options: Annotated[
308-
Union[PictureDescriptionApiOptions, PictureDescriptionVlmOptions],
309-
Field(discriminator="kind"),
310-
] = smolvlm_picture_description
302+
ocr_options: OcrOptions = EasyOcrOptions()
303+
picture_description_options: PictureDescriptionBaseOptions = (
304+
smolvlm_picture_description
305+
)
311306

312307
images_scale: float = 1.0
313308
generate_page_images: bool = False

docling/models/base_model.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,20 @@
11
from abc import ABC, abstractmethod
2-
from typing import Any, Generic, Iterable, Optional
2+
from typing import Any, Generic, Iterable, Optional, Protocol, Type
33

44
from docling_core.types.doc import BoundingBox, DocItem, DoclingDocument, NodeItem
55
from typing_extensions import TypeVar
66

77
from docling.datamodel.base_models import ItemAndImageEnrichmentElement, Page
88
from docling.datamodel.document import ConversionResult
9+
from docling.datamodel.pipeline_options import BaseOptions
910
from docling.datamodel.settings import settings
1011

1112

13+
class BaseModelWithOptions(Protocol):
14+
@classmethod
15+
def get_options_type(cls) -> Type[BaseOptions]: ...
16+
17+
1218
class BasePageModel(ABC):
1319
@abstractmethod
1420
def __call__(

docling/models/base_ocr_model.py

Lines changed: 16 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import logging
33
from abc import abstractmethod
44
from pathlib import Path
5-
from typing import Iterable, List
5+
from typing import Iterable, List, Optional, Type
66

77
import numpy as np
88
from docling_core.types.doc import BoundingBox, CoordOrigin
@@ -12,15 +12,21 @@
1212

1313
from docling.datamodel.base_models import Cell, OcrCell, Page
1414
from docling.datamodel.document import ConversionResult
15-
from docling.datamodel.pipeline_options import OcrOptions
15+
from docling.datamodel.pipeline_options import AcceleratorOptions, OcrOptions
1616
from docling.datamodel.settings import settings
17-
from docling.models.base_model import BasePageModel
17+
from docling.models.base_model import BaseModelWithOptions, BasePageModel
1818

1919
_log = logging.getLogger(__name__)
2020

2121

22-
class BaseOcrModel(BasePageModel):
23-
def __init__(self, enabled: bool, options: OcrOptions):
22+
class BaseOcrModel(BasePageModel, BaseModelWithOptions):
23+
def __init__(
24+
self,
25+
enabled: bool,
26+
artifacts_path: Optional[Path],
27+
options: OcrOptions,
28+
accelerator_options: AcceleratorOptions,
29+
):
2430
self.enabled = enabled
2531
self.options = options
2632

@@ -187,3 +193,8 @@ def __call__(
187193
self, conv_res: ConversionResult, page_batch: Iterable[Page]
188194
) -> Iterable[Page]:
189195
pass
196+
197+
@classmethod
198+
@abstractmethod
199+
def get_options_type(cls) -> Type[OcrOptions]:
200+
pass

docling/models/easyocr_model.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
import warnings
33
import zipfile
44
from pathlib import Path
5-
from typing import Iterable, List, Optional
5+
from typing import Iterable, List, Optional, Type
66

77
import numpy
88
from docling_core.types.doc import BoundingBox, CoordOrigin
@@ -13,6 +13,7 @@
1313
AcceleratorDevice,
1414
AcceleratorOptions,
1515
EasyOcrOptions,
16+
OcrOptions,
1617
)
1718
from docling.datamodel.settings import settings
1819
from docling.models.base_ocr_model import BaseOcrModel
@@ -33,7 +34,12 @@ def __init__(
3334
options: EasyOcrOptions,
3435
accelerator_options: AcceleratorOptions,
3536
):
36-
super().__init__(enabled=enabled, options=options)
37+
super().__init__(
38+
enabled=enabled,
39+
artifacts_path=artifacts_path,
40+
options=options,
41+
accelerator_options=accelerator_options,
42+
)
3743
self.options: EasyOcrOptions
3844

3945
self.scale = 3 # multiplier for 72 dpi == 216 dpi.
@@ -175,3 +181,7 @@ def __call__(
175181
self.draw_ocr_rects_and_cells(conv_res, page, ocr_rects)
176182

177183
yield page
184+
185+
@classmethod
186+
def get_options_type(cls) -> Type[OcrOptions]:
187+
return EasyOcrOptions

docling/models/factories/__init__.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
import logging
2+
from functools import lru_cache
3+
4+
from docling.models.factories.ocr_factory import OcrFactory
5+
6+
logger = logging.getLogger(__name__)
7+
8+
9+
@lru_cache(maxsize=1)
10+
def get_ocr_factory():
11+
factory = OcrFactory()
12+
factory.load_from_plugins()
13+
# logger.info("Registered ocr engines: %r", factory.registered_kind)
14+
return factory

0 commit comments

Comments
 (0)