Skip to content

Commit 21f8ac9

Browse files
committed
test(wired_table_rec): Add issue #13 unit testing
1 parent 8f55dfe commit 21f8ac9

File tree

5 files changed

+24
-25
lines changed

5 files changed

+24
-25
lines changed

.github/workflows/wired_table_rec.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ jobs:
2727
- name: Unit testings
2828
run: |
2929
pip install -r requirements.txt
30-
pip install pytest
30+
pip install pytest beautifulsoup4
3131
3232
wget https://github.com/RapidAI/TableStructureRec/releases/download/v0.0.0/wired_table_rec_models.zip
3333
unzip wired_table_rec_models.zip

demo_wired.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22
# @Author: SWHL
33
# @Contact: [email protected]
44
from pathlib import Path
5-
from wired_table_rec import WiredTableRecognition
65

6+
from wired_table_rec import WiredTableRecognition
77

88
table_rec = WiredTableRecognition()
99

10-
img_path = "tests/test_files/wired/row_span.png"
10+
img_path = "tests/test_files/wired/squeeze_error.jpeg"
1111
table_str, elapse = table_rec(img_path)
1212
print(table_str)
1313
print(elapse)
839 KB
Loading

tests/test_wired.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from pathlib import Path
66

77
import pytest
8+
from bs4 import BeautifulSoup
89
from rapidocr_onnxruntime import RapidOCR
910

1011
cur_dir = Path(__file__).resolve().parent
@@ -20,36 +21,35 @@
2021
ocr_engine = RapidOCR()
2122

2223

23-
@pytest.mark.parametrize(
24-
"img_path, gt1, gt2",
25-
[
26-
("table_recognition.jpg", 1245, "d colsp"),
27-
("table2.jpg", 924, "td><td "),
28-
("row_span.png", 312, "></td><"),
29-
],
30-
)
31-
def test_input_normal(img_path, gt1, gt2):
32-
img_path = test_file_dir / img_path
24+
def get_td_nums(html: str) -> int:
25+
soup = BeautifulSoup(html, "html.parser")
26+
tds = soup.table.find_all("td")
27+
return len(tds)
28+
3329

30+
def test_squeeze_bug():
31+
img_path = test_file_dir / "squeeze_error.jpeg"
3432
ocr_result, _ = ocr_engine(img_path)
3533
table_str, _ = table_recog(str(img_path), ocr_result)
36-
37-
assert len(table_str) >= gt1
38-
assert table_str[-53:-46] == gt2
34+
td_nums = get_td_nums(table_str)
35+
assert td_nums == 153
3936

4037

4138
@pytest.mark.parametrize(
42-
"img_path, gt1, gt2",
39+
"img_path, gt_td_nums, gt2",
4340
[
44-
("table_recognition.jpg", 1245, "d colsp"),
45-
("table2.jpg", 924, "td><td "),
46-
("row_span.png", 311, "></td><"),
41+
("table_recognition.jpg", 35, "d colsp"),
42+
("table2.jpg", 22, "td><td "),
43+
("row_span.png", 17, "></td><"),
4744
],
4845
)
49-
def test_input_without_ocr(img_path, gt1, gt2):
46+
def test_input_normal(img_path, gt_td_nums, gt2):
5047
img_path = test_file_dir / img_path
5148

52-
table_str, _ = table_recog(str(img_path))
49+
ocr_result, _ = ocr_engine(img_path)
50+
table_str, _ = table_recog(str(img_path), ocr_result)
51+
td_nums = get_td_nums(table_str)
5352

54-
assert len(table_str) >= gt1
53+
assert td_nums == gt_td_nums
5554
assert table_str[-53:-46] == gt2
55+

wired_table_rec/main.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,8 +59,7 @@ def __call__(
5959
except Exception:
6060
logging.warning(traceback.format_exc())
6161
return "", 0.0
62-
else:
63-
return table_str, elapse
62+
return table_str, elapse
6463

6564

6665
def main():

0 commit comments

Comments
 (0)