Skip to content

Commit a55315f

Browse files
committed
update llm format
1 parent 8cb227c commit a55315f

File tree

5 files changed

+162
-74
lines changed

5 files changed

+162
-74
lines changed

app/system/db/base.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,21 +153,29 @@ class QueryLog(Base): # pylint: disable=too-few-public-methods
153153
# deep dives
154154

155155

156+
LLM_CHUNK_SIZE = 4000
157+
LLM_CHUNK_PADDING = 1000
158+
159+
156160
class DeepDivePrompt(Base): # pylint: disable=too-few-public-methods
157161
__tablename__ = "deep_dive_prompt"
158162

159163
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
160-
name = sa.Column(sa.Text(), nullable=False)
164+
name = sa.Column(sa.Text(), nullable=False, unique=True)
161165
main_prompt = sa.Column(sa.Text(), nullable=False)
162166
post_prompt = sa.Column(sa.Text(), nullable=True)
163167
categories = sa.Column(sa.Text(), nullable=True)
168+
chunk_size = sa.Column(
169+
sa.Integer, nullable=False, default=LLM_CHUNK_SIZE)
170+
chunk_padding = sa.Column(
171+
sa.Integer, nullable=False, default=LLM_CHUNK_PADDING)
164172

165173

166174
class DeepDiveProcess(Base): # pylint: disable=too-few-public-methods
167175
__tablename__ = "deep_dive_process"
168176

169177
id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
170-
name = sa.Column(sa.Text(), nullable=False)
178+
name = sa.Column(sa.Text(), nullable=False, unique=True)
171179
verify_id = sa.Column(
172180
sa.Integer,
173181
sa.ForeignKey(

app/system/deepdive/collection.py

Lines changed: 141 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
# You should have received a copy of the GNU General Public License
1515
# along with this program. If not, see <https://www.gnu.org/licenses/>.
1616
from collections.abc import Iterable
17-
from typing import cast, get_args, Literal, TypedDict
17+
from typing import cast, Literal, overload, TypedDict
1818
from uuid import UUID
1919

2020
import sqlalchemy as sa
@@ -24,24 +24,32 @@
2424
from app.system.db.base import (
2525
DeepDiveCollection,
2626
DeepDiveElement,
27+
DeepDiveProcess,
28+
DeepDivePrompt,
2729
DeepDiveSegment,
2830
)
2931
from app.system.db.db import DBConnector
3032
from app.system.prep.snippify import snippify_text
3133

3234

33-
DeepDiveName = Literal["circular_economy", "circular_economy_undp"]
34-
DEEP_DIVE_NAMES: tuple[DeepDiveName] = get_args(DeepDiveName)
35+
DeepDiveProcessRow = TypedDict('DeepDiveProcessRow', {
36+
"verify": int,
37+
"categories": int,
38+
})
3539

3640

37-
LLM_CHUNK_SIZE = 4000
38-
LLM_CHUNK_PADDING = 2000
41+
DeepDiveSegmentationInfo = TypedDict('DeepDiveSegmentationInfo', {
42+
"categories": list[str] | None,
43+
"chunk_size": int,
44+
"chunk_padding": int,
45+
})
3946

4047

41-
def get_deep_dive_name(name: str) -> DeepDiveName:
42-
if name not in DEEP_DIVE_NAMES:
43-
raise ValueError(f"{name} is not a deep dive ({DEEP_DIVE_NAMES})")
44-
return cast(DeepDiveName, name)
48+
DeepDivePromptInfo = TypedDict('DeepDivePromptInfo', {
49+
"main_prompt": str,
50+
"post_prompt": str | None,
51+
"categories": list[str] | None,
52+
})
4553

4654

4755
CollectionObj = TypedDict('CollectionObj', {
@@ -65,13 +73,7 @@ def get_deep_dive_name(name: str) -> DeepDiveName:
6573

6674
DeepDiveResult = TypedDict('DeepDiveResult', {
6775
"reason": str,
68-
"cultural": int,
69-
"economic": int,
70-
"educational": int,
71-
"institutional": int,
72-
"legal": int,
73-
"political": int,
74-
"technological": int,
76+
"values": dict[str, int],
7577
})
7678

7779

@@ -116,26 +118,77 @@ def get_deep_dive_name(name: str) -> DeepDiveName:
116118
})
117119

118120

119-
def get_deep_dive_keys(deep_dive: DeepDiveName) -> tuple[str, str]:
120-
if deep_dive == "circular_economy":
121-
return ("verify_circular_economy", "rate_circular_economy")
122-
if deep_dive == "circular_economy_undp":
123-
return ("verify_circular_economy_no_acclab", "rate_circular_economy")
124-
raise ValueError(f"unknown {deep_dive=}")
121+
def get_deep_dives(session: Session) -> list[str]:
122+
stmt = sa.select(DeepDiveProcess.name)
123+
return [row.name for row in session.execute(stmt)]
124+
125+
126+
def get_process_id(session: Session, name: str) -> int:
127+
stmt = sa.select(DeepDiveProcess.id).where(DeepDiveProcess.name == name)
128+
stmt = stmt.limit(1)
129+
process_id = session.execute(stmt).scalar()
130+
if process_id is None:
131+
raise ValueError(f"could not find deep dive for {name=}")
132+
return int(process_id)
133+
134+
135+
def get_deep_dive_process(
136+
session: Session, process_id: int) -> DeepDiveProcessRow:
137+
stmt = sa.select(DeepDiveProcess.verify_id, DeepDiveProcess.categories_id)
138+
stmt = stmt.where(DeepDiveProcess.id == process_id)
139+
stmt = stmt.limit(1)
140+
for row in session.execute(stmt):
141+
return {
142+
"verify": int(row.verify_id),
143+
"categories": int(row.categories_id),
144+
}
145+
raise ValueError(f"{process_id=} not found")
146+
147+
148+
def get_deep_dive_segmentation_info(
149+
session: Session, prompt_id: int) -> DeepDiveSegmentationInfo:
150+
stmt = sa.select(
151+
DeepDivePrompt.categories,
152+
DeepDivePrompt.chunk_size,
153+
DeepDivePrompt.chunk_padding)
154+
stmt = stmt.where(DeepDivePrompt.id == prompt_id)
155+
stmt = stmt.limit(1)
156+
for row in session.execute(stmt):
157+
return {
158+
"categories": row.categories,
159+
"chunk_size": row.chunk_size,
160+
"chunk_padding": row.chunk_padding,
161+
}
162+
raise ValueError(f"{prompt_id=} not found")
163+
164+
165+
def get_deep_dive_prompt_info(
166+
session: Session, prompt_id: int) -> DeepDivePromptInfo:
167+
stmt = sa.select(
168+
DeepDivePrompt.main_prompt,
169+
DeepDivePrompt.post_prompt,
170+
DeepDivePrompt.categories)
171+
stmt = stmt.where(DeepDivePrompt.id == prompt_id)
172+
stmt = stmt.limit(1)
173+
for row in session.execute(stmt):
174+
return {
175+
"main_prompt": row.main_prompt,
176+
"post_prompt": row.post_prompt,
177+
"categories": row.categories,
178+
}
179+
raise ValueError(f"{prompt_id=} not found")
125180

126181

127182
def add_collection(
128183
db: DBConnector,
129184
user: UUID,
130185
name: str,
131-
deep_dive: DeepDiveName) -> int:
132-
verify_key, deep_dive_key = get_deep_dive_keys(deep_dive)
186+
process_id: int) -> int:
133187
with db.get_session() as session:
134188
stmt = sa.insert(DeepDiveCollection).values(
135189
name=name,
136190
user=user,
137-
verify_key=verify_key,
138-
deep_dive_key=deep_dive_key)
191+
process=process_id)
139192
stmt = stmt.returning(DeepDiveCollection.id)
140193
row_id = session.execute(stmt).scalar()
141194
if row_id is None:
@@ -227,6 +280,32 @@ def add_documents(
227280
return res
228281

229282

283+
@overload
284+
def convert_deep_dive_result(ddr: dict[str, int | str]) -> DeepDiveResult:
285+
...
286+
287+
288+
@overload
289+
def convert_deep_dive_result(ddr: None) -> None:
290+
...
291+
292+
293+
def convert_deep_dive_result(
294+
ddr: dict[str, int | str] | None) -> DeepDiveResult | None:
295+
if ddr is None:
296+
return None
297+
if "values" not in ddr:
298+
return {
299+
"reason": cast(str, ddr["reason"]),
300+
"values": {
301+
key: int(value)
302+
for key, value in ddr.items()
303+
if key != "reason"
304+
},
305+
}
306+
return cast(DeepDiveResult, ddr)
307+
308+
230309
def get_documents(
231310
db: DBConnector,
232311
collection_id: int,
@@ -265,7 +344,8 @@ def get_documents(
265344
"deep_dive_key": row.deep_dive_key,
266345
"is_valid": row.is_valid,
267346
"verify_reason": row.verify_reason,
268-
"deep_dive_result": row.deep_dive_result,
347+
"deep_dive_result": convert_deep_dive_result(
348+
row.deep_dive_result),
269349
"error": row.error,
270350
"tag": row.tag,
271351
"tag_reason": row.tag_reason,
@@ -450,23 +530,28 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]:
450530
"deep_dive_key": row.deep_dive_key,
451531
"is_valid": row.is_valid,
452532
"verify_reason": row.verify_reason,
453-
"deep_dive_result": row.deep_dive_result,
533+
"deep_dive_result": convert_deep_dive_result(
534+
row.deep_dive_result),
454535
"error": row.error,
455536
"tag": row.tag,
456537
"tag_reason": row.tag_reason,
457538
}
458539

459540

460-
def add_segments(db: DBConnector, doc: DocumentObj, full_text: str) -> int:
541+
def add_segments(
542+
db: DBConnector,
543+
doc: DocumentObj,
544+
full_text: str,
545+
segmentation_info: DeepDiveSegmentationInfo) -> int:
461546
page = 0
462547
with db.get_session() as session:
463548
collection_id = doc["deep_dive"]
464549
main_id = doc["main_id"]
465550
remove_segments(session, collection_id, [main_id])
466551
for content, _ in snippify_text(
467552
full_text,
468-
chunk_size=LLM_CHUNK_SIZE,
469-
chunk_padding=LLM_CHUNK_PADDING):
553+
chunk_size=segmentation_info["chunk_size"],
554+
chunk_padding=segmentation_info["chunk_padding"]):
470555
stmt = sa.insert(DeepDiveSegment).values(
471556
main_id=main_id,
472557
page=page,
@@ -517,7 +602,8 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]:
517602
"content": row.content,
518603
"is_valid": row.is_valid,
519604
"verify_reason": row.verify_reason,
520-
"deep_dive_result": row.deep_dive_result,
605+
"deep_dive_result": convert_deep_dive_result(
606+
row.deep_dive_result),
521607
"error": row.error,
522608
}
523609

@@ -554,7 +640,7 @@ def get_segments(
554640
"content": row.content,
555641
"is_valid": row.is_valid,
556642
"verify_reason": row.verify_reason,
557-
"deep_dive_result": row.deep_dive_result,
643+
"deep_dive_result": convert_deep_dive_result(row.deep_dive_result),
558644
"error": row.error,
559645
}
560646

@@ -602,7 +688,9 @@ def remove_segments(
602688

603689
def combine_segments(
604690
db: DBConnector,
605-
doc: DocumentObj) -> Literal["empty", "incomplete", "done"]:
691+
doc: DocumentObj,
692+
categories: list[str],
693+
) -> Literal["empty", "incomplete", "done"]:
606694
with db.get_session() as session:
607695
collection_id = doc["deep_dive"]
608696
main_id = doc["main_id"]
@@ -615,13 +703,10 @@ def combine_segments(
615703
verify_msg = ""
616704
results: DeepDiveResult = {
617705
"reason": "",
618-
"cultural": 0,
619-
"economic": 0,
620-
"educational": 0,
621-
"institutional": 0,
622-
"legal": 0,
623-
"political": 0,
624-
"technological": 0,
706+
"values": {
707+
cat: 0
708+
for cat in categories
709+
},
625710
}
626711
for segment in get_segments(session, collection_id, main_id):
627712
no_segments = False
@@ -651,13 +736,11 @@ def combine_segments(
651736
p_reason = results["reason"]
652737
results["reason"] = (
653738
f"{p_reason}\n\n[{page=}]:\n{verify_reason}".lstrip())
654-
for key, prev in results.items():
655-
if key == "reason":
656-
continue
657-
prev_val: int = int(prev) # type: ignore
658-
incoming_val: int = int(deep_dive_result[key]) # type: ignore
739+
for key, prev in results["values"].items():
740+
prev_val: int = int(prev)
741+
incoming_val: int = int(deep_dive_result["values"][key])
659742
next_val = max(prev_val, incoming_val)
660-
results[key] = next_val # type: ignore
743+
results["values"][key] = next_val
661744
if no_segments:
662745
return "empty"
663746
if is_incomplete:
@@ -674,13 +757,10 @@ def combine_segments(
674757
"reason": (
675758
"Document did not pass filter! "
676759
"No interpretation performed!"),
677-
"cultural": 0,
678-
"economic": 0,
679-
"educational": 0,
680-
"institutional": 0,
681-
"legal": 0,
682-
"political": 0,
683-
"technological": 0,
760+
"values": {
761+
cat: 0
762+
for cat in categories
763+
},
684764
}
685765
set_deep_dive(session, doc_id, results)
686766
if not is_error:
@@ -712,4 +792,10 @@ def segment_stats(db: DBConnector) -> Iterable[SegmentStats]:
712792

713793

714794
def create_deep_dive_tables(db: DBConnector) -> None:
715-
db.create_tables([DeepDiveCollection, DeepDiveElement, DeepDiveSegment])
795+
db.create_tables([
796+
DeepDivePrompt,
797+
DeepDiveProcess,
798+
DeepDiveCollection,
799+
DeepDiveElement,
800+
DeepDiveSegment,
801+
])

app/system/deepdive/diver.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -259,7 +259,7 @@ def process_segments(
259259
vres["is_hit"],
260260
vres["reason"])
261261
else:
262-
ddres, derror = interpret_deep_dive(text)
262+
ddres, derror = interpret_deep_dive(text, categories)
263263
if ddres is None:
264264
derror = (
265265
""
@@ -331,21 +331,20 @@ def interpret_verify(text: str) -> tuple[VerifyResult | None, str | None]:
331331
return (None, traceback.format_exc())
332332

333333

334-
def interpret_deep_dive(text: str) -> tuple[DeepDiveResult | None, str | None]:
334+
def interpret_deep_dive(
335+
text: str,
336+
categories: list[str]) -> tuple[DeepDiveResult | None, str | None]:
335337
obj, error = parse_json(text)
336338
if obj is None:
337339
return (None, error)
338340
try:
339341
return (
340342
{
341343
"reason": f"{obj['reason']}",
342-
"cultural": int(obj["cultural"]),
343-
"economic": int(obj["economic"]),
344-
"educational": int(obj["educational"]),
345-
"institutional": int(obj["institutional"]),
346-
"legal": int(obj["legal"]),
347-
"political": int(obj["political"]),
348-
"technological": int(obj["technological"]),
344+
"values": {
345+
key: int(obj[key])
346+
for key in categories
347+
},
349348
},
350349
None,
351350
)

0 commit comments

Comments
 (0)