14
14
# You should have received a copy of the GNU General Public License
15
15
# along with this program. If not, see <https://www.gnu.org/licenses/>.
16
16
from collections .abc import Iterable
17
- from typing import cast , get_args , Literal , TypedDict
17
+ from typing import cast , Literal , overload , TypedDict
18
18
from uuid import UUID
19
19
20
20
import sqlalchemy as sa
24
24
from app .system .db .base import (
25
25
DeepDiveCollection ,
26
26
DeepDiveElement ,
27
+ DeepDiveProcess ,
28
+ DeepDivePrompt ,
27
29
DeepDiveSegment ,
28
30
)
29
31
from app .system .db .db import DBConnector
30
32
from app .system .prep .snippify import snippify_text
31
33
32
34
33
- DeepDiveName = Literal ["circular_economy" , "circular_economy_undp" ]
34
- DEEP_DIVE_NAMES : tuple [DeepDiveName ] = get_args (DeepDiveName )
35
+ DeepDiveProcessRow = TypedDict ('DeepDiveProcessRow' , {
36
+ "verify" : int ,
37
+ "categories" : int ,
38
+ })
35
39
36
40
37
- LLM_CHUNK_SIZE = 4000
38
- LLM_CHUNK_PADDING = 2000
41
+ DeepDiveSegmentationInfo = TypedDict ('DeepDiveSegmentationInfo' , {
42
+ "categories" : list [str ] | None ,
43
+ "chunk_size" : int ,
44
+ "chunk_padding" : int ,
45
+ })
39
46
40
47
41
- def get_deep_dive_name (name : str ) -> DeepDiveName :
42
- if name not in DEEP_DIVE_NAMES :
43
- raise ValueError (f"{ name } is not a deep dive ({ DEEP_DIVE_NAMES } )" )
44
- return cast (DeepDiveName , name )
48
+ DeepDivePromptInfo = TypedDict ('DeepDivePromptInfo' , {
49
+ "main_prompt" : str ,
50
+ "post_prompt" : str | None ,
51
+ "categories" : list [str ] | None ,
52
+ })
45
53
46
54
47
55
CollectionObj = TypedDict ('CollectionObj' , {
@@ -65,13 +73,7 @@ def get_deep_dive_name(name: str) -> DeepDiveName:
65
73
66
74
DeepDiveResult = TypedDict ('DeepDiveResult' , {
67
75
"reason" : str ,
68
- "cultural" : int ,
69
- "economic" : int ,
70
- "educational" : int ,
71
- "institutional" : int ,
72
- "legal" : int ,
73
- "political" : int ,
74
- "technological" : int ,
76
+ "values" : dict [str , int ],
75
77
})
76
78
77
79
@@ -116,26 +118,77 @@ def get_deep_dive_name(name: str) -> DeepDiveName:
116
118
})
117
119
118
120
119
- def get_deep_dive_keys (deep_dive : DeepDiveName ) -> tuple [str , str ]:
120
- if deep_dive == "circular_economy" :
121
- return ("verify_circular_economy" , "rate_circular_economy" )
122
- if deep_dive == "circular_economy_undp" :
123
- return ("verify_circular_economy_no_acclab" , "rate_circular_economy" )
124
- raise ValueError (f"unknown { deep_dive = } " )
121
+ def get_deep_dives (session : Session ) -> list [str ]:
122
+ stmt = sa .select (DeepDiveProcess .name )
123
+ return [row .name for row in session .execute (stmt )]
124
+
125
+
126
+ def get_process_id (session : Session , name : str ) -> int :
127
+ stmt = sa .select (DeepDiveProcess .id ).where (DeepDiveProcess .name == name )
128
+ stmt = stmt .limit (1 )
129
+ process_id = session .execute (stmt ).scalar ()
130
+ if process_id is None :
131
+ raise ValueError (f"could not find deep dive for { name = } " )
132
+ return int (process_id )
133
+
134
+
135
+ def get_deep_dive_process (
136
+ session : Session , process_id : int ) -> DeepDiveProcessRow :
137
+ stmt = sa .select (DeepDiveProcess .verify_id , DeepDiveProcess .categories_id )
138
+ stmt = stmt .where (DeepDiveProcess .id == process_id )
139
+ stmt = stmt .limit (1 )
140
+ for row in session .execute (stmt ):
141
+ return {
142
+ "verify" : int (row .verify_id ),
143
+ "categories" : int (row .categories_id ),
144
+ }
145
+ raise ValueError (f"{ process_id = } not found" )
146
+
147
+
148
+ def get_deep_dive_segmentation_info (
149
+ session : Session , prompt_id : int ) -> DeepDiveSegmentationInfo :
150
+ stmt = sa .select (
151
+ DeepDivePrompt .categories ,
152
+ DeepDivePrompt .chunk_size ,
153
+ DeepDivePrompt .chunk_padding )
154
+ stmt = stmt .where (DeepDivePrompt .id == prompt_id )
155
+ stmt = stmt .limit (1 )
156
+ for row in session .execute (stmt ):
157
+ return {
158
+ "categories" : row .categories ,
159
+ "chunk_size" : row .chunk_size ,
160
+ "chunk_padding" : row .chunk_padding ,
161
+ }
162
+ raise ValueError (f"{ prompt_id = } not found" )
163
+
164
+
165
+ def get_deep_dive_prompt_info (
166
+ session : Session , prompt_id : int ) -> DeepDivePromptInfo :
167
+ stmt = sa .select (
168
+ DeepDivePrompt .main_prompt ,
169
+ DeepDivePrompt .post_prompt ,
170
+ DeepDivePrompt .categories )
171
+ stmt = stmt .where (DeepDivePrompt .id == prompt_id )
172
+ stmt = stmt .limit (1 )
173
+ for row in session .execute (stmt ):
174
+ return {
175
+ "main_prompt" : row .main_prompt ,
176
+ "post_prompt" : row .post_prompt ,
177
+ "categories" : row .categories ,
178
+ }
179
+ raise ValueError (f"{ prompt_id = } not found" )
125
180
126
181
127
182
def add_collection (
128
183
db : DBConnector ,
129
184
user : UUID ,
130
185
name : str ,
131
- deep_dive : DeepDiveName ) -> int :
132
- verify_key , deep_dive_key = get_deep_dive_keys (deep_dive )
186
+ process_id : int ) -> int :
133
187
with db .get_session () as session :
134
188
stmt = sa .insert (DeepDiveCollection ).values (
135
189
name = name ,
136
190
user = user ,
137
- verify_key = verify_key ,
138
- deep_dive_key = deep_dive_key )
191
+ process = process_id )
139
192
stmt = stmt .returning (DeepDiveCollection .id )
140
193
row_id = session .execute (stmt ).scalar ()
141
194
if row_id is None :
@@ -227,6 +280,32 @@ def add_documents(
227
280
return res
228
281
229
282
283
+ @overload
284
+ def convert_deep_dive_result (ddr : dict [str , int | str ]) -> DeepDiveResult :
285
+ ...
286
+
287
+
288
+ @overload
289
+ def convert_deep_dive_result (ddr : None ) -> None :
290
+ ...
291
+
292
+
293
+ def convert_deep_dive_result (
294
+ ddr : dict [str , int | str ] | None ) -> DeepDiveResult | None :
295
+ if ddr is None :
296
+ return None
297
+ if "values" not in ddr :
298
+ return {
299
+ "reason" : cast (str , ddr ["reason" ]),
300
+ "values" : {
301
+ key : int (value )
302
+ for key , value in ddr .items ()
303
+ if key != "reason"
304
+ },
305
+ }
306
+ return cast (DeepDiveResult , ddr )
307
+
308
+
230
309
def get_documents (
231
310
db : DBConnector ,
232
311
collection_id : int ,
@@ -265,7 +344,8 @@ def get_documents(
265
344
"deep_dive_key" : row .deep_dive_key ,
266
345
"is_valid" : row .is_valid ,
267
346
"verify_reason" : row .verify_reason ,
268
- "deep_dive_result" : row .deep_dive_result ,
347
+ "deep_dive_result" : convert_deep_dive_result (
348
+ row .deep_dive_result ),
269
349
"error" : row .error ,
270
350
"tag" : row .tag ,
271
351
"tag_reason" : row .tag_reason ,
@@ -450,23 +530,28 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]:
450
530
"deep_dive_key" : row .deep_dive_key ,
451
531
"is_valid" : row .is_valid ,
452
532
"verify_reason" : row .verify_reason ,
453
- "deep_dive_result" : row .deep_dive_result ,
533
+ "deep_dive_result" : convert_deep_dive_result (
534
+ row .deep_dive_result ),
454
535
"error" : row .error ,
455
536
"tag" : row .tag ,
456
537
"tag_reason" : row .tag_reason ,
457
538
}
458
539
459
540
460
- def add_segments (db : DBConnector , doc : DocumentObj , full_text : str ) -> int :
541
+ def add_segments (
542
+ db : DBConnector ,
543
+ doc : DocumentObj ,
544
+ full_text : str ,
545
+ segmentation_info : DeepDiveSegmentationInfo ) -> int :
461
546
page = 0
462
547
with db .get_session () as session :
463
548
collection_id = doc ["deep_dive" ]
464
549
main_id = doc ["main_id" ]
465
550
remove_segments (session , collection_id , [main_id ])
466
551
for content , _ in snippify_text (
467
552
full_text ,
468
- chunk_size = LLM_CHUNK_SIZE ,
469
- chunk_padding = LLM_CHUNK_PADDING ):
553
+ chunk_size = segmentation_info [ "chunk_size" ] ,
554
+ chunk_padding = segmentation_info [ "chunk_padding" ] ):
470
555
stmt = sa .insert (DeepDiveSegment ).values (
471
556
main_id = main_id ,
472
557
page = page ,
@@ -517,7 +602,8 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]:
517
602
"content" : row .content ,
518
603
"is_valid" : row .is_valid ,
519
604
"verify_reason" : row .verify_reason ,
520
- "deep_dive_result" : row .deep_dive_result ,
605
+ "deep_dive_result" : convert_deep_dive_result (
606
+ row .deep_dive_result ),
521
607
"error" : row .error ,
522
608
}
523
609
@@ -554,7 +640,7 @@ def get_segments(
554
640
"content" : row .content ,
555
641
"is_valid" : row .is_valid ,
556
642
"verify_reason" : row .verify_reason ,
557
- "deep_dive_result" : row .deep_dive_result ,
643
+ "deep_dive_result" : convert_deep_dive_result ( row .deep_dive_result ) ,
558
644
"error" : row .error ,
559
645
}
560
646
@@ -602,7 +688,9 @@ def remove_segments(
602
688
603
689
def combine_segments (
604
690
db : DBConnector ,
605
- doc : DocumentObj ) -> Literal ["empty" , "incomplete" , "done" ]:
691
+ doc : DocumentObj ,
692
+ categories : list [str ],
693
+ ) -> Literal ["empty" , "incomplete" , "done" ]:
606
694
with db .get_session () as session :
607
695
collection_id = doc ["deep_dive" ]
608
696
main_id = doc ["main_id" ]
@@ -615,13 +703,10 @@ def combine_segments(
615
703
verify_msg = ""
616
704
results : DeepDiveResult = {
617
705
"reason" : "" ,
618
- "cultural" : 0 ,
619
- "economic" : 0 ,
620
- "educational" : 0 ,
621
- "institutional" : 0 ,
622
- "legal" : 0 ,
623
- "political" : 0 ,
624
- "technological" : 0 ,
706
+ "values" : {
707
+ cat : 0
708
+ for cat in categories
709
+ },
625
710
}
626
711
for segment in get_segments (session , collection_id , main_id ):
627
712
no_segments = False
@@ -651,13 +736,11 @@ def combine_segments(
651
736
p_reason = results ["reason" ]
652
737
results ["reason" ] = (
653
738
f"{ p_reason } \n \n [{ page = } ]:\n { verify_reason } " .lstrip ())
654
- for key , prev in results .items ():
655
- if key == "reason" :
656
- continue
657
- prev_val : int = int (prev ) # type: ignore
658
- incoming_val : int = int (deep_dive_result [key ]) # type: ignore
739
+ for key , prev in results ["values" ].items ():
740
+ prev_val : int = int (prev )
741
+ incoming_val : int = int (deep_dive_result ["values" ][key ])
659
742
next_val = max (prev_val , incoming_val )
660
- results [key ] = next_val # type: ignore
743
+ results ["values" ][ key ] = next_val
661
744
if no_segments :
662
745
return "empty"
663
746
if is_incomplete :
@@ -674,13 +757,10 @@ def combine_segments(
674
757
"reason" : (
675
758
"Document did not pass filter! "
676
759
"No interpretation performed!" ),
677
- "cultural" : 0 ,
678
- "economic" : 0 ,
679
- "educational" : 0 ,
680
- "institutional" : 0 ,
681
- "legal" : 0 ,
682
- "political" : 0 ,
683
- "technological" : 0 ,
760
+ "values" : {
761
+ cat : 0
762
+ for cat in categories
763
+ },
684
764
}
685
765
set_deep_dive (session , doc_id , results )
686
766
if not is_error :
@@ -712,4 +792,10 @@ def segment_stats(db: DBConnector) -> Iterable[SegmentStats]:
712
792
713
793
714
794
def create_deep_dive_tables (db : DBConnector ) -> None :
715
- db .create_tables ([DeepDiveCollection , DeepDiveElement , DeepDiveSegment ])
795
+ db .create_tables ([
796
+ DeepDivePrompt ,
797
+ DeepDiveProcess ,
798
+ DeepDiveCollection ,
799
+ DeepDiveElement ,
800
+ DeepDiveSegment ,
801
+ ])
0 commit comments