From d84b76f8158d631dd9c36039cea728646422a336 Mon Sep 17 00:00:00 2001 From: Anne Haley Date: Wed, 22 Feb 2023 12:01:57 -0500 Subject: [PATCH] Fix multidomain particle use cases and add progress bars to swcc (#283) * Add upload and download progress bar printout * Fix upload & download of particles-only project * Fix rendering of multiple particle sets for multiple domains * Lint fix * Remove anatomies created for particle keys and name keys --- .../migrations/0024_world_local_nullable.py | 23 +++++++ shapeworks_cloud/core/models.py | 4 +- shapeworks_cloud/core/serializers.py | 4 +- swcc/swcc/models/project.py | 67 ++++++++++++------- swcc/swcc/models/utils.py | 17 +++++ web/shapeworks/src/components/ShapeViewer.vue | 6 +- 6 files changed, 91 insertions(+), 30 deletions(-) create mode 100644 shapeworks_cloud/core/migrations/0024_world_local_nullable.py diff --git a/shapeworks_cloud/core/migrations/0024_world_local_nullable.py b/shapeworks_cloud/core/migrations/0024_world_local_nullable.py new file mode 100644 index 00000000..1c8aa487 --- /dev/null +++ b/shapeworks_cloud/core/migrations/0024_world_local_nullable.py @@ -0,0 +1,23 @@ +# Generated by Django 3.2.17 on 2023-02-17 19:48 + +from django.db import migrations +import s3_file_field.fields + + +class Migration(migrations.Migration): + dependencies = [ + ('core', '0023_task_progress'), + ] + + operations = [ + migrations.AlterField( + model_name='optimizedparticles', + name='local', + field=s3_file_field.fields.S3FileField(null=True), + ), + migrations.AlterField( + model_name='optimizedparticles', + name='world', + field=s3_file_field.fields.S3FileField(null=True), + ), + ] diff --git a/shapeworks_cloud/core/models.py b/shapeworks_cloud/core/models.py index 1e94e7e7..65ba9bbd 100644 --- a/shapeworks_cloud/core/models.py +++ b/shapeworks_cloud/core/models.py @@ -174,8 +174,8 @@ class GroomedMesh(TimeStampedModel, models.Model): class OptimizedParticles(TimeStampedModel, models.Model): project = models.ForeignKey(Project, on_delete=models.CASCADE) - world = S3FileField() - local = S3FileField() + world = S3FileField(null=True) + local = S3FileField(null=True) transform = S3FileField(null=True) groomed_segmentation = models.ForeignKey( diff --git a/shapeworks_cloud/core/serializers.py b/shapeworks_cloud/core/serializers.py index a36b3fd9..0923352f 100644 --- a/shapeworks_cloud/core/serializers.py +++ b/shapeworks_cloud/core/serializers.py @@ -159,8 +159,8 @@ class Meta: class OptimizedParticlesSerializer(serializers.ModelSerializer): - world = S3FileSerializerField() - local = S3FileSerializerField() + world = S3FileSerializerField(required=False, allow_null=True) + local = S3FileSerializerField(required=False, allow_null=True) transform = S3FileSerializerField(required=False, allow_null=True) constraints = S3FileSerializerField(required=False, allow_null=True) diff --git a/swcc/swcc/models/project.py b/swcc/swcc/models/project.py index 4ad2cd8a..7a03c813 100644 --- a/swcc/swcc/models/project.py +++ b/swcc/swcc/models/project.py @@ -40,7 +40,7 @@ Segmentation, ) from .subject import Subject -from .utils import FileIO, shape_file_type +from .utils import FileIO, print_progress_bar, shape_file_type class ProjectFileIO(BaseModel, FileIO): @@ -69,8 +69,15 @@ def load_data_from_json(self, file, create): contents = json.load(open(file)) data = self.interpret_data(contents['data']) if create: + print(f'Uploading files for {len(data)} subjects...') + i = 0 + total_progress_steps = len(data) + print_progress_bar(i, total_progress_steps) for [subject, objects_by_domain] in data: + i += 1 self.create_objects_for_subject(subject, objects_by_domain) + print_progress_bar(i, total_progress_steps) + print() return data def interpret_data(self, input_data): @@ -85,18 +92,15 @@ def interpret_data(self, input_data): entry_values: Dict = {p: [] for p in expected_key_prefixes} entry_values['anatomy_ids'] = [] for key in entry.keys(): - prefixes = [p for p in expected_key_prefixes if key.startswith(p)] - if len(prefixes) > 0: - entry_values[prefixes[0]].append(entry[key]) - anatomy_id = 'anatomy' + key.replace(prefixes[0], '') - if anatomy_id not in entry_values['anatomy_ids'] and prefixes[0] in [ - 'shape', - 'mesh', - 'segmentation', - 'contour', - 'image', - ]: - entry_values['anatomy_ids'].append(anatomy_id) + if key != 'name': + prefixes = [p for p in expected_key_prefixes if key.startswith(p)] + if len(prefixes) > 0: + entry_values[prefixes[0]].append(entry[key]) + anatomy_id = 'anatomy' + key.replace(prefixes[0], '').replace( + '_particles', '' + ) + if anatomy_id not in entry_values['anatomy_ids']: + entry_values['anatomy_ids'].append(anatomy_id) objects_by_domain = {} for index, anatomy_id in enumerate(entry_values['anatomy_ids']): objects_by_domain[anatomy_id] = { @@ -225,34 +229,47 @@ def relative_download(file, resolve): relative_download(self.project.file, '') data = self.load_data(create=False) + print(f'Downloading files for {len(data)} subjects...') + i = 0 + total_progress_steps = len(data) + 9 # 9 download mappings to evaluate + print_progress_bar(i, total_progress_steps) download_mappings: Dict[str, List] = { - 'mesh': [{'set': list(self.project.dataset.meshes), 'attr': 'file'}], - 'segmentation': [{'set': list(self.project.dataset.segmentations), 'attr': 'file'}], - 'contour': [{'set': list(self.project.dataset.contours), 'attr': 'file'}], - 'image': [{'set': list(self.project.dataset.images), 'attr': 'file'}], + 'mesh': [{'set': self.project.dataset.meshes, 'attr': 'file'}], + 'segmentation': [{'set': self.project.dataset.segmentations, 'attr': 'file'}], + 'contour': [{'set': self.project.dataset.contours, 'attr': 'file'}], + 'image': [{'set': self.project.dataset.images, 'attr': 'file'}], 'groomed': [ - {'set': list(self.project.groomed_meshes), 'attr': 'file'}, - {'set': list(self.project.groomed_segmentations), 'attr': 'file'}, + {'set': self.project.groomed_meshes, 'attr': 'file'}, + {'set': self.project.groomed_segmentations, 'attr': 'file'}, ], - 'local': [{'set': list(self.project.particles), 'attr': 'local'}], - 'world': [{'set': list(self.project.particles), 'attr': 'world'}], - 'landmarks': [{'set': list(self.project.dataset.landmarks), 'attr': 'file'}], - 'constraints': [{'set': list(self.project.dataset.constraints), 'attr': 'file'}], + 'local': [{'set': self.project.particles, 'attr': 'local'}], + 'world': [{'set': self.project.particles, 'attr': 'world'}], + 'landmarks': [{'set': self.project.dataset.landmarks, 'attr': 'file'}], + 'constraints': [{'set': self.project.dataset.constraints, 'attr': 'file'}], } + download_mappings_evaluated = {} + for k, v in download_mappings.items(): + i += 1 + print_progress_bar(i, total_progress_steps) + download_mappings_evaluated[k] = [dict(s, **{'set': list(s['set'])}) for s in v] + for [_s, objects_by_domain] in data: + i += 1 for _a, objects in objects_by_domain.items(): for key, value in objects.items(): if key == 'shape': key = shape_file_type(Path(value)).__name__.lower() match_name = Path(value).name - if key in download_mappings: - for mapping in download_mappings[key]: + if key in download_mappings_evaluated: + for mapping in download_mappings_evaluated[key]: for x in mapping['set']: attr = getattr(x, mapping['attr']) if str(attr) == match_name: relative_download(attr, value) + print_progress_bar(i, total_progress_steps) + print() def load_analysis_from_json(self, file_path): project_root = Path(str(self.project.file.path)).parent diff --git a/swcc/swcc/models/utils.py b/swcc/swcc/models/utils.py index c1106042..44f45044 100644 --- a/swcc/swcc/models/utils.py +++ b/swcc/swcc/models/utils.py @@ -130,3 +130,20 @@ def get_config_value(filename: str, key: str) -> Optional[Any]: return config.get(key) return None + + +# Progress Bar Printing Function +# stackoverflow.com/questions/3173320/text-progress-bar-in-terminal-with-block-characters +def print_progress_bar( + iteration, + total, + prefix='Progress:', + suffix='Complete.', + decimals=1, + length=30, + fill='█', +): + percent = ('{0:.' + str(decimals) + 'f}').format(100 * (iteration / float(total))) + filled_length = int(length * iteration // total) + bar = fill * filled_length + '-' * (length - filled_length) + print(f'\r{prefix} |{bar}| {percent}% {suffix}', end='\r') diff --git a/web/shapeworks/src/components/ShapeViewer.vue b/web/shapeworks/src/components/ShapeViewer.vue index 7d827c61..3bfb47c9 100644 --- a/web/shapeworks/src/components/ShapeViewer.vue +++ b/web/shapeworks/src/components/ShapeViewer.vue @@ -561,7 +561,11 @@ export default { this.addShapes(renderer, label, shapes.map(({shape}) => shape)); const points = shapes.map(({points}) => points) - if(points.length > 0 && points[0].getNumberOfPoints() > 0) this.addPoints(renderer, points[0]); + points.forEach((pointSet) => { + if(pointSet.getNumberOfPoints() > 0) { + this.addPoints(renderer, pointSet); + } + }) const camera = vtkCamera.newInstance(); renderer.setActiveCamera(camera);