Skip to content

Commit 1dc25f5

Browse files
authored
implemented boilerplate transforms of documents. (#668)
* implemented boilerplate into pipeline * fixed tests * linted * fixed mypy * linted and fixed tests * updated neo4j auth * addressed vinayak comments * linted * added tests for new map functions * changed neo4j auth * linted
1 parent 06c462c commit 1dc25f5

File tree

5 files changed

+186
-112
lines changed

5 files changed

+186
-112
lines changed

.github/workflows/testing.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,9 +127,9 @@ jobs:
127127
neo4j:
128128
image: neo4j:5.21.0
129129
env:
130-
NEO4J_AUTH: neo4j/koala-stereo-comedy-spray-figure-6974 # DO NOT SET PASSWORD LESS THAN 8 CHARACTERS
131130
NEO4J_dbms_memory_heap_initial__size: 2G
132131
NEO4J_dbms_memory_heap_max__size: 2G
132+
NEO4J_dbms_security_auth__enabled: false
133133
NEO4J_apoc_export_file_enabled: true
134134
NEO4J_apoc_import_file_enabled: true
135135
NEO4J_apoc_import_file_use__neo4j__config: true

lib/sycamore/sycamore/docset.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -549,6 +549,9 @@ def extract_graph_structure(self, extractors: list[GraphExtractor], **kwargs) ->
549549
.explode()
550550
)
551551
"""
552+
from sycamore.transforms.extract_graph import ExtractDocumentStructure
553+
554+
self.plan = ExtractDocumentStructure(self.plan)
552555
docset = self
553556
for extractor in extractors:
554557
docset = extractor.extract(docset)

lib/sycamore/sycamore/tests/integration/connectors/neo4j/test_docset_to_neo4j.py

Lines changed: 2 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -1,104 +1,15 @@
11
import sycamore
22
from sycamore.tests.config import TEST_DIR
3-
from sycamore.data import HierarchicalDocument, Document
43
from sycamore.transforms.partition import SycamorePartitioner
54

65

76
def test_to_neo4j():
8-
# helper function
9-
def restructure_doc(doc: Document) -> HierarchicalDocument:
10-
doc = HierarchicalDocument(doc.data)
11-
return doc
12-
13-
# helper function
14-
def children_to_section(doc: HierarchicalDocument) -> HierarchicalDocument:
15-
import uuid
16-
17-
# if the first element is not a section header, insert generic placeholder
18-
if len(doc.children) > 0 and doc.children[0]["type"] != "Section-header":
19-
initial_page = HierarchicalDocument(
20-
{
21-
"type": "Section-header",
22-
"bbox": (0, 0, 0, 0),
23-
"properties": {"score": 1, "page_number": 1},
24-
"text_representation": "Front Page",
25-
"binary_representation": b"Front Page",
26-
}
27-
)
28-
doc.children.insert(0, initial_page) # O(n) insert :( we should use deque for everything
29-
30-
if "relationships" not in doc.data:
31-
doc.data["relationships"] = {}
32-
if "label" not in doc.data:
33-
doc.data["label"] = "DOCUMENT"
34-
35-
sections = []
36-
37-
section: HierarchicalDocument = None
38-
element: HierarchicalDocument = None
39-
for child in doc.children:
40-
if "relationships" not in child.data:
41-
child.data["relationships"] = {}
42-
if (
43-
child.type == "Section-header"
44-
and "text_representation" in child.data
45-
and len(child.data["text_representation"]) > 0
46-
):
47-
if section is not None:
48-
next = {
49-
"TYPE": "NEXT",
50-
"properties": {},
51-
"START_ID": section.doc_id,
52-
"START_LABEL": "SECTION",
53-
"END_ID": child.doc_id,
54-
"END_LABEL": "SECTION",
55-
}
56-
child.data["relationships"][str(uuid.uuid4())] = next
57-
element = None
58-
rel = {
59-
"TYPE": "SECTION_OF",
60-
"properties": {},
61-
"START_ID": child.doc_id,
62-
"START_LABEL": "SECTION",
63-
"END_ID": doc.doc_id,
64-
"END_LABEL": "DOCUMENT",
65-
}
66-
child.data["relationships"][str(uuid.uuid4())] = rel
67-
child.data["label"] = "SECTION"
68-
section = child
69-
sections.append(section)
70-
else:
71-
if element is not None:
72-
next = {
73-
"TYPE": "NEXT",
74-
"properties": {},
75-
"START_ID": element.doc_id,
76-
"START_LABEL": "ELEMENT",
77-
"END_ID": child.doc_id,
78-
"END_LABEL": "ELEMENT",
79-
}
80-
child.data["relationships"][str(uuid.uuid4())] = next
81-
rel = {
82-
"TYPE": "PART_OF",
83-
"properties": {},
84-
"START_ID": child.doc_id,
85-
"START_LABEL": "ELEMENT",
86-
"END_ID": section.doc_id,
87-
"END_LABEL": "SECTION",
88-
}
89-
child.data["relationships"][str(uuid.uuid4())] = rel
90-
child.data["label"] = "ELEMENT"
91-
element = child
92-
section.data["children"].append(element)
93-
94-
doc.children = sections
95-
return doc
967

978
## actual test ##
989
path = str(TEST_DIR / "resources/data/pdfs/Ray_page11.pdf")
9910
context = sycamore.init()
10011
URI = "neo4j://localhost:7687"
101-
AUTH = ("neo4j", "koala-stereo-comedy-spray-figure-6974")
12+
AUTH = None
10213
DATABASE = "neo4j"
10314

10415
ds = (
@@ -107,8 +18,7 @@ def children_to_section(doc: HierarchicalDocument) -> HierarchicalDocument:
10718
partitioner=SycamorePartitioner(extract_table_structure=True, use_ocr=True, extract_images=True),
10819
num_gpus=0.2,
10920
)
110-
.map(restructure_doc)
111-
.map(children_to_section)
21+
.extract_graph_structure(extractors=[])
11222
.explode()
11323
)
11424

lib/sycamore/sycamore/tests/unit/transforms/test_graph_extractor.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,18 @@
11
from typing import Optional
22
import sycamore
3+
from sycamore.data.document import Document
4+
from sycamore.data.element import Element
35
from sycamore.llms.llms import LLM
46
from sycamore.reader import DocSetReader
57
from sycamore.transforms.extract_graph import GraphMetadata, MetadataExtractor, GraphEntity, EntityExtractor
8+
from sycamore.transforms.extract_graph import ExtractSummaries, ExtractDocumentStructure
69
from sycamore.data import HierarchicalDocument
710
from collections import defaultdict
811

12+
import logging
13+
14+
logger = logging.getLogger(__name__)
15+
916

1017
class TestGraphExtractor:
1118
metadata_docs = [
@@ -42,34 +49,38 @@ class TestGraphExtractor:
4249
]
4350

4451
entity_docs = [
45-
HierarchicalDocument(
52+
Document(
4653
{
4754
"doc_id": "1",
48-
"label": "Document",
4955
"type": "pdf",
50-
"relationships": {},
5156
"properties": {"company": "3M", "sector": "Industrial", "doctype": "10K"},
52-
"children": [
53-
HierarchicalDocument(
57+
"elements": [
58+
Element(
5459
{
55-
"doc_id": "2",
56-
"label": "Document",
57-
"type": "pdf",
58-
"relationships": {},
59-
"summary": "...",
60+
"type": "Section-header",
61+
"text_representation": "header-1",
6062
"properties": {},
61-
"children": [],
6263
}
6364
),
64-
HierarchicalDocument(
65+
Element(
6566
{
66-
"doc_id": "3",
67-
"label": "Document",
68-
"type": "pdf",
69-
"relationships": {},
70-
"summary": "...",
67+
"type": "text",
68+
"text_representation": "i'm text-1",
69+
"properties": {},
70+
}
71+
),
72+
Element(
73+
{
74+
"type": "Section-header",
75+
"text_representation": "header-2",
76+
"properties": {},
77+
}
78+
),
79+
Element(
80+
{
81+
"type": "text",
82+
"text_representation": "i'm text-2",
7183
"properties": {},
72-
"children": [],
7384
}
7485
),
7586
],
@@ -172,3 +183,37 @@ def test_entity_extractor(self):
172183
assert len(nested_dict["Company"]["Microsoft"]) == 2
173184
assert len(nested_dict["Company"]["Google"]) == 2
174185
assert len(nested_dict["Company"]["3M"]) == 2
186+
187+
def test_extract_document_structure(self):
188+
context = sycamore.init()
189+
reader = DocSetReader(context)
190+
ds = reader.document(self.entity_docs)
191+
192+
ds.plan = ExtractDocumentStructure(ds.plan)
193+
docs = ds.take_all()
194+
195+
for document in docs:
196+
assert document.data["label"] == "DOCUMENT"
197+
for section in document.children:
198+
assert section.data["label"] == "SECTION"
199+
for element in section.children:
200+
assert element.data["label"] == "ELEMENT"
201+
202+
def test_summarize_sections(self):
203+
context = sycamore.init()
204+
reader = DocSetReader(context)
205+
ds = reader.document(self.entity_docs)
206+
207+
ds.plan = ExtractDocumentStructure(ds.plan)
208+
ds.plan = ExtractSummaries(ds.plan)
209+
docs = ds.take_all()
210+
211+
summaries = [
212+
"-----SECTION TITLE: header-1-----\n---Element Type: text---\ni'm text-1\n",
213+
"-----SECTION TITLE: header-2-----\n---Element Type: text---\ni'm text-2\n",
214+
]
215+
216+
for document in docs:
217+
for index, section in enumerate(document.children):
218+
logger.warning(section.data["summary"])
219+
assert section.data["summary"] == summaries[index]

lib/sycamore/sycamore/transforms/extract_graph.py

Lines changed: 117 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from abc import ABC, abstractmethod
2-
from typing import TYPE_CHECKING, Awaitable, Dict, Any
2+
from typing import TYPE_CHECKING, Awaitable, Dict, Any, Optional
33
from sycamore.plan_nodes import Node
44
from sycamore.transforms.map import Map
55
from sycamore.data import Document, MetadataDocument, HierarchicalDocument
@@ -203,6 +203,7 @@ def extract(self, docset: "DocSet") -> "DocSet":
203203
"""
204204
Extracts entities from documents then creates a document in the docset where they are stored as nodes
205205
"""
206+
docset.plan = ExtractSummaries(docset.plan)
206207
docset.plan = ExtractFeatures(docset.plan, self)
207208
docset = self.resolve(docset)
208209
return docset
@@ -300,6 +301,121 @@ def GraphEntityExtractorPrompt(entities, query):
300301
Output:"""
301302

302303

304+
class ExtractDocumentStructure(Map):
305+
"""
306+
Extracts the structure of the document organizing document elements by their
307+
respective section headers.
308+
"""
309+
310+
def __init__(self, child: Node, **resource_args):
311+
super().__init__(child, f=ExtractDocumentStructure.structure_by_section, **resource_args)
312+
313+
@staticmethod
314+
def structure_by_section(doc: Document) -> HierarchicalDocument:
315+
import uuid
316+
317+
doc = HierarchicalDocument(doc.data)
318+
# if the first element is not a section header, insert generic placeholder
319+
if len(doc.children) > 0 and doc.children[0]["type"] != "Section-header":
320+
initial_page = HierarchicalDocument(
321+
{
322+
"type": "Section-header",
323+
"bbox": (0, 0, 0, 0),
324+
"properties": {"score": 1, "page_number": 1},
325+
"text_representation": "Front Page",
326+
"binary_representation": b"Front Page",
327+
}
328+
)
329+
doc.children.insert(0, initial_page) # O(n) insert :( we should use deque for everything
330+
331+
doc.data["relationships"] = doc.get("relationships", {})
332+
doc.data["label"] = doc.get("label", "DOCUMENT")
333+
334+
sections = []
335+
336+
section: Optional[HierarchicalDocument] = None
337+
element: Optional[HierarchicalDocument] = None
338+
for child in doc.children:
339+
child.data["relationships"] = child.get("relationships", {})
340+
if child.type == "Section-header" and child.data.get("text_representation"):
341+
if section is not None:
342+
next = {
343+
"TYPE": "NEXT",
344+
"properties": {},
345+
"START_ID": section.doc_id,
346+
"START_LABEL": "SECTION",
347+
"END_ID": child.doc_id,
348+
"END_LABEL": "SECTION",
349+
}
350+
child.data["relationships"][str(uuid.uuid4())] = next
351+
element = None
352+
rel = {
353+
"TYPE": "SECTION_OF",
354+
"properties": {},
355+
"START_ID": child.doc_id,
356+
"START_LABEL": "SECTION",
357+
"END_ID": doc.doc_id,
358+
"END_LABEL": "DOCUMENT",
359+
}
360+
child.data["relationships"][str(uuid.uuid4())] = rel
361+
child.data["label"] = "SECTION"
362+
section = child
363+
sections.append(section)
364+
else:
365+
assert section is not None
366+
if element is not None:
367+
next = {
368+
"TYPE": "NEXT",
369+
"properties": {},
370+
"START_ID": element.doc_id,
371+
"START_LABEL": "ELEMENT",
372+
"END_ID": child.doc_id,
373+
"END_LABEL": "ELEMENT",
374+
}
375+
child.data["relationships"][str(uuid.uuid4())] = next
376+
rel = {
377+
"TYPE": "PART_OF",
378+
"properties": {},
379+
"START_ID": child.doc_id,
380+
"START_LABEL": "ELEMENT",
381+
"END_ID": section.doc_id,
382+
"END_LABEL": "SECTION",
383+
}
384+
child.data["relationships"][str(uuid.uuid4())] = rel
385+
child.data["label"] = "ELEMENT"
386+
element = child
387+
section.data["children"].append(element)
388+
389+
doc.children = sections
390+
return doc
391+
392+
393+
class ExtractSummaries(Map):
394+
"""
395+
Extracts summaries from child documents to be used for entity extraction. This function
396+
generates summaries for sections within documents which are used during entity extraction.
397+
"""
398+
399+
def __init__(self, child: Node, **resource_args):
400+
super().__init__(child, f=ExtractSummaries.summarize_sections, **resource_args)
401+
402+
@staticmethod
403+
def summarize_sections(doc: HierarchicalDocument) -> HierarchicalDocument:
404+
if "EXTRACTED_NODES" in doc.data:
405+
return doc
406+
for section in doc.children:
407+
assert section.text_representation is not None
408+
summary = f"-----SECTION TITLE: {section.text_representation.strip()}-----\n"
409+
for element in section.children:
410+
if element.type == "table":
411+
element.text_representation = element.data["table"].to_csv()
412+
assert element.type is not None
413+
assert element.text_representation is not None
414+
summary += f"""---Element Type: {element.type.strip()}---\n{element.text_representation.strip()}\n"""
415+
section.data["summary"] = summary
416+
return doc
417+
418+
303419
class ExtractFeatures(Map):
304420
"""
305421
Extracts features determined by a specific extractor from each document

0 commit comments

Comments
 (0)