Skip to content

Commit

Permalink
update llm format
Browse files Browse the repository at this point in the history
  • Loading branch information
JosuaKrause committed Aug 7, 2024
1 parent 8cb227c commit a55315f
Show file tree
Hide file tree
Showing 5 changed files with 162 additions and 74 deletions.
12 changes: 10 additions & 2 deletions app/system/db/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,21 +153,29 @@ class QueryLog(Base): # pylint: disable=too-few-public-methods
# deep dives


LLM_CHUNK_SIZE = 4000
LLM_CHUNK_PADDING = 1000


class DeepDivePrompt(Base): # pylint: disable=too-few-public-methods
__tablename__ = "deep_dive_prompt"

id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
name = sa.Column(sa.Text(), nullable=False)
name = sa.Column(sa.Text(), nullable=False, unique=True)
main_prompt = sa.Column(sa.Text(), nullable=False)
post_prompt = sa.Column(sa.Text(), nullable=True)
categories = sa.Column(sa.Text(), nullable=True)
chunk_size = sa.Column(
sa.Integer, nullable=False, default=LLM_CHUNK_SIZE)
chunk_padding = sa.Column(
sa.Integer, nullable=False, default=LLM_CHUNK_PADDING)


class DeepDiveProcess(Base): # pylint: disable=too-few-public-methods
__tablename__ = "deep_dive_process"

id = sa.Column(sa.Integer, primary_key=True, autoincrement=True)
name = sa.Column(sa.Text(), nullable=False)
name = sa.Column(sa.Text(), nullable=False, unique=True)
verify_id = sa.Column(
sa.Integer,
sa.ForeignKey(
Expand Down
196 changes: 141 additions & 55 deletions app/system/deepdive/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
from collections.abc import Iterable
from typing import cast, get_args, Literal, TypedDict
from typing import cast, Literal, overload, TypedDict
from uuid import UUID

import sqlalchemy as sa
Expand All @@ -24,24 +24,32 @@
from app.system.db.base import (
DeepDiveCollection,
DeepDiveElement,
DeepDiveProcess,
DeepDivePrompt,
DeepDiveSegment,
)
from app.system.db.db import DBConnector
from app.system.prep.snippify import snippify_text


DeepDiveName = Literal["circular_economy", "circular_economy_undp"]
DEEP_DIVE_NAMES: tuple[DeepDiveName] = get_args(DeepDiveName)
DeepDiveProcessRow = TypedDict('DeepDiveProcessRow', {
"verify": int,
"categories": int,
})


LLM_CHUNK_SIZE = 4000
LLM_CHUNK_PADDING = 2000
DeepDiveSegmentationInfo = TypedDict('DeepDiveSegmentationInfo', {
"categories": list[str] | None,
"chunk_size": int,
"chunk_padding": int,
})


def get_deep_dive_name(name: str) -> DeepDiveName:
if name not in DEEP_DIVE_NAMES:
raise ValueError(f"{name} is not a deep dive ({DEEP_DIVE_NAMES})")
return cast(DeepDiveName, name)
DeepDivePromptInfo = TypedDict('DeepDivePromptInfo', {
"main_prompt": str,
"post_prompt": str | None,
"categories": list[str] | None,
})


CollectionObj = TypedDict('CollectionObj', {
Expand All @@ -65,13 +73,7 @@ def get_deep_dive_name(name: str) -> DeepDiveName:

DeepDiveResult = TypedDict('DeepDiveResult', {
"reason": str,
"cultural": int,
"economic": int,
"educational": int,
"institutional": int,
"legal": int,
"political": int,
"technological": int,
"values": dict[str, int],
})


Expand Down Expand Up @@ -116,26 +118,77 @@ def get_deep_dive_name(name: str) -> DeepDiveName:
})


def get_deep_dive_keys(deep_dive: DeepDiveName) -> tuple[str, str]:
if deep_dive == "circular_economy":
return ("verify_circular_economy", "rate_circular_economy")
if deep_dive == "circular_economy_undp":
return ("verify_circular_economy_no_acclab", "rate_circular_economy")
raise ValueError(f"unknown {deep_dive=}")
def get_deep_dives(session: Session) -> list[str]:
stmt = sa.select(DeepDiveProcess.name)
return [row.name for row in session.execute(stmt)]


def get_process_id(session: Session, name: str) -> int:
stmt = sa.select(DeepDiveProcess.id).where(DeepDiveProcess.name == name)
stmt = stmt.limit(1)
process_id = session.execute(stmt).scalar()
if process_id is None:
raise ValueError(f"could not find deep dive for {name=}")
return int(process_id)


def get_deep_dive_process(
session: Session, process_id: int) -> DeepDiveProcessRow:
stmt = sa.select(DeepDiveProcess.verify_id, DeepDiveProcess.categories_id)
stmt = stmt.where(DeepDiveProcess.id == process_id)
stmt = stmt.limit(1)
for row in session.execute(stmt):
return {
"verify": int(row.verify_id),
"categories": int(row.categories_id),
}
raise ValueError(f"{process_id=} not found")


def get_deep_dive_segmentation_info(
session: Session, prompt_id: int) -> DeepDiveSegmentationInfo:
stmt = sa.select(
DeepDivePrompt.categories,
DeepDivePrompt.chunk_size,
DeepDivePrompt.chunk_padding)
stmt = stmt.where(DeepDivePrompt.id == prompt_id)
stmt = stmt.limit(1)
for row in session.execute(stmt):
return {
"categories": row.categories,
"chunk_size": row.chunk_size,
"chunk_padding": row.chunk_padding,
}
raise ValueError(f"{prompt_id=} not found")


def get_deep_dive_prompt_info(
session: Session, prompt_id: int) -> DeepDivePromptInfo:
stmt = sa.select(
DeepDivePrompt.main_prompt,
DeepDivePrompt.post_prompt,
DeepDivePrompt.categories)
stmt = stmt.where(DeepDivePrompt.id == prompt_id)
stmt = stmt.limit(1)
for row in session.execute(stmt):
return {
"main_prompt": row.main_prompt,
"post_prompt": row.post_prompt,
"categories": row.categories,
}
raise ValueError(f"{prompt_id=} not found")


def add_collection(
db: DBConnector,
user: UUID,
name: str,
deep_dive: DeepDiveName) -> int:
verify_key, deep_dive_key = get_deep_dive_keys(deep_dive)
process_id: int) -> int:
with db.get_session() as session:
stmt = sa.insert(DeepDiveCollection).values(
name=name,
user=user,
verify_key=verify_key,
deep_dive_key=deep_dive_key)
process=process_id)
stmt = stmt.returning(DeepDiveCollection.id)
row_id = session.execute(stmt).scalar()
if row_id is None:
Expand Down Expand Up @@ -227,6 +280,32 @@ def add_documents(
return res


@overload
def convert_deep_dive_result(ddr: dict[str, int | str]) -> DeepDiveResult:
...


@overload
def convert_deep_dive_result(ddr: None) -> None:
...


def convert_deep_dive_result(
ddr: dict[str, int | str] | None) -> DeepDiveResult | None:
if ddr is None:
return None
if "values" not in ddr:
return {
"reason": cast(str, ddr["reason"]),
"values": {
key: int(value)
for key, value in ddr.items()
if key != "reason"
},
}
return cast(DeepDiveResult, ddr)


def get_documents(
db: DBConnector,
collection_id: int,
Expand Down Expand Up @@ -265,7 +344,8 @@ def get_documents(
"deep_dive_key": row.deep_dive_key,
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": row.deep_dive_result,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
"error": row.error,
"tag": row.tag,
"tag_reason": row.tag_reason,
Expand Down Expand Up @@ -450,23 +530,28 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]:
"deep_dive_key": row.deep_dive_key,
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": row.deep_dive_result,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
"error": row.error,
"tag": row.tag,
"tag_reason": row.tag_reason,
}


def add_segments(db: DBConnector, doc: DocumentObj, full_text: str) -> int:
def add_segments(
db: DBConnector,
doc: DocumentObj,
full_text: str,
segmentation_info: DeepDiveSegmentationInfo) -> int:
page = 0
with db.get_session() as session:
collection_id = doc["deep_dive"]
main_id = doc["main_id"]
remove_segments(session, collection_id, [main_id])
for content, _ in snippify_text(
full_text,
chunk_size=LLM_CHUNK_SIZE,
chunk_padding=LLM_CHUNK_PADDING):
chunk_size=segmentation_info["chunk_size"],
chunk_padding=segmentation_info["chunk_padding"]):
stmt = sa.insert(DeepDiveSegment).values(
main_id=main_id,
page=page,
Expand Down Expand Up @@ -517,7 +602,8 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]:
"content": row.content,
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": row.deep_dive_result,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
"error": row.error,
}

Expand Down Expand Up @@ -554,7 +640,7 @@ def get_segments(
"content": row.content,
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": row.deep_dive_result,
"deep_dive_result": convert_deep_dive_result(row.deep_dive_result),
"error": row.error,
}

Expand Down Expand Up @@ -602,7 +688,9 @@ def remove_segments(

def combine_segments(
db: DBConnector,
doc: DocumentObj) -> Literal["empty", "incomplete", "done"]:
doc: DocumentObj,
categories: list[str],
) -> Literal["empty", "incomplete", "done"]:
with db.get_session() as session:
collection_id = doc["deep_dive"]
main_id = doc["main_id"]
Expand All @@ -615,13 +703,10 @@ def combine_segments(
verify_msg = ""
results: DeepDiveResult = {
"reason": "",
"cultural": 0,
"economic": 0,
"educational": 0,
"institutional": 0,
"legal": 0,
"political": 0,
"technological": 0,
"values": {
cat: 0
for cat in categories
},
}
for segment in get_segments(session, collection_id, main_id):
no_segments = False
Expand Down Expand Up @@ -651,13 +736,11 @@ def combine_segments(
p_reason = results["reason"]
results["reason"] = (
f"{p_reason}\n\n[{page=}]:\n{verify_reason}".lstrip())
for key, prev in results.items():
if key == "reason":
continue
prev_val: int = int(prev) # type: ignore
incoming_val: int = int(deep_dive_result[key]) # type: ignore
for key, prev in results["values"].items():
prev_val: int = int(prev)
incoming_val: int = int(deep_dive_result["values"][key])
next_val = max(prev_val, incoming_val)
results[key] = next_val # type: ignore
results["values"][key] = next_val
if no_segments:
return "empty"
if is_incomplete:
Expand All @@ -674,13 +757,10 @@ def combine_segments(
"reason": (
"Document did not pass filter! "
"No interpretation performed!"),
"cultural": 0,
"economic": 0,
"educational": 0,
"institutional": 0,
"legal": 0,
"political": 0,
"technological": 0,
"values": {
cat: 0
for cat in categories
},
}
set_deep_dive(session, doc_id, results)
if not is_error:
Expand Down Expand Up @@ -712,4 +792,10 @@ def segment_stats(db: DBConnector) -> Iterable[SegmentStats]:


def create_deep_dive_tables(db: DBConnector) -> None:
db.create_tables([DeepDiveCollection, DeepDiveElement, DeepDiveSegment])
db.create_tables([
DeepDivePrompt,
DeepDiveProcess,
DeepDiveCollection,
DeepDiveElement,
DeepDiveSegment,
])
17 changes: 8 additions & 9 deletions app/system/deepdive/diver.py
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ def process_segments(
vres["is_hit"],
vres["reason"])
else:
ddres, derror = interpret_deep_dive(text)
ddres, derror = interpret_deep_dive(text, categories)
if ddres is None:
derror = (
""
Expand Down Expand Up @@ -331,21 +331,20 @@ def interpret_verify(text: str) -> tuple[VerifyResult | None, str | None]:
return (None, traceback.format_exc())


def interpret_deep_dive(text: str) -> tuple[DeepDiveResult | None, str | None]:
def interpret_deep_dive(
text: str,
categories: list[str]) -> tuple[DeepDiveResult | None, str | None]:
obj, error = parse_json(text)
if obj is None:
return (None, error)
try:
return (
{
"reason": f"{obj['reason']}",
"cultural": int(obj["cultural"]),
"economic": int(obj["economic"]),
"educational": int(obj["educational"]),
"institutional": int(obj["institutional"]),
"legal": int(obj["legal"]),
"political": int(obj["political"]),
"technological": int(obj["technological"]),
"values": {
key: int(obj[key])
for key in categories
},
},
None,
)
Expand Down
Loading

0 comments on commit a55315f

Please sign in to comment.