Skip to content

Commit e583fb9

Browse files
committed
update ocr
1 parent 504d147 commit e583fb9

File tree

4 files changed

+169
-5
lines changed

4 files changed

+169
-5
lines changed

Diff for: AUTHORS.md

+2
Original file line numberDiff line numberDiff line change
@@ -22,4 +22,6 @@
2222

2323
[Shangzi Xue](https://github.com/ShangziXue)
2424

25+
[Heng Yu](https://github.com/GNEHUY)
26+
2527
The stared contributors are the corresponding authors.

Diff for: EduNLP/SIF/parser/ocr.py

+127
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,127 @@
1+
# coding: utf-8
2+
# 2024/3/5 @ yuheng
3+
import json
4+
import requests
5+
from EduNLP.utils import image2base64
6+
7+
class FormulaRecognitionError(Exception):
8+
"""Exception raised when formula recognition fails."""
9+
def __init__(self, message="Formula recognition failed"):
10+
self.message = message
11+
super().__init__(self.message)
12+
13+
def ocr_formula_figure(image_PIL_or_base64, is_base64=False):
14+
"""
15+
Recognizes mathematical formulas in an image and returns their LaTeX representation.
16+
17+
Parameters
18+
----------
19+
image_PIL_or_base64 : PngImageFile or str
20+
The PngImageFile if is_base64 is False, or the base64 encoded string of the image if is_base64 is True.
21+
is_base64 : bool, optional
22+
Indicates whether the image_PIL_or_base64 parameter is an PngImageFile or a base64 encoded string, by default False.
23+
24+
Returns
25+
-------
26+
latex : str
27+
The LaTeX representation of the mathematical formula recognized in the image. Raises an exception if the image is not recognized as containing a mathematical formula.
28+
29+
Raises
30+
------
31+
FormulaRecognitionError
32+
If the HTTP request does not return a 200 status code, if there is an error processing the response, or if the image is not recognized as a mathematical formula.
33+
34+
Examples
35+
--------
36+
>>> from PIL import Image
37+
>>> image_PIL = Image.open("path/to/your/image.jpg")
38+
>>> print(ocr_formula_figure(image_PIL))
39+
Or
40+
>>> image_base64 = "base64_encoded_image_string"
41+
>>> print(ocr_formula_figure(image_base64, is_base64=True))
42+
43+
Notes
44+
-----
45+
This function relies on an external service "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1",
46+
and the `requests` library to make HTTP requests. Make sure the required libraries are installed before use.
47+
"""
48+
url = "https://formula-recognition-service-47-production.env.iai.bdaa.pro/v1"
49+
50+
if is_base64:
51+
image = image_PIL_or_base64
52+
else:
53+
image = image2base64(image_PIL_or_base64)
54+
55+
data = [{
56+
'qid': 0,
57+
'image': image
58+
}]
59+
60+
resp = requests.post(url, data=json.dumps(data))
61+
62+
if resp.status_code != 200:
63+
raise FormulaRecognitionError(f"HTTP error {resp.status_code}: {resp.text}")
64+
65+
try:
66+
res = json.loads(resp.content)
67+
except Exception as e:
68+
raise FormulaRecognitionError(f"Error processing response: {e}")
69+
70+
res = json.loads(resp.content)
71+
data = res['data']
72+
if data['success'] == 1 and data['is_formula'] == 1 and data['detect_formula'] == 1:
73+
latex = data['latex']
74+
else:
75+
latex = None
76+
raise FormulaRecognitionError("Image is not recognized as a formula")
77+
78+
return latex
79+
80+
def ocr(src, is_base64=False, figure_instances: dict = None):
81+
"""
82+
Recognizes mathematical formulas within figures from a given source, which can be either a base64 string or an identifier for a figure within a provided dictionary.
83+
84+
Parameters
85+
----------
86+
src : str
87+
The source from which the figure is to be recognized. It can be a base64 encoded string of the image if is_base64 is True, or an identifier for the figure if is_base64 is False.
88+
is_base64 : bool, optional
89+
Indicates whether the src parameter is a base64 encoded string or an identifier, by default False.
90+
figure_instances : dict, optional
91+
A dictionary mapping figure identifiers to their corresponding PngImageFile, by default None. This is only required and used if is_base64 is False.
92+
93+
Returns
94+
-------
95+
forumla_figure_latex : str or None
96+
The LaTeX representation of the mathematical formula recognized within the figure. Returns None if no formula is recognized or if the figure_instances dictionary does not contain the specified figure identifier when is_base64 is False.
97+
98+
Examples
99+
--------
100+
>>> src_base64 = r"\FormFigureBase64{base64_encoded_image_string}"
101+
>>> print(ocr(src_base64, is_base64=True))
102+
Or
103+
>>> from PIL import Image
104+
>>> image_PIL = Image.open("path/to/your/image.jpg")
105+
>>> figure_instances = {"figure1": image_PIL}
106+
>>> src_id = r"\FormFigureID{figure1}"
107+
>>> print(ocr(src_id, figure_instances=figure_instances))
108+
109+
Notes
110+
-----
111+
This function relies on `ocr_formula_figure` for the actual OCR (Optical Character Recognition) process. Ensure that `ocr_formula_figure` is correctly implemented and can handle both base64 encoded strings and PngImageFile as input.
112+
"""
113+
forumla_figure_latex = None
114+
if is_base64:
115+
figure = src[len(r"\FormFigureBase64") + 1: -1]
116+
if figure_instances is not None:
117+
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
118+
else:
119+
figure = src[len(r"\FormFigureID") + 1: -1]
120+
if figure_instances is not None:
121+
figure = figure_instances[figure]
122+
forumla_figure_latex = ocr_formula_figure(figure, is_base64)
123+
124+
return forumla_figure_latex
125+
126+
127+

Diff for: EduNLP/SIF/segment/segment.py

+12-5
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
import re
66
from contextlib import contextmanager
77
from ..constants import Symbol, TEXT_SYMBOL, FORMULA_SYMBOL, FIGURE_SYMBOL, QUES_MARK_SYMBOL, TAG_SYMBOL, SEP_SYMBOL
8+
from ..parser.ocr import ocr
89

910

1011
class TextSegment(str):
@@ -93,7 +94,7 @@ class SegmentList(object):
9394
>>> SegmentList(test_item)
9495
['如图所示,则三角形', 'ABC', '的面积是', '\\\\SIFBlank', '。', \\FigureID{1}]
9596
"""
96-
def __init__(self, item, figures: dict = None):
97+
def __init__(self, item, figures: dict = None, convert_image_to_latex=False):
9798
self._segments = []
9899
self._text_segments = []
99100
self._formula_segments = []
@@ -112,9 +113,15 @@ def __init__(self, item, figures: dict = None):
112113
if not re.match(r"\$.+?\$", segment):
113114
self.append(TextSegment(segment))
114115
elif re.match(r"\$\\FormFigureID\{.+?}\$", segment):
115-
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
116+
if convert_image_to_latex:
117+
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=False, figure_instances=figures)))
118+
else:
119+
self.append(FigureFormulaSegment(segment[1:-1], is_base64=False, figure_instances=figures))
116120
elif re.match(r"\$\\FormFigureBase64\{.+?}\$", segment):
117-
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
121+
if convert_image_to_latex:
122+
self.append(LatexFormulaSegment(ocr(segment[1:-1], is_base64=True, figure_instances=figures)))
123+
else:
124+
self.append(FigureFormulaSegment(segment[1:-1], is_base64=True, figure_instances=figures))
118125
elif re.match(r"\$\\FigureID\{.+?}\$", segment):
119126
self.append(FigureSegment(segment[1:-1], is_base64=False, figure_instances=figures))
120127
elif re.match(r"\$\\FigureBase64\{.+?}\$", segment):
@@ -271,7 +278,7 @@ def describe(self):
271278
}
272279

273280

274-
def seg(item, figures=None, symbol=None):
281+
def seg(item, figures=None, symbol=None, convert_image_to_latex=False):
275282
r"""
276283
It is a interface for SegmentList. And show it in an appropriate way.
277284
@@ -346,7 +353,7 @@ def seg(item, figures=None, symbol=None):
346353
>>> s2.text_segments
347354
['已知', ',则以下说法中正确的是']
348355
"""
349-
segments = SegmentList(item, figures)
356+
segments = SegmentList(item, figures, convert_image_to_latex)
350357
if symbol is not None:
351358
segments.symbolize(symbol)
352359
return segments

Diff for: tests/test_sif/test_ocr.py

+28
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# 2024/3/5 @ yuheng
2+
3+
import pytest
4+
5+
from EduNLP.SIF.segment import seg
6+
7+
8+
def test_ocr(figure0, figure1, figure0_base64, figure1_base64):
9+
seg(
10+
r"如图所示,则$\FormFigureID{0}$的面积是$\SIFBlank$。$\FigureID{1}$",
11+
figures={
12+
"0": figure0,
13+
"1": figure1
14+
},
15+
convert_image_to_latex=True
16+
)
17+
s = seg(
18+
r"如图所示,则$\FormFigureBase64{%s}$的面积是$\SIFBlank$。$\FigureBase64{%s}$" % (figure0_base64, figure1_base64),
19+
figures=True,
20+
convert_image_to_latex=True
21+
)
22+
with pytest.raises(TypeError):
23+
s.append("123")
24+
seg_test_text = seg(
25+
r"如图所示,有三组$\textf{机器人,bu}$在踢$\textf{足球,b}$",
26+
figures=True
27+
)
28+
assert seg_test_text.text_segments == ['如图所示,有三组机器人在踢足球']

0 commit comments

Comments
 (0)