Skip to content

Commit 855e250

Browse files
authored
Merge pull request #157 from GNEHUY/dev
Update OCR module
2 parents 504d147 + d235439 commit 855e250

File tree

7 files changed

+246
-13
lines changed

7 files changed

+246
-13
lines changed

AUTHORS.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@
2222

2323
[Shangzi Xue](https://github.com/ShangziXue)
2424

25+
[Heng Yu](https://github.com/GNEHUY)
26+
2527
The stared contributors are the corresponding authors.

EduNLP/SIF/parser/ocr.py

+156
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,156 @@
1+
# coding: utf-8
2+
# 2024/3/5 @ yuheng
3+
import json
4+
import requests
5+
from EduNLP.utils import image2base64
6+
7+
8+
class FormulaRecognitionError(Exception):
9+
"""Exception raised when formula recognition fails."""
10+
def __init__(self, message="Formula recognition failed"):
11+
self.message = message
12+
super().__init__(self.message)
13+
14+
15+
def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
16+
"""
17+
Recognizes mathematical formulas in an image and returns their LaTeX representation.
18+
19+
Parameters
20+
----------
21+
image_PIL_or_base64 : PngImageFile or str
22+
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
23+
is_base64 : bool, optional
24+
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string.
25+
26+
Returns
27+
-------
28+
latex : str
29+
The LaTeX representation of the mathematical formula recognized in the image.
30+
Raises an exception if the image is not recognized as containing a mathematical formula.
31+
32+
Raises
33+
------
34+
FormulaRecognitionError
35+
If the HTTP request does not return a 200 status code,
36+
if there is an error processing the response,
37+
if the image is not recognized as a mathematical formula.
38+
39+
Examples
40+
--------
41+
>>> import os
42+
>>> from PIL import Image
43+
>>> from EduNLP.utils import abs_current_dir, path_append
44+
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
45+
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
46+
>>> print(ocr_formula_figure(image_PIL))
47+
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
48+
>>> import os
49+
>>> from PIL import Image
50+
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
51+
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
52+
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
53+
>>> image_base64 = image2base64(image_PIL)
54+
>>> print(ocr_formula_figure(image_base64, is_base64=True))
55+
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
56+
57+
Notes
58+
-----
59+
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
60+
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
61+
"""
62+
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"
63+
64+
if is_base64:
65+
image = image_PIL_or_base64
66+
else:
67+
image = image2base64(image_PIL_or_base64)
68+
69+
data = [{
70+
'qid': 0,
71+
'image': image
72+
}]
73+
74+
resp = requests.post(url, data=json.dumps(data))
75+
76+
if resp.status_code != 200:
77+
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")
78+
79+
try:
80+
res = json.loads(resp.content)
81+
except Exception as e:
82+
raise FormulaRecognitionError(f"Error processing response: {e}")
83+
84+
res = json.loads(resp.content)
85+
data = res['data']
86+
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
87+
latex = data['latex']
88+
else:
89+
latex = None
90+
raise FormulaRecognitionError("Image is not recognized as a formula")
91+
92+
return latex
93+
94+
95+
def ocr(src, is_base64=False, figure_instances: dict = None):
96+
"""
97+
Recognizes mathematical formulas within figures from a given source,
98+
which can be either a base64 string or an identifier for a figure within a provided dictionary.
99+
100+
Parameters
101+
----------
102+
src : str
103+
The source from which the figure is to be recognized.
104+
It can be a base64 encoded string of the image if is_base64 is True,
105+
or an identifier for the figure if is_base64 is False.
106+
is_base64 : bool, optional
107+
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
108+
figure_instances : dict, optional
109+
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None.
110+
This is only required and used if is_base64 is False.
111+
112+
Returns
113+
-------
114+
forumla_figure_latex : str or None
115+
The LaTeX representation of the mathematical formula recognized within the figure.
116+
Returns None if no formula is recognized or
117+
if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.
118+
119+
Examples
120+
--------
121+
>>> import os
122+
>>> from PIL import Image
123+
>>> from EduNLP.utils import abs_current_dir, path_append
124+
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
125+
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
126+
>>> figure_instances = {"1": image_PIL}
127+
>>> src_id = r"$\\FormFigureID{1}$"
128+
>>> print(ocr(src_id[1:-1], figure_instances=figure_instances))
129+
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
130+
>>> import os
131+
>>> from PIL import Image
132+
>>> from EduNLP.utils import abs_current_dir, path_append, image2base64
133+
>>> img_dir = os.path.abspath(path_append(abs_current_dir(__file__), "..", "..", "..", "asset", "_static"))
134+
>>> image_PIL = Image.open(path_append(img_dir, "item_ocr_formula.png", to_str=True))
135+
>>> image_base64 = image2base64(image_PIL)
136+
>>> src_base64 = r"$\\FormFigureBase64{%s}$" % (image_base64)
137+
>>> print(ocr(src_base64[1:-1], is_base64=True, figure_instances=True))
138+
f(x)=\\left (\\frac {1}{3}\\right )^{x}-\\sqrt {x}}
139+
140+
Notes
141+
-----
142+
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process.
143+
Ensure that `ocr_formula_figure` is correctly implemented and can handle base64 encoded strings and PngImageFile.
144+
"""
145+
forumla_figure_latex = None
146+
if is_base64:
147+
figure = src[len(r"\FormFigureBase64") + 1: -1]
148+
if figure_instances is not None:
149+
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
150+
else:
151+
figure = src[len(r"\FormFigureID") + 1: -1]
152+
if figure_instances is not None:
153+
figure = figure_instances[figure]
154+
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
155+
156+
return forumla_figure_latex

EduNLP/SIF/segment/segment.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from contextlib import contextmanager
77
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
8+
from ..parser.ocr import ocr
89

910

1011
class TextSegment(str):
@@ -93,7 +94,7 @@ class SegmentList(object):
9394
>>> SegmentList(test_item)
9495
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
9596
"""
96-
def __init__(self, item, figures: dict = None):
97+
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
9798
self._segments = []
9899
self._text_segments = []
99100
self._formula_segments = []
@@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
112113
if not re.match(r"\$.+?\$", segment):
113114
self.append(TextSegment(segment))
114115
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
115-
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
116+
if convert_image_to_latex:
117+
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
118+
else:
119+
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
116120
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
117-
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
121+
if convert_image_to_latex:
122+
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
123+
else:
124+
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
118125
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
119126
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
120127
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
@@ -271,7 +278,7 @@ def describe(self):
271278
}
272279

273280

274-
def seg(item, figures=None, symbol=None):
281+
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
275282
r"""
276283
It is a interface for SegmentList. And show it in an appropriate way.
277284
@@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
346353
>>> s2.text_segments
347354
['已知', ',则以下说法中正确的是']
348355
"""
349-
segments = SegmentList(item, figures)
356+
segments = SegmentList(item, figures, convert_image_to_latex)
350357
if symbol is not None:
351358
segments.symbolize(symbol)
352359
return segments

EduNLP/SIF/sif.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ def to_sif(item, check_formula=True, parser: Parser = None):
9797

9898

9999
def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str = None, tokenization=True,
100-
tokenization_params=None, errors="raise"):
100+
tokenization_params=None, convert_image_to_latex=False, errors="raise"):
101101
r"""
102102
103103
Default to use linear Tokenizer, change the tokenizer by specifying tokenization_params
@@ -260,7 +260,7 @@ def sif4sci(item: str, figures: (dict, bool) = None, mode: int = 2, symbol: str
260260
"Unknown mode %s, use only 0 or 1 or 2." % mode
261261
)
262262

263-
ret = seg(item, figures, symbol)
263+
ret = seg(item, figures, symbol, convert_image_to_latex)
264264

265265
if tokenization is True:
266266
ret = tokenize(ret, **(tokenization_params if tokenization_params is not None else {}))

EduNLP/Tokenizer/tokenizer.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -56,7 +56,8 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
5656
for item in items:
5757
yield self._tokenize(item, key=key, **kwargs)
5858

59-
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
59+
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
60+
convert_image_to_latex=False, **kwargs):
6061
"""Tokenize one item, return token list
6162
6263
Parameters
@@ -67,7 +68,8 @@ def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
6768
determine how to get the text of item, by default lambdax: x
6869
"""
6970
symbol = self.symbol if symbol is None else symbol
70-
return tokenize(seg(key(item), symbol=symbol, figures=self.figures),
71+
return tokenize(seg(key(item), symbol=symbol, figures=self.figures,
72+
convert_image_to_latex=convert_image_to_latex),
7173
**self.tokenization_params, **kwargs).tokens
7274

7375

@@ -191,9 +193,11 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
191193
for item in items:
192194
yield self._tokenize(item, key=key, **kwargs)
193195

194-
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
196+
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
197+
convert_image_to_latex=False, **kwargs):
195198
symbol = self.symbol if symbol is None else symbol
196-
return tokenize(seg(key(item), symbol=symbol), **self.tokenization_params, **kwargs).tokens
199+
return tokenize(seg(key(item), symbol=symbol, convert_image_to_latex=convert_image_to_latex),
200+
**self.tokenization_params, **kwargs).tokens
197201

198202

199203
class AstFormulaTokenizer(Tokenizer):
@@ -235,11 +239,13 @@ def __call__(self, items: Iterable, key=lambda x: x, **kwargs):
235239
for item in items:
236240
yield self._tokenize(item, key=key, **kwargs)
237241

238-
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None, **kwargs):
242+
def _tokenize(self, item: Union[str, dict], key=lambda x: x, symbol: str = None,
243+
convert_image_to_latex=False, **kwargs):
239244
mode = kwargs.pop("mode", 0)
240245
symbol = self.symbol if symbol is None else symbol
241246
ret = sif4sci(key(item), figures=self.figures, mode=mode, symbol=symbol,
242-
tokenization_params=self.tokenization_params, errors="ignore", **kwargs)
247+
tokenization_params=self.tokenization_params, convert_image_to_latex=convert_image_to_latex,
248+
errors="ignore", **kwargs)
243249
ret = [] if ret is None else ret.tokens
244250
return ret
245251

asset/_static/item_ocr_formula.png

5.69 KB
Loading

tests/test_sif/test_ocr.py

+62
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,62 @@
1+
# 2024/3/5 @ yuheng
2+
3+
import pytest
4+
import json
5+
6+
from EduNLP.SIF.segment import seg
7+
from EduNLP.SIF.parser.ocr import ocr_formula_figure, FormulaRecognitionError
8+
from unittest.mock import patch
9+
10+
11+
def test_ocr(figure0, figure1, figure0_base64, figure1_base64):
12+
seg(
13+
r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$",
14+
figures={
15+
"0": figure0,
16+
"1": figure1
17+
},
18+
convert_image_to_latex=True
19+
)
20+
s = seg(
21+
r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64),
22+
figures=True,
23+
convert_image_to_latex=True
24+
)
25+
with pytest.raises(TypeError):
26+
s.append("123")
27+
seg_test_text = seg(
28+
r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$",
29+
figures=True
30+
)
31+
assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']
32+
33+
34+
def test_ocr_formula_figure_exceptions(figure0_base64):
35+
"""Simulate a non-200 status code"""
36+
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
37+
mock_post.return_value.status_code = 404
38+
with pytest.raises(FormulaRecognitionError) as exc_info:
39+
ocr_formula_figure(figure0_base64, is_base64=True)
40+
assert "HTTP error 404" in str(exc_info.value)
41+
42+
"""Simulate an invalid JSON response"""
43+
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
44+
mock_post.return_value.status_code = 200
45+
mock_post.return_value.content = b"invalid_json_response"
46+
with pytest.raises(FormulaRecognitionError) as exc_info:
47+
ocr_formula_figure(figure0_base64, is_base64=True)
48+
assert "Error processing response" in str(exc_info.value)
49+
50+
"""Simulate image not recognized as a formula"""
51+
with patch('EduNLP.SIF.parser.ocr.requests.post') as mock_post:
52+
mock_post.return_value.status_code = 200
53+
mock_post.return_value.content = json.dumps({
54+
"data": {
55+
'success': 1,
56+
'is_formula': 0,
57+
'detect_formula': 0
58+
}
59+
}).encode('utf-8')
60+
with pytest.raises(FormulaRecognitionError) as exc_info:
61+
ocr_formula_figure(figure0_base64, is_base64=True)
62+
assert "Image is not recognized as a formula" in str(exc_info.value)

0 commit comments

Comments
 (0)