Skip to content

Commit

Permalink
cluster_args
Browse files Browse the repository at this point in the history
  • Loading branch information
JosuaKrause committed Aug 8, 2024
1 parent 736b1f4 commit 4f7e2bd
Show file tree
Hide file tree
Showing 10 changed files with 272 additions and 22 deletions.
70 changes: 69 additions & 1 deletion app/api/server.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,6 +236,13 @@ def add_vec_features(
articles_dict: dict[DBName, str] = {}

def init_vec_db() -> None:
"""
Asynchronously initializes the vector databases. This can be really
slow and it could happen that qdrant is not available for a while.
If any call times out this function will repeat trying to access the
databases. Once everything is connected properly, the process queue
is started.
"""
time.sleep(60.0) # NOTE: give qdrant plenty of time...
try:
tstart = time.monotonic()
Expand Down Expand Up @@ -288,11 +295,38 @@ def init_vec_db() -> None:
th.start()

def parse_vdb(vdb_str: str) -> DBName:
"""
Converts a string into the external database name type.
Args:
vdb_str (str): The string.
Raises:
ValueError: If the string is not a valid external vector database
name.
Returns:
DBName: The external vector database name.
"""
if vdb_str not in DBS:
raise ValueError(f"db ({vdb_str}) must be one of {DBS}")
return cast(DBName, vdb_str)

def get_articles(vdb_str: str) -> str:
"""
Converts an external vector database name into an internal vector
database name.
Args:
vdb_str (str): The external database name.
Raises:
ValueError: If the string is not a valid external vector database
name or the databases have not been loaded yet.
Returns:
str: The internal name for the given vector database.
"""
vdb = parse_vdb(vdb_str)
res = articles_dict.get(vdb)
if res:
Expand All @@ -304,12 +338,41 @@ def get_articles(vdb_str: str) -> str:
raise ValueError("vector database is not ready yet!")

def get_articles_dict() -> dict[DBName, str]:
"""
Retrieve all loaded vector databases.
Returns:
dict[DBName, str]: The external name mapped to the internal name.
"""
return dict(articles_dict)

@server.json_post(f"{prefix}/stats")
@server.middleware(verify_readonly)
@server.middleware(maybe_session)
def _post_stats(_req: QSRH, rargs: ReqArgs) -> StatEmbed:
"""
The `/api/stats` endpoint provides document counts for semantic search
queries. If the session cookie is not provided or invalid only public
documents are considered for the stats.
@readonly
@cookie (optional)
Args:
_req (QSRH): The request.
rargs (ReqArgs): The arguments.
POST
"fields": A set of field types expected to be returned.
"filters": A dictionary of field types to lists of filter
values. The date field, if given, expects a list of
exactly two values, the start and end date
(both inclusive). If the session cookie is missing or
invalid the "status" filter gets overwritten to
include "public" documents only.
"vecdb": The vector database.
Returns:
StatEmbed: Vector database document counts.
"""
session: SessionInfo | None = rargs["meta"].get("session")
args = rargs["post"]
fields = set(args["fields"])
Expand Down Expand Up @@ -961,7 +1024,12 @@ def _post_tags_create(_req: QSRH, rargs: ReqArgs) -> AddQueue:
name: str | None = args.get("name")
bases: list[str] = list(args["bases"])
is_updating = to_bool(args.get("is_updating", True))
tag_processor(name=name, bases=bases, is_updating=is_updating)
cluster_args = args.get("cluster_args", {})
tag_processor(
name=name,
bases=bases,
is_updating=is_updating,
cluster_args=cluster_args)
return {
"enqueued": True,
}
Expand Down
1 change: 1 addition & 0 deletions app/misc/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,4 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Miscellaneous helper functions."""
17 changes: 17 additions & 0 deletions app/misc/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,17 +13,34 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Helper function to determine the context of a given hit."""
import re


CONTEXT_SIZE = 20
"""The desired context size in characters for both directions."""
CONTEXT_MAX_EXPAND = 10
"""The maximum expansion over the desired context size."""
CONTEXT_END = re.compile(r"\b")
"""Regex to find a suitable end of a context."""
CONTEXT_START = re.compile(r"\b")
"""Regex to find a suitable start of a context."""
ELLIPSIS = "…"
"""The ellipsis character."""


def get_context(text: str, start: int, stop: int) -> str:
"""
Gets the context of the given hit.
Args:
text (str): The full text.
start (int): The hit start index.
stop (int): The hit end index.
Returns:
str: The hit with surrounding context.
"""
orig_start = start
orig_stop = stop
start = max(start - CONTEXT_SIZE, 0)
Expand Down
49 changes: 49 additions & 0 deletions app/misc/env.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""Handling environment variables."""
import os
from typing import Literal

Expand All @@ -25,6 +26,7 @@
"SMIND_CFG",
"UI_PATH",
]
"""Environment variables representing a file path or folder."""
EnvStr = Literal[
"APP_SECRET",
"BLOGS_DB_DIALECT",
Expand All @@ -50,17 +52,20 @@
"TANUKI",
"WRITE_TOKEN",
]
"""Environment variables representing a string."""
EnvInt = Literal[
"BLOGS_DB_PORT",
"LOGIN_DB_PORT",
"PORT",
"QDRANT_GRPC_PORT",
"QDRANT_REST_PORT",
]
"""Environment variables representing an integer."""
EnvBool = Literal[
"NO_QDRANT",
"HAS_LLAMA",
]
"""Environment variables representing a boolean value (true, false, 0, 1)."""


def _envload(key: str, default: str | None) -> str:
Expand All @@ -73,16 +78,60 @@ def _envload(key: str, default: str | None) -> str:


def envload_str(key: EnvStr, *, default: str | None = None) -> str:
"""
Loads a string environment variable.
Args:
key (EnvStr): The variable name.
default (str | None, optional): The default value. If None, the
environment variable is mandatory. Defaults to None.
Returns:
str: The value.
"""
return _envload(key, default)


def envload_path(key: EnvPath, *, default: str | None = None) -> str:
"""
Loads a path or folder environment variable.
Args:
key (EnvPath): The variable name.
default (str | None, optional): The default value. If None, the
environment variable is mandatory. Defaults to None.
Returns:
str: The value.
"""
return _envload(key, default)


def envload_int(key: EnvInt, *, default: int | None = None) -> int:
"""
Loads an integer environment variable.
Args:
key (EnvInt): The variable name.
default (int | None, optional): The default value. If None, the
environment variable is mandatory. Defaults to None.
Returns:
int: The value.
"""
return int(_envload(key, f"{default}"))


def envload_bool(key: EnvBool, *, default: bool | None = None) -> bool:
"""
Loads a boolean environment variable (0, 1, true, false).
Args:
key (EnvBool): The variable name.
default (bool | None, optional): The default value. If None, the
environment variable is mandatory. Defaults to None.
Returns:
bool: The value.
"""
return to_bool(_envload(key, f"{default}"))
74 changes: 74 additions & 0 deletions app/misc/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <https://www.gnu.org/licenses/>.
"""I/O helper functions that handle a slow disk (or network disk) gracefully.
"""
import contextlib
import errno
import io
Expand All @@ -26,11 +28,23 @@


MAIN_LOCK = threading.RLock()
"""Lock for coordinating the wait on start when the (network) disk is not
ready yet. Network disks can take a bit to get ready after a container is
started."""
STALE_FILE_RETRIES: list[float] = [0.1, 0.2, 0.5, 0.8, 1, 1.2, 1.5, 2, 3, 5]
"""Wait times for retrying reads on stale files."""
TMP_POSTFIX = ".~tmp"
"""Postfix for temporary files."""


def when_ready(fun: Callable[[], None]) -> None:
"""
Executes an I/O operation, retrying if the disk is not ready. After 120
retries (~2min) the function gives up and lets the error go through.
Args:
fun (Callable[[], None]): The I/O operation.
"""
with MAIN_LOCK:
counter = 0
while True:
Expand All @@ -46,6 +60,13 @@ def when_ready(fun: Callable[[], None]) -> None:


def fastrename(src: str, dst: str) -> None:
"""
Moves a file or folder. Source and destination cannot be the same.
Args:
src (str): The source file or folder.
dst (str): The destination file or folder.
"""
src = os.path.abspath(src)
dst = os.path.abspath(dst)
if src == dst:
Expand All @@ -71,10 +92,26 @@ def fastrename(src: str, dst: str) -> None:


def copy_file(from_file: str, to_file: str) -> None:
"""
Copies a file to a new destination.
Args:
from_file (str): The source file.
to_file (str): The destination file.
"""
shutil.copy(from_file, to_file)


def normalize_folder(folder: str) -> str:
"""
Makes the path absolute and ensures that the folder exists.
Args:
folder (str): The folder.
Returns:
str: The absolute path.
"""
res = os.path.abspath(folder)
when_ready(lambda: os.makedirs(res, mode=0o777, exist_ok=True))
if not os.path.isdir(res):
Expand All @@ -83,16 +120,44 @@ def normalize_folder(folder: str) -> str:


def normalize_file(fname: str) -> str:
"""
Makes the path absolute and ensures that the parent folder exists.
Args:
fname (str): The file.
Returns:
str: The absolute path.
"""
res = os.path.abspath(fname)
normalize_folder(os.path.dirname(res))
return res


def get_mode(base: str, text: bool) -> str:
"""
Creates a mode string for the `open` function.
Args:
base (str): The base mode string.
text (bool): Whether it is a text file.
Returns:
str: The mode string.
"""
return f"{base}{'' if text else 'b'}"


def is_empty_file(fin: IO[Any]) -> bool:
"""
Cheecks whether the given file is empty.
Args:
fin (IO[Any]): The file handle.
Returns:
bool: True, if the file is empty.
"""
pos = fin.seek(0, io.SEEK_CUR)
size = fin.seek(0, io.SEEK_END) - pos
fin.seek(pos, io.SEEK_SET)
Expand All @@ -110,6 +175,15 @@ def ensure_folder(folder: None) -> None:


def ensure_folder(folder: str | None) -> str | None:
"""
Ensures that the given folder exists.
Args:
folder (str | None): The folder name or None.
Returns:
str | None: The folder name or None.
"""
if folder is not None and not os.path.exists(folder):
a_folder: str = folder
when_ready(lambda: os.makedirs(a_folder, mode=0o777, exist_ok=True))
Expand Down
Loading

0 comments on commit 4f7e2bd

Please sign in to comment.