Skip to content

Commit

Permalink
more general
Browse files Browse the repository at this point in the history
  • Loading branch information
JosuaKrause committed Aug 7, 2024
1 parent a55315f commit 736b1f4
Show file tree
Hide file tree
Showing 9 changed files with 128 additions and 92 deletions.
8 changes: 2 additions & 6 deletions app/api/response_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,7 @@
from typing import TypedDict

from app.system.autotag.autotag import TagClusterEntry
from app.system.deepdive.collection import (
DeepDiveName,
DocumentObj,
SegmentStats,
)
from app.system.deepdive.collection import DocumentObj, SegmentStats
from app.system.smind.api import QueueStat
from app.system.smind.vec import DBName, VecDBStat
from app.system.workqueues.queue import ProcessError, ProcessQueueStats
Expand All @@ -46,7 +42,7 @@
"has_llm": bool,
"vecdb_ready": bool,
"vecdbs": list[DBName],
"deepdives": list[DeepDiveName],
"deepdives": list[str],
"error": list[str] | None,
})
StatsResponse = TypedDict('StatsResponse', {
Expand Down
2 changes: 0 additions & 2 deletions app/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -86,9 +86,7 @@
add_collection,
add_documents,
CollectionOptions,
DEEP_DIVE_NAMES,
get_collections,
get_deep_dive_name,
get_documents,
requeue,
requeue_error,
Expand Down
39 changes: 31 additions & 8 deletions app/system/deepdive/collection.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,20 +281,37 @@ def add_documents(


@overload
def convert_deep_dive_result(ddr: dict[str, int | str]) -> DeepDiveResult:
def convert_deep_dive_result(
ddr: dict[str, int | str],
*,
categories: list[str] | None) -> DeepDiveResult:
...


@overload
def convert_deep_dive_result(ddr: None) -> None:
def convert_deep_dive_result(
ddr: None, *, categories: list[str] | None) -> None:
...


def convert_deep_dive_result(
ddr: dict[str, int | str] | None) -> DeepDiveResult | None:
ddr: dict[str, int | str] | None,
*,
categories: list[str] | None) -> DeepDiveResult | None:
if ddr is None:
return None
if "values" not in ddr:
if categories is not None:
if "reason" in categories:
raise ValueError(
"must use 'values' key when 'reason' is a category")
return {
"reason": cast(str, ddr["reason"]),
"values": {
cat: int(ddr[cat])
for cat in categories
},
}
return {
"reason": cast(str, ddr["reason"]),
"values": {
Expand All @@ -303,7 +320,12 @@ def convert_deep_dive_result(
if key != "reason"
},
}
return cast(DeepDiveResult, ddr)
res = cast(DeepDiveResult, ddr)
if categories is not None:
missing = set(categories).difference(res["values"].keys())
if missing:
raise ValueError(f"categories {missing} are missing in {ddr}")
return res


def get_documents(
Expand Down Expand Up @@ -345,7 +367,7 @@ def get_documents(
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
row.deep_dive_result, categories=None),
"error": row.error,
"tag": row.tag,
"tag_reason": row.tag_reason,
Expand Down Expand Up @@ -531,7 +553,7 @@ def get_documents_in_queue(db: DBConnector) -> Iterable[DocumentObj]:
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
row.deep_dive_result, categories=None),
"error": row.error,
"tag": row.tag,
"tag_reason": row.tag_reason,
Expand Down Expand Up @@ -603,7 +625,7 @@ def get_segments_in_queue(db: DBConnector) -> Iterable[SegmentObj]:
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result),
row.deep_dive_result, categories=None),
"error": row.error,
}

Expand Down Expand Up @@ -640,7 +662,8 @@ def get_segments(
"content": row.content,
"is_valid": row.is_valid,
"verify_reason": row.verify_reason,
"deep_dive_result": convert_deep_dive_result(row.deep_dive_result),
"deep_dive_result": convert_deep_dive_result(
row.deep_dive_result, categories=None),
"error": row.error,
}

Expand Down
151 changes: 82 additions & 69 deletions app/system/deepdive/diver.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,10 @@
import threading
import time
import traceback
from typing import Literal

from scattermind.api.api import ScattermindAPI
from scattermind.system.names import GNamespace
from scattermind.system.response import TASK_COMPLETE
from scattermind.system.torch_util import tensor_to_str

Expand All @@ -28,6 +30,8 @@
from app.system.deepdive.collection import (
add_segments,
combine_segments,
convert_deep_dive_result,
DeepDivePromptInfo,
DeepDiveResult,
DocumentObj,
get_documents_in_queue,
Expand Down Expand Up @@ -179,6 +183,8 @@ def process_segments(
page = segment["page"]
full_text = segment["content"]
is_verify = segment["is_valid"] is None
prompt_info: DeepDivePromptInfo = segment["prompt_info"]
# FIXME
if is_verify:
sp_key = segment["verify_key"]
elif segment["is_valid"] is True:
Expand Down Expand Up @@ -207,76 +213,92 @@ def process_segments(
log_diver(
f"processing segment {main_id}@{page} ({seg_id}): "
f"llm ({sp_key}) size={len(full_text)}")
task_id = smind.enqueue_task(
ns,
{
"prompt": full_text,
"system_prompt_key": sp_key,
})
try:
for _, result in smind.wait_for([task_id], timeout=LLM_TIMEOUT):
if result["status"] not in TASK_COMPLETE:
log_diver(
f"processing segment {main_id}@{page}: "
f"llm timed out ({sp_key})")
llm_out, llm_error = llm_response(smind, ns, full_text, prompt_info)
if llm_out is not None:
error_msg = (
f"ERROR: could not interpret model output:\n{llm_out}")
# prompt_info["categories"] is None is_verify
if is_verify:
vres, verror = interpret_verify(llm_out)
if vres is None:
verror = (
""
if verror is None
else f"\nSTACKTRACE: {verror}")
retry_err(
set_error_segment,
db,
seg_id,
f"llm timed out for {main_id}@{page}")
continue
res = result["result"]
if res is None:
log_diver(
f"processing segment {main_id}@{page}: "
f"llm error ({sp_key})")
f"{error_msg}{verror}")
else:
retry_err(
set_verify_segment,
db,
seg_id,
vres["is_hit"],
vres["reason"])
else:
ddres, derror = interpret_deep_dive(
llm_out, prompt_info["categories"])
if ddres is None:
derror = (
""
if derror is None
else f"\nSTACKTRACE: {derror}")
retry_err(
set_error_segment,
db,
seg_id,
f"error in task: {result}")
continue
text = tensor_to_str(res["response"])
error_msg = (
f"ERROR: could not interpret model output:\n{text}")
if is_verify:
vres, verror = interpret_verify(text)
if vres is None:
verror = (
""
if verror is None
else f"\nSTACKTRACE: {verror}")
retry_err(
set_error_segment,
db,
seg_id,
f"{error_msg}{verror}")
else:
retry_err(
set_verify_segment,
db,
seg_id,
vres["is_hit"],
vres["reason"])
f"{error_msg}{derror}")
else:
ddres, derror = interpret_deep_dive(text, categories)
if ddres is None:
derror = (
""
if derror is None
else f"\nSTACKTRACE: {derror}")
retry_err(
set_error_segment,
db,
seg_id,
f"{error_msg}{derror}")
else:
retry_err(set_deep_dive_segment, db, seg_id, ddres)
finally:
smind.clear_task(task_id)
retry_err(set_deep_dive_segment, db, seg_id, ddres)
elif llm_error == "timeout":
log_diver(
f"processing segment {main_id}@{page}: "
f"llm timed out ({sp_key})")
retry_err(
set_error_segment,
db,
seg_id,
f"llm timed out for {main_id}@{page}")
elif llm_error == "missing":
log_diver(
f"processing segment {main_id}@{page}: "
f"llm error ({sp_key})")
retry_err(
set_error_segment,
db,
seg_id,
f"error in task: {llm_out}")
else:
raise ValueError(f"unexpected error: {llm_out=} {llm_error=}")
return len(segments)


def llm_response(
smind: ScattermindAPI,
ns: GNamespace,
full_text: str,
prompt_info: DeepDivePromptInfo,
) -> tuple[str | None, Literal["timeout", "missing", "okay"]]:
task_id = smind.enqueue_task(
ns,
{
"prompt": full_text,
"main_prompt": prompt_info["main_prompt"],
"post_prompt": prompt_info["post_prompt"],
})
for _, result in smind.wait_for(
[task_id], timeout=LLM_TIMEOUT, auto_clear=True):
if result["status"] not in TASK_COMPLETE:
return (None, "timeout")
res = result["result"]
if res is None:
return (None, "missing")
return (tensor_to_str(res["response"]), "okay")
return (None, "missing")


LP = r"{"
RP = r"}"

Expand Down Expand Up @@ -338,15 +360,6 @@ def interpret_deep_dive(
if obj is None:
return (None, error)
try:
return (
{
"reason": f"{obj['reason']}",
"values": {
key: int(obj[key])
for key in categories
},
},
None,
)
except KeyError:
return (convert_deep_dive_result(obj, categories=categories), None)
except (KeyError, ValueError):
return (None, traceback.format_exc())
4 changes: 3 additions & 1 deletion ui/src/App.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -147,7 +147,7 @@ class App extends PureComponent<AppProps, AppState> {
userId: undefined,
userName: undefined,
isCollapsed: +(localStorage.getItem('menuCollapse') ?? 0) > 0,
dbs: ['main', 'test', 'rave_ce'],
dbs: [],
};
this.apiActions = new ApiActions(undefined);

Expand Down Expand Up @@ -194,9 +194,11 @@ class App extends PureComponent<AppProps, AppState> {
this.setState(
{
dbStart: true,
dbs: JSON.parse(localStorage.getItem('pageLoadDbs') ?? '[]'),
},
() => {
this.apiActions.vecDBs((vecdbs) => {
localStorage.setItem('pageLoadDbs', JSON.stringify(vecdbs));
this.setState({ dbReady: true, dbs: vecdbs });
});
},
Expand Down
8 changes: 6 additions & 2 deletions ui/src/api/ApiActions.ts
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ export default class ApiActions {

async search(
query: string,
vecdb: DBName,
vecdb: Readonly<DBName>,
filters: SearchFilters,
page: number,
cb: ResultCallback,
Expand Down Expand Up @@ -132,7 +132,11 @@ export default class ApiActions {
});
}

async stats(vecdb: DBName, filters: SearchFilters, cb: StatCallback) {
async stats(
vecdb: Readonly<DBName>,
filters: SearchFilters,
cb: StatCallback,
) {
this.statNum += 1;
const statNum = this.statNum;
const { doc_count, fields } = await this.api.stats(
Expand Down
4 changes: 2 additions & 2 deletions ui/src/api/types.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,8 @@
* You should have received a copy of the GNU General Public License
* along with this program. If not, see <https://www.gnu.org/licenses/>.
*/
export type DBName = 'main' | 'test' | 'rave_ce';
export type DeepDiveName = 'circular_economy' | 'circular_economy_undp';
export type DBName = string & { _dbName: void };
export type DeepDiveName = string & { _deepDiveName: void };

export type VersionResponse = {
app_name: string;
Expand Down
2 changes: 1 addition & 1 deletion ui/src/search/Search.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -741,7 +741,7 @@ class Search extends PureComponent<SearchProps, SearchState> {
Database:{' '}
<Select
onChange={this.onDBChange}
value={db}>
value={`${db}`}>
{dbs.map((db) => (
<Option
key={db}
Expand Down
2 changes: 1 addition & 1 deletion ui/src/search/SearchStateSlice.ts
Original file line number Diff line number Diff line change
Expand Up @@ -81,7 +81,7 @@ const searchStateSlice = createSlice<SearchState, SearchReducers, string>({
setSearch: (state, action) => {
const { db, query, filters, page } = action.payload;
if (state.db !== db) {
localStorage.setItem('vecdb', db);
localStorage.setItem('vecdb', `${db}`);
state.db = db;
}
state.query = query;
Expand Down

0 comments on commit 736b1f4

Please sign in to comment.