Skip to content

Commit 736b1f4

Browse files
committed
more general
1 parent a55315f commit 736b1f4

File tree

9 files changed

+128
-92
lines changed

9 files changed

+128
-92
lines changed

app/api/response_types.py

Lines changed: 2 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,7 @@
1616
from typing import TypedDict
1717

1818
from app.system.autotag.autotag import TagClusterEntry
19-
from app.system.deepdive.collection import (
20-
DeepDiveName,
21-
DocumentObj,
22-
SegmentStats,
23-
)
19+
from app.system.deepdive.collection import DocumentObj, SegmentStats
2420
from app.system.smind.api import QueueStat
2521
from app.system.smind.vec import DBName, VecDBStat
2622
from app.system.workqueues.queue import ProcessError, ProcessQueueStats
@@ -46,7 +42,7 @@
4642
"has_llm": bool,
4743
"vecdb_ready": bool,
4844
"vecdbs": list[DBName],
49-
"deepdives": list[DeepDiveName],
45+
"deepdives": list[str],
5046
"error": list[str] | None,
5147
})
5248
StatsResponse = TypedDict('StatsResponse', {

app/api/server.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,9 +86,7 @@
8686
add_collection,
8787
add_documents,
8888
CollectionOptions,
89-
DEEP_DIVE_NAMES,
9089
get_collections,
91-
get_deep_dive_name,
9290
get_documents,
9391
requeue,
9492
requeue_error,

app/system/deepdive/collection.py

Lines changed: 31 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -281,20 +281,37 @@ def add_documents(
281281

282282

283283
@overload
284-
def convert_deep_dive_result(ddr: dict[str, int | str]) -> DeepDiveResult:
284+
def convert_deep_dive_result(
285+
ddr: dict[str, int | str],
286+
*,
287+
categories: list[str] | None) -> DeepDiveResult:
285288
...
286289

287290

288291
@overload
289-
def convert_deep_dive_result(ddr: None) -> None:
292+
def convert_deep_dive_result(
293+
ddr: None, *, categories: list[str] | None) -> None:
290294
...
291295

292296

293297
def convert_deep_dive_result(
294-
ddr: dict[str, int | str] | None) -> DeepDiveResult | None:
298+
ddr: dict[str, int | str] | None,
299+
*,
300+
categories: list[str] | None) -> DeepDiveResult | None:
295301
if ddr is None:
296302
return None
297303
if "values" not in ddr:
304+
if categories is not None:
305+
if "reason" in categories:
306+
raise ValueError(
307+
"must use 'values' key when 'reason' is a category")
308+
return {
309+
"reason": cast(str, ddr["reason"]),
310+
"values": {
311+
cat: int(ddr[cat])
312+
for cat in categories
313+
},
314+
}
298315
return {
299316
"reason": cast(str, ddr["reason"]),
300317
"values": {
@@ -303,7 +320,12 @@ def convert_deep_dive_result(
303320
if key != "reason"
304321
},
305322
}
306-
return cast(DeepDiveResult, ddr)
323+
res = cast(DeepDiveResult, ddr)
324+
if categories is not None:
325+
missing = set(categories).difference(res["values"].keys())
326+
if missing:
327+
raise ValueError(f"categories {missing} are missing in {ddr}")
328+
return res
307329

308330

309331
def get_documents(
@@ -345,7 +367,7 @@ def get_documents(
345367
"is_valid": row.is_valid,
346368
"verify_reason": row.verify_reason,
347369
"deep_dive_result": convert_deep_dive_result(
348-
row.deep_dive_result),
370+
row.deep_dive_result, categories=None),
349371
"error": row.error,
350372
"tag": row.tag,
351373
"tag_reason": row.tag_reason,
@@ -531,7 +553,7 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]:
531553
"is_valid": row.is_valid,
532554
"verify_reason": row.verify_reason,
533555
"deep_dive_result": convert_deep_dive_result(
534-
row.deep_dive_result),
556+
row.deep_dive_result, categories=None),
535557
"error": row.error,
536558
"tag": row.tag,
537559
"tag_reason": row.tag_reason,
@@ -603,7 +625,7 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]:
603625
"is_valid": row.is_valid,
604626
"verify_reason": row.verify_reason,
605627
"deep_dive_result": convert_deep_dive_result(
606-
row.deep_dive_result),
628+
row.deep_dive_result, categories=None),
607629
"error": row.error,
608630
}
609631

@@ -640,7 +662,8 @@ def get_segments(
640662
"content": row.content,
641663
"is_valid": row.is_valid,
642664
"verify_reason": row.verify_reason,
643-
"deep_dive_result": convert_deep_dive_result(row.deep_dive_result),
665+
"deep_dive_result": convert_deep_dive_result(
666+
row.deep_dive_result, categories=None),
644667
"error": row.error,
645668
}
646669

app/system/deepdive/diver.py

Lines changed: 82 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,10 @@
1818
import threading
1919
import time
2020
import traceback
21+
from typing import Literal
2122

2223
from scattermind.api.api import ScattermindAPI
24+
from scattermind.system.names import GNamespace
2325
from scattermind.system.response import TASK_COMPLETE
2426
from scattermind.system.torch_util import tensor_to_str
2527

@@ -28,6 +30,8 @@
2830
from app.system.deepdive.collection import (
2931
add_segments,
3032
combine_segments,
33+
convert_deep_dive_result,
34+
DeepDivePromptInfo,
3135
DeepDiveResult,
3236
DocumentObj,
3337
get_documents_in_queue,
@@ -179,6 +183,8 @@ def process_segments(
179183
page = segment["page"]
180184
full_text = segment["content"]
181185
is_verify = segment["is_valid"] is None
186+
prompt_info: DeepDivePromptInfo = segment["prompt_info"]
187+
# FIXME
182188
if is_verify:
183189
sp_key = segment["verify_key"]
184190
elif segment["is_valid"] is True:
@@ -207,76 +213,92 @@ def process_segments(
207213
log_diver(
208214
f"processing segment {main_id}@{page} ({seg_id}): "
209215
f"llm ({sp_key}) size={len(full_text)}")
210-
task_id = smind.enqueue_task(
211-
ns,
212-
{
213-
"prompt": full_text,
214-
"system_prompt_key": sp_key,
215-
})
216-
try:
217-
for _, result in smind.wait_for([task_id], timeout=LLM_TIMEOUT):
218-
if result["status"] not in TASK_COMPLETE:
219-
log_diver(
220-
f"processing segment {main_id}@{page}: "
221-
f"llm timed out ({sp_key})")
216+
llm_out, llm_error = llm_response(smind, ns, full_text, prompt_info)
217+
if llm_out is not None:
218+
error_msg = (
219+
f"ERROR: could not interpret model output:\n{llm_out}")
220+
# prompt_info["categories"] is None is_verify
221+
if is_verify:
222+
vres, verror = interpret_verify(llm_out)
223+
if vres is None:
224+
verror = (
225+
""
226+
if verror is None
227+
else f"\nSTACKTRACE: {verror}")
222228
retry_err(
223229
set_error_segment,
224230
db,
225231
seg_id,
226-
f"llm timed out for {main_id}@{page}")
227-
continue
228-
res = result["result"]
229-
if res is None:
230-
log_diver(
231-
f"processing segment {main_id}@{page}: "
232-
f"llm error ({sp_key})")
232+
f"{error_msg}{verror}")
233+
else:
234+
retry_err(
235+
set_verify_segment,
236+
db,
237+
seg_id,
238+
vres["is_hit"],
239+
vres["reason"])
240+
else:
241+
ddres, derror = interpret_deep_dive(
242+
llm_out, prompt_info["categories"])
243+
if ddres is None:
244+
derror = (
245+
""
246+
if derror is None
247+
else f"\nSTACKTRACE: {derror}")
233248
retry_err(
234249
set_error_segment,
235250
db,
236251
seg_id,
237-
f"error in task: {result}")
238-
continue
239-
text = tensor_to_str(res["response"])
240-
error_msg = (
241-
f"ERROR: could not interpret model output:\n{text}")
242-
if is_verify:
243-
vres, verror = interpret_verify(text)
244-
if vres is None:
245-
verror = (
246-
""
247-
if verror is None
248-
else f"\nSTACKTRACE: {verror}")
249-
retry_err(
250-
set_error_segment,
251-
db,
252-
seg_id,
253-
f"{error_msg}{verror}")
254-
else:
255-
retry_err(
256-
set_verify_segment,
257-
db,
258-
seg_id,
259-
vres["is_hit"],
260-
vres["reason"])
252+
f"{error_msg}{derror}")
261253
else:
262-
ddres, derror = interpret_deep_dive(text, categories)
263-
if ddres is None:
264-
derror = (
265-
""
266-
if derror is None
267-
else f"\nSTACKTRACE: {derror}")
268-
retry_err(
269-
set_error_segment,
270-
db,
271-
seg_id,
272-
f"{error_msg}{derror}")
273-
else:
274-
retry_err(set_deep_dive_segment, db, seg_id, ddres)
275-
finally:
276-
smind.clear_task(task_id)
254+
retry_err(set_deep_dive_segment, db, seg_id, ddres)
255+
elif llm_error == "timeout":
256+
log_diver(
257+
f"processing segment {main_id}@{page}: "
258+
f"llm timed out ({sp_key})")
259+
retry_err(
260+
set_error_segment,
261+
db,
262+
seg_id,
263+
f"llm timed out for {main_id}@{page}")
264+
elif llm_error == "missing":
265+
log_diver(
266+
f"processing segment {main_id}@{page}: "
267+
f"llm error ({sp_key})")
268+
retry_err(
269+
set_error_segment,
270+
db,
271+
seg_id,
272+
f"error in task: {llm_out}")
273+
else:
274+
raise ValueError(f"unexpected error: {llm_out=} {llm_error=}")
277275
return len(segments)
278276

279277

278+
def llm_response(
279+
smind: ScattermindAPI,
280+
ns: GNamespace,
281+
full_text: str,
282+
prompt_info: DeepDivePromptInfo,
283+
) -> tuple[str | None, Literal["timeout", "missing", "okay"]]:
284+
task_id = smind.enqueue_task(
285+
ns,
286+
{
287+
"prompt": full_text,
288+
"main_prompt": prompt_info["main_prompt"],
289+
"post_prompt": prompt_info["post_prompt"],
290+
})
291+
for _, result in smind.wait_for(
292+
[task_id], timeout=LLM_TIMEOUT, auto_clear=True):
293+
if result["status"] not in TASK_COMPLETE:
294+
return (None, "timeout")
295+
res = result["result"]
296+
if res is None:
297+
return (None, "missing")
298+
return (tensor_to_str(res["response"]), "okay")
299+
return (None, "missing")
300+
301+
280302
LP = r"{"
281303
RP = r"}"
282304

@@ -338,15 +360,6 @@ def interpret_deep_dive(
338360
if obj is None:
339361
return (None, error)
340362
try:
341-
return (
342-
{
343-
"reason": f"{obj['reason']}",
344-
"values": {
345-
key: int(obj[key])
346-
for key in categories
347-
},
348-
},
349-
None,
350-
)
351-
except KeyError:
363+
return (convert_deep_dive_result(obj, categories=categories), None)
364+
except (KeyError, ValueError):
352365
return (None, traceback.format_exc())

ui/src/App.tsx

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,7 @@ class App extends PureComponent<AppProps, AppState> {
147147
userId: undefined,
148148
userName: undefined,
149149
isCollapsed: +(localStorage.getItem('menuCollapse') ?? 0) > 0,
150-
dbs: ['main', 'test', 'rave_ce'],
150+
dbs: [],
151151
};
152152
this.apiActions = new ApiActions(undefined);
153153

@@ -194,9 +194,11 @@ class App extends PureComponent<AppProps, AppState> {
194194
this.setState(
195195
{
196196
dbStart: true,
197+
dbs: JSON.parse(localStorage.getItem('pageLoadDbs') ?? '[]'),
197198
},
198199
() => {
199200
this.apiActions.vecDBs((vecdbs) => {
201+
localStorage.setItem('pageLoadDbs', JSON.stringify(vecdbs));
200202
this.setState({ dbReady: true, dbs: vecdbs });
201203
});
202204
},

ui/src/api/ApiActions.ts

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ export default class ApiActions {
8888

8989
async search(
9090
query: string,
91-
vecdb: DBName,
91+
vecdb: Readonly<DBName>,
9292
filters: SearchFilters,
9393
page: number,
9494
cb: ResultCallback,
@@ -132,7 +132,11 @@ export default class ApiActions {
132132
});
133133
}
134134

135-
async stats(vecdb: DBName, filters: SearchFilters, cb: StatCallback) {
135+
async stats(
136+
vecdb: Readonly<DBName>,
137+
filters: SearchFilters,
138+
cb: StatCallback,
139+
) {
136140
this.statNum += 1;
137141
const statNum = this.statNum;
138142
const { doc_count, fields } = await this.api.stats(

ui/src/api/types.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,8 +15,8 @@
1515
* You should have received a copy of the GNU General Public License
1616
* along with this program. If not, see <https://www.gnu.org/licenses/>.
1717
*/
18-
export type DBName = 'main' | 'test' | 'rave_ce';
19-
export type DeepDiveName = 'circular_economy' | 'circular_economy_undp';
18+
export type DBName = string & { _dbName: void };
19+
export type DeepDiveName = string & { _deepDiveName: void };
2020

2121
export type VersionResponse = {
2222
app_name: string;

ui/src/search/Search.tsx

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,7 +741,7 @@ class Search extends PureComponent<SearchProps, SearchState> {
741741
Database:{' '}
742742
<Select
743743
onChange={this.onDBChange}
744-
value={db}>
744+
value={`${db}`}>
745745
{dbs.map((db) => (
746746
<Option
747747
key={db}

ui/src/search/SearchStateSlice.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,7 @@ const searchStateSlice = createSlice<SearchState, SearchReducers, string>({
8181
setSearch: (state, action) => {
8282
const { db, query, filters, page } = action.payload;
8383
if (state.db !== db) {
84-
localStorage.setItem('vecdb', db);
84+
localStorage.setItem('vecdb', `${db}`);
8585
state.db = db;
8686
}
8787
state.query = query;

0 commit comments

Comments
 (0)