Skip to content

Commit

Permalink
Merge pull request #1664 from myhloli/dev
Browse files Browse the repository at this point in the history
feat(pdf_parse): improve OCR processing and contrast filtering
  • Loading branch information
myhloli authored Feb 9, 2025
2 parents 9bb2d58 + 5561ac9 commit 6e1fba9
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 20 deletions.
2 changes: 1 addition & 1 deletion magic_pdf/filter/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ def classify(pdf_bytes: bytes) -> SupportedPdfParseMethod:
pdf_meta['image_info_per_page'],
pdf_meta['text_len_per_page'],
pdf_meta['imgs_per_page'],
pdf_meta['text_layout_per_page'],
# pdf_meta['text_layout_per_page'],
pdf_meta['invalid_chars'],
)
if is_text_pdf:
Expand Down
10 changes: 6 additions & 4 deletions magic_pdf/filter/pdf_classify_by_type.py
Original file line number Diff line number Diff line change
Expand Up @@ -305,7 +305,8 @@ def is_narrow_strip(img):


def classify(total_page: int, page_width, page_height, img_sz_list: list, text_len_list: list, img_num_list: list,
text_layout_list: list, invalid_chars: bool):
# text_layout_list: list,
invalid_chars: bool):
"""
这里的图片和页面长度单位是pts
:param total_page:
Expand All @@ -321,7 +322,7 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
'by_text_len': classify_by_text_len(text_len_list, total_page),
'by_avg_words': classify_by_avg_words(text_len_list),
'by_img_num': classify_by_img_num(img_sz_list, img_num_list),
'by_text_layout': classify_by_text_layout(text_layout_list),
# 'by_text_layout': classify_by_text_layout(text_layout_list),
'by_img_narrow_strips': classify_by_img_narrow_strips(page_width, page_height, img_sz_list),
'by_invalid_chars': invalid_chars,
}
Expand All @@ -332,9 +333,10 @@ def classify(total_page: int, page_width, page_height, img_sz_list: list, text_l
return False, results
else:
logger.warning(
f"pdf is not classified by area and text_len, by_image_area: {results['by_image_area']},"
f"OCR needed based on classification result, by_image_area: {results['by_image_area']},"
f" by_text: {results['by_text_len']}, by_avg_words: {results['by_avg_words']}, by_img_num: {results['by_img_num']},"
f" by_text_layout: {results['by_text_layout']}, by_img_narrow_strips: {results['by_img_narrow_strips']},"
# f" by_text_layout: {results['by_text_layout']},"
f" by_img_narrow_strips: {results['by_img_narrow_strips']},"
f" by_invalid_chars: {results['by_invalid_chars']}",
file=sys.stderr) # 利用这种情况可以快速找出来哪些pdf比较特殊,针对性修正分类算法
return False, results
Expand Down
8 changes: 4 additions & 4 deletions magic_pdf/filter/pdf_meta_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -356,9 +356,9 @@ def pdf_meta_scan(pdf_bytes: bytes):
# logger.info(f"image_info_per_page: {image_info_per_page}, junk_img_bojids: {junk_img_bojids}")
text_len_per_page = get_pdf_textlen_per_page(doc)
# logger.info(f"text_len_per_page: {text_len_per_page}")
text_layout_per_page = get_pdf_text_layout_per_page(doc)
# text_layout_per_page = get_pdf_text_layout_per_page(doc)
# logger.info(f"text_layout_per_page: {text_layout_per_page}")
text_language = get_language(doc)
# text_language = get_language(doc)
# logger.info(f"text_language: {text_language}")
invalid_chars = check_invalid_chars(pdf_bytes)
# logger.info(f"invalid_chars: {invalid_chars}")
Expand All @@ -372,8 +372,8 @@ def pdf_meta_scan(pdf_bytes: bytes):
'page_height_pts': int(page_height_pts),
'image_info_per_page': image_info_per_page,
'text_len_per_page': text_len_per_page,
'text_layout_per_page': text_layout_per_page,
'text_language': text_language,
# 'text_layout_per_page': text_layout_per_page,
# 'text_language': text_language,
# "svgs_per_page": svgs_per_page,
'imgs_per_page': imgs_per_page, # 增加每页img数量list
'junk_img_bojids': junk_img_bojids, # 增加垃圾图片的bojid list
Expand Down
12 changes: 11 additions & 1 deletion magic_pdf/libs/pdf_check.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import re
from io import BytesIO
from pdfminer.high_level import extract_text
from pdfminer.layout import LAParams


def calculate_sample_count(total_page: int):
Expand Down Expand Up @@ -41,7 +42,16 @@ def detect_invalid_chars(src_pdf_bytes: bytes) -> bool:
sample_docs = extract_pages(src_pdf_bytes)
sample_pdf_bytes = sample_docs.tobytes()
sample_pdf_file_like_object = BytesIO(sample_pdf_bytes)
text = extract_text(sample_pdf_file_like_object)
laparams = LAParams(
line_overlap=0.5,
char_margin=2.0,
line_margin=0.5,
word_margin=0.1,
boxes_flow=None,
detect_vertical=False,
all_texts=False,
)
text = extract_text(pdf_file=sample_pdf_file_like_object, laparams=laparams)
text = text.replace("\n", "")
# logger.info(text)
'''乱码文本用pdfminer提取出来的文本特征是(cid:xxx)'''
Expand Down
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
# Copyright (c) Opendatalab. All rights reserved.
import time
from collections import Counter
from uuid import uuid4

Expand Down Expand Up @@ -102,9 +103,9 @@ def do_detect(self, images: list):
temp_images = split_images(image)
for temp_image in temp_images:
all_images.append(resize_images_to_224(temp_image))

images_lang_res = self.batch_predict(all_images, batch_size=8)
# logger.info(f"images_lang_res: {images_lang_res}")
# langdetect_start = time.time()
images_lang_res = self.batch_predict(all_images, batch_size=256)
# logger.info(f"image number of langdetect: {len(images_lang_res)}, langdetect time: {round(time.time() - langdetect_start, 2)}")
if len(images_lang_res) > 0:
count_dict = Counter(images_lang_res)
language = max(count_dict, key=count_dict.get)
Expand Down
46 changes: 39 additions & 7 deletions magic_pdf/pdf_parse_union_core_v2.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import time
from typing import List

import cv2
import fitz
import torch
import numpy as np
from loguru import logger

from magic_pdf.config.enums import SupportedPdfParseMethod
Expand Down Expand Up @@ -127,16 +129,15 @@ def fill_char_in_spans(spans, all_chars):
span['chars'].append(char)
break

empty_spans = []

need_ocr_spans = []
for span in spans:
chars_to_content(span)
# 有的span中虽然没有字但有一两个空的占位符,用宽高和content长度过滤
if len(span['content']) * span['height'] < span['width'] * 0.5:
# logger.info(f"maybe empty span: {len(span['content'])}, {span['height']}, {span['width']}")
empty_spans.append(span)
need_ocr_spans.append(span)
del span['height'], span['width']
return empty_spans
return need_ocr_spans


# 使用鲁棒性更强的中心点坐标判断
Expand Down Expand Up @@ -190,6 +191,31 @@ def remove_tilted_line(text_blocks):
block['lines'].remove(line)


def calculate_contrast(img, img_mode) -> float:
"""
计算给定图像的对比度。
:param img: 图像,类型为numpy.ndarray
:Param img_mode = 图像的色彩通道,'rgb' 或 'bgr'
:return: 图像的对比度值
"""
if img_mode == 'rgb':
# 将RGB图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_RGB2GRAY)
elif img_mode == 'bgr':
# 将BGR图像转换为灰度图
gray_img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
else:
raise ValueError("Invalid image mode. Please provide 'rgb' or 'bgr'.")

# 计算均值和标准差
mean_value = np.mean(gray_img)
std_dev = np.std(gray_img)
# 对比度定义为标准差除以平均值(加上小常数避免除零错误)
contrast = std_dev / (mean_value + 1e-6)
# logger.info(f"contrast: {contrast}")
return round(contrast, 2)


def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang):
# cid用0xfffd表示,连字符拆开
# text_blocks_raw = pdf_page.get_text('rawdict', flags=fitz.TEXT_PRESERVE_WHITESPACE | fitz.TEXT_MEDIABOX_CLIP)['blocks']
Expand Down Expand Up @@ -274,9 +300,9 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
span['chars'] = []
new_spans.append(span)

empty_spans = fill_char_in_spans(new_spans, all_pymu_chars)
need_ocr_spans = fill_char_in_spans(new_spans, all_pymu_chars)

if len(empty_spans) > 0:
if len(need_ocr_spans) > 0:

# 初始化ocr模型
atom_model_manager = AtomModelSingleton()
Expand All @@ -287,9 +313,15 @@ def txt_spans_extract_v2(pdf_page, spans, all_bboxes, all_discarded_blocks, lang
lang=lang
)

for span in empty_spans:
for span in need_ocr_spans:
# 对span的bbox截图再ocr
span_img = cut_image_to_pil_image(span['bbox'], pdf_page, mode='cv2')

# 计算span的对比度,低于0.20的span不进行ocr
if calculate_contrast(span_img, img_mode='bgr') <= 0.20:
spans.remove(span)
continue

ocr_res = ocr_model.ocr(span_img, det=False)
if ocr_res and len(ocr_res) > 0:
if len(ocr_res[0]) > 0:
Expand Down

0 comments on commit 6e1fba9

Please sign in to comment.