Skip to content

Commit

Permalink
feat: sv_loopkup endpoint, remesh query param, reuse initial meshes f…
Browse files Browse the repository at this point in the history
…or clone pcgs
  • Loading branch information
akhileshh committed Jan 28, 2025
1 parent 0a3232a commit 69cd97c
Show file tree
Hide file tree
Showing 4 changed files with 66 additions and 21 deletions.
28 changes: 20 additions & 8 deletions pychunkedgraph/app/segmentation/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -189,11 +189,13 @@ def handle_find_minimal_covering_nodes(table_id, is_binary=True):
): # Process from higher layers to lower layers
if len(node_queue[layer]) == 0:
continue

current_nodes = list(node_queue[layer])

# Call handle_roots to find parents
parents = cg.get_roots(current_nodes, stop_layer=layer + 1, time_stamp=timestamp)
parents = cg.get_roots(
current_nodes, stop_layer=layer + 1, time_stamp=timestamp
)
unique_parents = np.unique(parents)
parent_layers = np.array(
[cg.get_chunk_layer(parent) for parent in unique_parents]
Expand Down Expand Up @@ -312,7 +314,11 @@ def str2bool(v):


def publish_edit(
table_id: str, user_id: str, result: GraphEditOperation.Result, is_priority=True
table_id: str,
user_id: str,
result: GraphEditOperation.Result,
is_priority=True,
remesh: bool = True,
):
import pickle

Expand All @@ -322,6 +328,7 @@ def publish_edit(
"table_id": table_id,
"user_id": user_id,
"remesh_priority": "true" if is_priority else "false",
"remesh": "true" if remesh else "false",
}
payload = {
"operation_id": int(result.operation_id),
Expand All @@ -343,6 +350,7 @@ def handle_merge(table_id, allow_same_segment_merge=False):

nodes = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
chebyshev_distance = request.args.get("chebyshev_distance", 3, type=int)

current_app.logger.debug(nodes)
Expand Down Expand Up @@ -391,7 +399,7 @@ def handle_merge(table_id, allow_same_segment_merge=False):
current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))

if len(ret.new_lvl2_ids) > 0:
publish_edit(table_id, user_id, ret, is_priority=is_priority)
publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh)

return ret

Expand All @@ -405,6 +413,7 @@ def handle_split(table_id):

data = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
mincut = request.args.get("mincut", True, type=str2bool)

current_app.logger.debug(data)
Expand Down Expand Up @@ -457,7 +466,7 @@ def handle_split(table_id):
current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))

if len(ret.new_lvl2_ids) > 0:
publish_edit(table_id, user_id, ret, is_priority=is_priority)
publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh)

return ret

Expand All @@ -470,6 +479,7 @@ def handle_undo(table_id):

data = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
user_id = str(g.auth_user.get("id", current_app.user_id))

current_app.logger.debug(data)
Expand All @@ -489,7 +499,7 @@ def handle_undo(table_id):
current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))

if ret.new_lvl2_ids.size > 0:
publish_edit(table_id, user_id, ret, is_priority=is_priority)
publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh)

return ret

Expand All @@ -502,6 +512,7 @@ def handle_redo(table_id):

data = json.loads(request.data)
is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
user_id = str(g.auth_user.get("id", current_app.user_id))

current_app.logger.debug(data)
Expand All @@ -521,7 +532,7 @@ def handle_redo(table_id):
current_app.logger.debug(("lvl2_nodes:", ret.new_lvl2_ids))

if ret.new_lvl2_ids.size > 0:
publish_edit(table_id, user_id, ret, is_priority=is_priority)
publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh)

return ret

Expand All @@ -536,6 +547,7 @@ def handle_rollback(table_id):
target_user_id = request.args["user_id"]

is_priority = request.args.get("priority", True, type=str2bool)
remesh = request.args.get("remesh", True, type=str2bool)
skip_operation_ids = np.array(
json.loads(request.args.get("skip_operation_ids", "[]")), dtype=np.uint64
)
Expand All @@ -562,7 +574,7 @@ def handle_rollback(table_id):
raise cg_exceptions.BadRequest(str(e))

if ret.new_lvl2_ids.size > 0:
publish_edit(table_id, user_id, ret, is_priority=is_priority)
publish_edit(table_id, user_id, ret, is_priority=is_priority, remesh=remesh)

return user_operations

Expand Down
19 changes: 19 additions & 0 deletions pychunkedgraph/app/segmentation/v1/routes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
)

from pychunkedgraph.app import common as app_common
from pychunkedgraph.app import app_utils
from pychunkedgraph.app.app_utils import (
jsonify_with_kwargs,
remap_public,
Expand Down Expand Up @@ -626,3 +627,21 @@ def valid_nodes(table_id):
resp = common.valid_nodes(table_id, is_binary=is_binary)

return jsonify_with_kwargs(resp, int64_as_str=int64_as_str)


@bp.route("/table/<table_id>/supervoxel_lookup", methods=["POST"])
@auth_requires_permission("admin")
@remap_public(edit=False)
def handle_supervoxel_lookup(table_id):
int64_as_str = request.args.get("int64_as_str", default=False, type=toboolean)

nodes = json.loads(request.data)
cg = app_utils.get_cg(table_id)
node_ids = []
coords = []
for node in nodes:
node_ids.append(node[0])
coords.append(np.array(node[1:]) / cg.segmentation_resolution)

atomic_ids = app_utils.handle_supervoxel_id_lookup(cg, coords, node_ids)
return jsonify_with_kwargs(atomic_ids, int64_as_str=int64_as_str)
19 changes: 17 additions & 2 deletions pychunkedgraph/meshing/meshing_batch.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,12 @@
import argparse, os
import numpy as np
from cloudvolume import CloudVolume
from cloudfiles import CloudFiles
from taskqueue import TaskQueue, LocalTaskQueue
import argparse

from pychunkedgraph.graph.chunkedgraph import ChunkedGraph # noqa
import numpy as np
from pychunkedgraph.meshing.meshing_sqs import MeshTask
from pychunkedgraph.meshing import meshgen_utils # noqa

if __name__ == "__main__":
parser = argparse.ArgumentParser()
Expand All @@ -13,11 +17,22 @@
parser.add_argument('--layer', type=int)
parser.add_argument('--mip', type=int)
parser.add_argument('--skip_cache', action='store_true')
parser.add_argument('--overwrite', type=bool, default=False)

args = parser.parse_args()
cache = not args.skip_cache

cg = ChunkedGraph(graph_id=args.cg_name)
cv = CloudVolume(
f"graphene://https://localhost/segmentation/table/dummy",
info=meshgen_utils.get_json_info(cg),
)
dst = os.path.join(
cv.cloudpath, cv.mesh.meta.mesh_path, "initial", str(args.layer)
)
cf = CloudFiles(dst)
if len(list(cf.list())) > 0 and not args.overwrite:
raise ValueError(f"Destination {dst} is not empty. Use `--overwrite true` to proceed anyway.")

chunks_arr = []
for x in range(args.chunk_start[0],args.chunk_end[0]):
Expand Down
21 changes: 10 additions & 11 deletions workers/mesh_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,10 @@ def callback(payload):
op_id = int(data["operation_id"])
l2ids = np.array(data["new_lvl2_ids"], dtype=basetypes.NODE_ID)
table_id = payload.attributes["table_id"]
remesh = payload.attributes["remesh"]

if remesh == "false":
return

try:
cg = PCG_CACHE[table_id]
Expand All @@ -37,9 +41,12 @@ def callback(payload):
)

try:
mesh_dir = cg.meta.dataset_info["mesh"]
mesh_meta = cg.meta.dataset_info["mesh_metadata"]
cv_unsharded_mesh_dir = mesh_meta.get("unsharded_mesh_dir", "dynamic")
mesh_meta = cg.meta.custom_data["mesh"]
mesh_dir = mesh_meta["dir"]
layer = mesh_meta["max_layer"]
mip = mesh_meta["mip"]
err = mesh_meta["max_error"]
cv_unsharded_mesh_dir = mesh_meta.get("dynamic_mesh_dir", "dynamic")
except KeyError:
logging.warning(f"No metadata found for {cg.graph_id}; ignoring...")
return
Expand All @@ -48,14 +55,6 @@ def callback(payload):
cg.meta.data_source.WATERSHED, mesh_dir, cv_unsharded_mesh_dir
)

try:
mesh_data = cg.meta.custom_data["mesh"]
layer = mesh_data["max_layer"]
mip = mesh_data["mip"]
err = mesh_data["max_error"]
except KeyError:
return


logging.log(INFO_HIGH, f"remeshing {l2ids}; graph {table_id} operation {op_id}.")
meshgen.remeshing(
Expand Down

0 comments on commit 69cd97c

Please sign in to comment.