diff --git a/app/system/db/base.py b/app/system/db/base.py index 26cffedc6..88c621769 100644 --- a/app/system/db/base.py +++ b/app/system/db/base.py @@ -153,21 +153,29 @@ class QueryLog(Base): # pylint: disable=too-few-public-methods # deep dives +LLM_CHUNK_SIZE = 4000 +LLM_CHUNK_PADDING = 1000 + + class DeepDivePrompt(Base): # pylint: disable=too-few-public-methods __tablename__ = "deep_dive_prompt" id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.Text(), nullable=False) + name = sa.Column(sa.Text(), nullable=False, unique=True) main_prompt = sa.Column(sa.Text(), nullable=False) post_prompt = sa.Column(sa.Text(), nullable=True) categories = sa.Column(sa.Text(), nullable=True) + chunk_size = sa.Column( + sa.Integer, nullable=False, default=LLM_CHUNK_SIZE) + chunk_padding = sa.Column( + sa.Integer, nullable=False, default=LLM_CHUNK_PADDING) class DeepDiveProcess(Base): # pylint: disable=too-few-public-methods __tablename__ = "deep_dive_process" id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) - name = sa.Column(sa.Text(), nullable=False) + name = sa.Column(sa.Text(), nullable=False, unique=True) verify_id = sa.Column( sa.Integer, sa.ForeignKey( diff --git a/app/system/deepdive/collection.py b/app/system/deepdive/collection.py index 5348f4912..2bd9bec21 100644 --- a/app/system/deepdive/collection.py +++ b/app/system/deepdive/collection.py @@ -14,7 +14,7 @@ # You should have received a copy of the GNU General Public License # along with this program. If not, see . from collections.abc import Iterable -from typing import cast, get_args, Literal, TypedDict +from typing import cast, Literal, overload, TypedDict from uuid import UUID import sqlalchemy as sa @@ -24,24 +24,32 @@ from app.system.db.base import ( DeepDiveCollection, DeepDiveElement, + DeepDiveProcess, + DeepDivePrompt, DeepDiveSegment, ) from app.system.db.db import DBConnector from app.system.prep.snippify import snippify_text -DeepDiveName = Literal["circular_economy", "circular_economy_undp"] -DEEP_DIVE_NAMES: tuple[DeepDiveName] = get_args(DeepDiveName) +DeepDiveProcessRow = TypedDict('DeepDiveProcessRow', { + "verify": int, + "categories": int, +}) -LLM_CHUNK_SIZE = 4000 -LLM_CHUNK_PADDING = 2000 +DeepDiveSegmentationInfo = TypedDict('DeepDiveSegmentationInfo', { + "categories": list[str] | None, + "chunk_size": int, + "chunk_padding": int, +}) -def get_deep_dive_name(name: str) -> DeepDiveName: - if name not in DEEP_DIVE_NAMES: - raise ValueError(f"{name} is not a deep dive ({DEEP_DIVE_NAMES})") - return cast(DeepDiveName, name) +DeepDivePromptInfo = TypedDict('DeepDivePromptInfo', { + "main_prompt": str, + "post_prompt": str | None, + "categories": list[str] | None, +}) CollectionObj = TypedDict('CollectionObj', { @@ -65,13 +73,7 @@ def get_deep_dive_name(name: str) -> DeepDiveName: DeepDiveResult = TypedDict('DeepDiveResult', { "reason": str, - "cultural": int, - "economic": int, - "educational": int, - "institutional": int, - "legal": int, - "political": int, - "technological": int, + "values": dict[str, int], }) @@ -116,26 +118,77 @@ def get_deep_dive_name(name: str) -> DeepDiveName: }) -def get_deep_dive_keys(deep_dive: DeepDiveName) -> tuple[str, str]: - if deep_dive == "circular_economy": - return ("verify_circular_economy", "rate_circular_economy") - if deep_dive == "circular_economy_undp": - return ("verify_circular_economy_no_acclab", "rate_circular_economy") - raise ValueError(f"unknown {deep_dive=}") +def get_deep_dives(session: Session) -> list[str]: + stmt = sa.select(DeepDiveProcess.name) + return [row.name for row in session.execute(stmt)] + + +def get_process_id(session: Session, name: str) -> int: + stmt = sa.select(DeepDiveProcess.id).where(DeepDiveProcess.name == name) + stmt = stmt.limit(1) + process_id = session.execute(stmt).scalar() + if process_id is None: + raise ValueError(f"could not find deep dive for {name=}") + return int(process_id) + + +def get_deep_dive_process( + session: Session, process_id: int) -> DeepDiveProcessRow: + stmt = sa.select(DeepDiveProcess.verify_id, DeepDiveProcess.categories_id) + stmt = stmt.where(DeepDiveProcess.id == process_id) + stmt = stmt.limit(1) + for row in session.execute(stmt): + return { + "verify": int(row.verify_id), + "categories": int(row.categories_id), + } + raise ValueError(f"{process_id=} not found") + + +def get_deep_dive_segmentation_info( + session: Session, prompt_id: int) -> DeepDiveSegmentationInfo: + stmt = sa.select( + DeepDivePrompt.categories, + DeepDivePrompt.chunk_size, + DeepDivePrompt.chunk_padding) + stmt = stmt.where(DeepDivePrompt.id == prompt_id) + stmt = stmt.limit(1) + for row in session.execute(stmt): + return { + "categories": row.categories, + "chunk_size": row.chunk_size, + "chunk_padding": row.chunk_padding, + } + raise ValueError(f"{prompt_id=} not found") + + +def get_deep_dive_prompt_info( + session: Session, prompt_id: int) -> DeepDivePromptInfo: + stmt = sa.select( + DeepDivePrompt.main_prompt, + DeepDivePrompt.post_prompt, + DeepDivePrompt.categories) + stmt = stmt.where(DeepDivePrompt.id == prompt_id) + stmt = stmt.limit(1) + for row in session.execute(stmt): + return { + "main_prompt": row.main_prompt, + "post_prompt": row.post_prompt, + "categories": row.categories, + } + raise ValueError(f"{prompt_id=} not found") def add_collection( db: DBConnector, user: UUID, name: str, - deep_dive: DeepDiveName) -> int: - verify_key, deep_dive_key = get_deep_dive_keys(deep_dive) + process_id: int) -> int: with db.get_session() as session: stmt = sa.insert(DeepDiveCollection).values( name=name, user=user, - verify_key=verify_key, - deep_dive_key=deep_dive_key) + process=process_id) stmt = stmt.returning(DeepDiveCollection.id) row_id = session.execute(stmt).scalar() if row_id is None: @@ -227,6 +280,32 @@ def add_documents( return res +@overload +def convert_deep_dive_result(ddr: dict[str, int | str]) -> DeepDiveResult: + ... + + +@overload +def convert_deep_dive_result(ddr: None) -> None: + ... + + +def convert_deep_dive_result( + ddr: dict[str, int | str] | None) -> DeepDiveResult | None: + if ddr is None: + return None + if "values" not in ddr: + return { + "reason": cast(str, ddr["reason"]), + "values": { + key: int(value) + for key, value in ddr.items() + if key != "reason" + }, + } + return cast(DeepDiveResult, ddr) + + def get_documents( db: DBConnector, collection_id: int, @@ -265,7 +344,8 @@ def get_documents( "deep_dive_key": row.deep_dive_key, "is_valid": row.is_valid, "verify_reason": row.verify_reason, - "deep_dive_result": row.deep_dive_result, + "deep_dive_result": convert_deep_dive_result( + row.deep_dive_result), "error": row.error, "tag": row.tag, "tag_reason": row.tag_reason, @@ -450,14 +530,19 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]: "deep_dive_key": row.deep_dive_key, "is_valid": row.is_valid, "verify_reason": row.verify_reason, - "deep_dive_result": row.deep_dive_result, + "deep_dive_result": convert_deep_dive_result( + row.deep_dive_result), "error": row.error, "tag": row.tag, "tag_reason": row.tag_reason, } -def add_segments(db: DBConnector, doc: DocumentObj, full_text: str) -> int: +def add_segments( + db: DBConnector, + doc: DocumentObj, + full_text: str, + segmentation_info: DeepDiveSegmentationInfo) -> int: page = 0 with db.get_session() as session: collection_id = doc["deep_dive"] @@ -465,8 +550,8 @@ def add_segments(db: DBConnector, doc: DocumentObj, full_text: str) -> int: remove_segments(session, collection_id, [main_id]) for content, _ in snippify_text( full_text, - chunk_size=LLM_CHUNK_SIZE, - chunk_padding=LLM_CHUNK_PADDING): + chunk_size=segmentation_info["chunk_size"], + chunk_padding=segmentation_info["chunk_padding"]): stmt = sa.insert(DeepDiveSegment).values( main_id=main_id, page=page, @@ -517,7 +602,8 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]: "content": row.content, "is_valid": row.is_valid, "verify_reason": row.verify_reason, - "deep_dive_result": row.deep_dive_result, + "deep_dive_result": convert_deep_dive_result( + row.deep_dive_result), "error": row.error, } @@ -554,7 +640,7 @@ def get_segments( "content": row.content, "is_valid": row.is_valid, "verify_reason": row.verify_reason, - "deep_dive_result": row.deep_dive_result, + "deep_dive_result": convert_deep_dive_result(row.deep_dive_result), "error": row.error, } @@ -602,7 +688,9 @@ def remove_segments( def combine_segments( db: DBConnector, - doc: DocumentObj) -> Literal["empty", "incomplete", "done"]: + doc: DocumentObj, + categories: list[str], + ) -> Literal["empty", "incomplete", "done"]: with db.get_session() as session: collection_id = doc["deep_dive"] main_id = doc["main_id"] @@ -615,13 +703,10 @@ def combine_segments( verify_msg = "" results: DeepDiveResult = { "reason": "", - "cultural": 0, - "economic": 0, - "educational": 0, - "institutional": 0, - "legal": 0, - "political": 0, - "technological": 0, + "values": { + cat: 0 + for cat in categories + }, } for segment in get_segments(session, collection_id, main_id): no_segments = False @@ -651,13 +736,11 @@ def combine_segments( p_reason = results["reason"] results["reason"] = ( f"{p_reason}\n\n[{page=}]:\n{verify_reason}".lstrip()) - for key, prev in results.items(): - if key == "reason": - continue - prev_val: int = int(prev) # type: ignore - incoming_val: int = int(deep_dive_result[key]) # type: ignore + for key, prev in results["values"].items(): + prev_val: int = int(prev) + incoming_val: int = int(deep_dive_result["values"][key]) next_val = max(prev_val, incoming_val) - results[key] = next_val # type: ignore + results["values"][key] = next_val if no_segments: return "empty" if is_incomplete: @@ -674,13 +757,10 @@ def combine_segments( "reason": ( "Document did not pass filter! " "No interpretation performed!"), - "cultural": 0, - "economic": 0, - "educational": 0, - "institutional": 0, - "legal": 0, - "political": 0, - "technological": 0, + "values": { + cat: 0 + for cat in categories + }, } set_deep_dive(session, doc_id, results) if not is_error: @@ -712,4 +792,10 @@ def segment_stats(db: DBConnector) -> Iterable[SegmentStats]: def create_deep_dive_tables(db: DBConnector) -> None: - db.create_tables([DeepDiveCollection, DeepDiveElement, DeepDiveSegment]) + db.create_tables([ + DeepDivePrompt, + DeepDiveProcess, + DeepDiveCollection, + DeepDiveElement, + DeepDiveSegment, + ]) diff --git a/app/system/deepdive/diver.py b/app/system/deepdive/diver.py index 42c5d7515..ed45e4b5e 100644 --- a/app/system/deepdive/diver.py +++ b/app/system/deepdive/diver.py @@ -259,7 +259,7 @@ def process_segments( vres["is_hit"], vres["reason"]) else: - ddres, derror = interpret_deep_dive(text) + ddres, derror = interpret_deep_dive(text, categories) if ddres is None: derror = ( "" @@ -331,7 +331,9 @@ def interpret_verify(text: str) -> tuple[VerifyResult | None, str | None]: return (None, traceback.format_exc()) -def interpret_deep_dive(text: str) -> tuple[DeepDiveResult | None, str | None]: +def interpret_deep_dive( + text: str, + categories: list[str]) -> tuple[DeepDiveResult | None, str | None]: obj, error = parse_json(text) if obj is None: return (None, error) @@ -339,13 +341,10 @@ def interpret_deep_dive(text: str) -> tuple[DeepDiveResult | None, str | None]: return ( { "reason": f"{obj['reason']}", - "cultural": int(obj["cultural"]), - "economic": int(obj["economic"]), - "educational": int(obj["educational"]), - "institutional": int(obj["institutional"]), - "legal": int(obj["legal"]), - "political": int(obj["political"]), - "technological": int(obj["technological"]), + "values": { + key: int(obj[key]) + for key in categories + }, }, None, ) diff --git a/ui/src/api/api.ts b/ui/src/api/api.ts index 56b0f1f7d..89626ab23 100644 --- a/ui/src/api/api.ts +++ b/ui/src/api/api.ts @@ -286,7 +286,8 @@ export const DEFAULT_API: ApiProvider = { let reason: string | undefined = undefined; let scores: StatNumbers = {}; if (deep_dive_result) { - const { reason: reasonValue, ...scoresValue } = deep_dive_result; + const { reason: reasonValue, values: scoresValue } = + deep_dive_result; reason = reasonValue; scores = scoresValue; } diff --git a/ui/src/api/types.ts b/ui/src/api/types.ts index 80b30928a..bcc1707a9 100644 --- a/ui/src/api/types.ts +++ b/ui/src/api/types.ts @@ -128,13 +128,7 @@ export type CollectionOptions = { type DeepDiveResult = { reason: string; - cultural: number; - economic: number; - educational: number; - institutional: number; - legal: number; - political: number; - technological: number; + values: { [key: string]: number }; }; type ApiDocumentObj = {