Skip to content

Commit

Permalink
Merge branch 'main' into issue_XyData_documentation
Browse files Browse the repository at this point in the history
  • Loading branch information
ayushjariyal authored Feb 17, 2025
2 parents 37abafa + c4dfada commit 36daf22
Show file tree
Hide file tree
Showing 13 changed files with 91 additions and 159 deletions.
2 changes: 1 addition & 1 deletion .docker/requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
docker~=7.0.0
pytest~=8.2.0
requests~=2.32.0
pytest-docker~=3.1.0
pytest-docker~=3.2.0
4 changes: 2 additions & 2 deletions .docker/tests/test_aiida.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@


def test_correct_python_version_installed(aiida_exec, python_version):
info = json.loads(aiida_exec('mamba list --json --full-name python').decode())[0]
info = json.loads(aiida_exec('mamba list --json --full-name python', ignore_stderr=True).decode())[0]
assert info['name'] == 'python'
assert parse(info['version']) == parse(python_version)

Expand All @@ -15,7 +15,7 @@ def test_correct_pgsql_version_installed(aiida_exec, pgsql_version, variant):
if variant == 'aiida-core-base':
pytest.skip('PostgreSQL is not installed in the base image')

info = json.loads(aiida_exec('mamba list --json --full-name postgresql').decode())[0]
info = json.loads(aiida_exec('mamba list --json --full-name postgresql', ignore_stderr=True).decode())[0]
assert info['name'] == 'postgresql'
assert parse(info['version']).major == parse(pgsql_version).major

Expand Down
2 changes: 0 additions & 2 deletions .github/workflows/ci-code.yml
Original file line number Diff line number Diff line change
Expand Up @@ -126,8 +126,6 @@ jobs:
with:
python-version: '3.12'
from-lock: 'true'
# NOTE: The `verdi devel check-undesired-imports` fails if
# the 'tui' extra is installed.
extras: ''

- name: Run verdi tests
Expand Down
4 changes: 4 additions & 0 deletions .github/workflows/docker-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,10 @@ jobs:
- name: Set up QEMU
if: ${{ inputs.platforms != 'linux/amd64' }}
uses: docker/setup-qemu-action@v3
with:
# Workaround for https://github.com/tonistiigi/binfmt/issues/215
image: tonistiigi/binfmt:qemu-v7.0.0-28
cache-image: true

- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v3
Expand Down
12 changes: 8 additions & 4 deletions src/aiida/cmdline/commands/cmd_devel.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,12 +61,11 @@ def devel_check_load_time():
def devel_check_undesired_imports():
"""Check that verdi does not import python modules it shouldn't.
Note: The blacklist was taken from the list of packages in the 'atomic_tools' extra but can be extended.
This is to keep the verdi CLI snappy, especially for tab-completion.
"""
loaded_modules = 0

for modulename in [
'asyncio',
unwanted_modules = [
'requests',
'plumpy',
'disk_objectstore',
Expand All @@ -78,7 +77,12 @@ def devel_check_undesired_imports():
'spglib',
'pymysql',
'yaml',
]:
]
# trogon powers the optional TUI and uses asyncio.
# Check for asyncio only when the optional tui extras are not installed.
if 'trogon' not in sys.modules:
unwanted_modules += 'asyncio'
for modulename in unwanted_modules:
if modulename in sys.modules:
echo.echo_warning(f'Detected loaded module "{modulename}"')
loaded_modules += 1
Expand Down
14 changes: 0 additions & 14 deletions src/aiida/common/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -446,20 +446,6 @@ def join_labels(labels, join_symbol='|', threshold=1.0e-6):
return new_labels


def strip_prefix(full_string, prefix):
"""Strip the prefix from the given string and return it. If the prefix is not present
the original string will be returned unaltered
:param full_string: the string from which to remove the prefix
:param prefix: the prefix to remove
:return: the string with prefix removed
"""
if full_string.startswith(prefix):
return full_string.rsplit(prefix)[1]

return full_string


class Capturing:
"""This class captures stdout and returns it
(as a list, split by lines).
Expand Down
9 changes: 4 additions & 5 deletions src/aiida/orm/utils/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import warnings

from aiida.common import exceptions
from aiida.common.utils import strip_prefix
from aiida.orm.fields import EntityFieldMeta

__all__ = (
Expand Down Expand Up @@ -49,19 +48,19 @@ def load_node_class(type_string):
# This exception needs to be there to make migrations work that rely on the old type string starting with `node.`
# Since now the type strings no longer have that prefix, we simply strip it and continue with the normal logic.
if base_path.startswith('node.'):
base_path = strip_prefix(base_path, 'node.')
base_path = base_path.removeprefix('node.')

# Data nodes are the only ones with sub classes that are still external, so if the plugin is not available
# we fall back on the base node type
if base_path.startswith('data.'):
entry_point_name = strip_prefix(base_path, 'data.')
entry_point_name = base_path.removeprefix('data.')
try:
return load_entry_point('aiida.data', entry_point_name)
except exceptions.MissingEntryPointError:
return Data

if base_path.startswith('process'):
entry_point_name = strip_prefix(base_path, 'nodes.')
entry_point_name = base_path.removeprefix('data.')
return load_entry_point('aiida.node', entry_point_name)

# At this point we really have an anomalous type string. At some point, storing nodes with unresolvable type strings
Expand Down Expand Up @@ -99,7 +98,7 @@ def get_type_string_from_class(class_module, class_name):

# Sequentially and **in order** strip the prefixes if present
for prefix in prefixes:
type_string = strip_prefix(type_string, prefix)
type_string = type_string.removeprefix(prefix)

# This needs to be here as long as `aiida.orm.nodes.data` does not live in `aiida.orm.nodes.data` because all the
# `Data` instances will have a type string that starts with `data.` instead of `nodes.`, so in order to match any
Expand Down
10 changes: 4 additions & 6 deletions src/aiida/restapi/common/identifiers.py
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ def load_entry_point_from_full_type(full_type):
:raises `~aiida.common.exceptions.EntryPointError`: if the corresponding entry point cannot be loaded
"""
from aiida.common import EntryPointError
from aiida.common.utils import strip_prefix
from aiida.plugins.entry_point import is_valid_entry_point_string, load_entry_point, load_entry_point_from_string

data_prefix = 'data.'
Expand All @@ -151,7 +150,7 @@ def load_entry_point_from_full_type(full_type):
raise EntryPointError(f'could not load entry point `{process_type}`')

elif node_type.startswith(data_prefix):
base_name = strip_prefix(node_type, data_prefix)
base_name = node_type.removeprefix(data_prefix)
entry_point_name = base_name.rsplit('.', 2)[0]

try:
Expand Down Expand Up @@ -229,20 +228,19 @@ def __init__(self, namespace, path=None, label=None, full_type=None, counter=Non

def _infer_full_type(self, full_type):
"""Infer the full type based on the current namespace path and the given full type of the leaf."""
from aiida.common.utils import strip_prefix

if full_type or self._path is None:
return full_type

full_type = strip_prefix(self._path, 'node.')
full_type = self._path.removeprefix('node.')

if full_type.startswith('process.'):
for basepath, full_type_template in self.process_full_type_mapping.items():
if full_type.startswith(basepath):
plugin_name = strip_prefix(full_type, basepath)
plugin_name = full_type.removeprefix(basepath)
if plugin_name.startswith(DEFAULT_NAMESPACE_LABEL):
temp_type_template = self.process_full_type_mapping_unplugged[basepath]
plugin_name = strip_prefix(plugin_name, DEFAULT_NAMESPACE_LABEL + '.')
plugin_name = plugin_name.removeprefix(DEFAULT_NAMESPACE_LABEL + '.')
full_type = temp_type_template.format(plugin_name=plugin_name)
else:
full_type = full_type_template.format(plugin_name=plugin_name)
Expand Down
74 changes: 16 additions & 58 deletions src/aiida/storage/psql_dos/orm/groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,17 +167,10 @@ def add_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is on, the SQLA ORM is skipped and SQLA is used
to create a direct SQL INSERT statement to the group-node relationship
table (to improve speed).
"""
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.exc import IntegrityError

super().add_nodes(nodes)
skip_orm = kwargs.get('skip_orm', False)

def check_node(given_node):
"""Check if given node is of correct type and stored"""
Expand All @@ -188,31 +181,16 @@ def check_node(given_node):
raise ValueError('At least one of the provided nodes is unstored, stopping...')

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes

for node in nodes:
check_node(node)

# Use pattern as suggested here:
# http://docs.sqlalchemy.org/en/latest/orm/session_transaction.html#using-savepoint
try:
with session.begin_nested():
dbnodes.append(node.bare_model)
session.flush()
except IntegrityError:
# Duplicate entry, skip
pass
else:
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))
ins_dict = []
for node in nodes:
check_node(node)
ins_dict.append({'dbnode_id': node.id, 'dbgroup_id': self.id})
if len(ins_dict) == 0:
return

table = self.GROUP_NODE_CLASS.__table__
ins = insert(table).values(ins_dict)
session.execute(ins.on_conflict_do_nothing(index_elements=['dbnode_id', 'dbgroup_id']))

# Commit everything as up till now we've just flushed
if not session.in_nested_transaction():
Expand All @@ -224,45 +202,25 @@ def remove_nodes(self, nodes, **kwargs):
:note: all the nodes *and* the group itself have to be stored.
:param nodes: a list of `BackendNode` instance to be added to this group
:param kwargs:
skip_orm: When the flag is set to `True`, the SQLA ORM is skipped and SQLA is used to create a direct SQL
DELETE statement to the group-node relationship table in order to improve speed.
"""
from sqlalchemy import and_

super().remove_nodes(nodes)

# Get dbnodes here ONCE, otherwise each call to dbnodes will re-read the current value in the database
dbnodes = self.model.dbnodes
skip_orm = kwargs.get('skip_orm', False)

def check_node(node):
if not isinstance(node, self.NODE_CLASS):
raise TypeError(f'invalid type {type(node)}, has to be {self.NODE_CLASS}')

if node.id is None:
raise ValueError('At least one of the provided nodes is unstored, stopping...')

list_nodes = []

with utils.disable_expire_on_commit(self.backend.get_session()) as session:
if not skip_orm:
for node in nodes:
check_node(node)

# Check first, if SqlA issues a DELETE statement for an unexisting key it will result in an error
if node.bare_model in dbnodes:
list_nodes.append(node.bare_model)

for node in list_nodes:
dbnodes.remove(node)
else:
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)
table = self.GROUP_NODE_CLASS.__table__
for node in nodes:
check_node(node)
clause = and_(table.c.dbnode_id == node.id, table.c.dbgroup_id == self.id)
statement = table.delete().where(clause)
session.execute(statement)

if not session.in_nested_transaction():
session.commit()
Expand Down
6 changes: 3 additions & 3 deletions src/aiida/transports/plugins/local.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,7 +866,7 @@ def rename(self, oldpath: TransportPath, newpath: TransportPath):
:param str oldpath: existing name of the file or folder
:param str newpath: new name for the file or folder
:raises OSError: if src/dst is not found
:raises OSError: if oldpath is not found or newpath already exists
:raises ValueError: if src/dst is not a valid string
"""
oldpath = str(oldpath)
Expand All @@ -877,8 +877,8 @@ def rename(self, oldpath: TransportPath, newpath: TransportPath):
raise ValueError(f'Destination {newpath} is not a valid string')
if not os.path.exists(oldpath):
raise OSError(f'Source {oldpath} does not exist')
if not os.path.exists(newpath):
raise OSError(f'Destination {newpath} does not exist')
if os.path.exists(newpath):
raise OSError(f'Destination {newpath} already exists.')

shutil.move(oldpath, newpath)

Expand Down
63 changes: 0 additions & 63 deletions tests/orm/implementation/test_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,66 +25,3 @@ def test_creation_from_dbgroup(backend):

assert group.pk == gcopy.pk
assert group.uuid == gcopy.uuid


def test_add_nodes_skip_orm():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag."""
group = orm.Group(label='test_adding_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
node_05 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03, node_04, node_05]

group.add_nodes([node_01], skip_orm=True)
group.add_nodes([node_02, node_03], skip_orm=True)
group.add_nodes((node_04, node_05), skip_orm=True)

assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Try to add a node that is already present: there should be no problem
group.add_nodes([node_01], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_add_nodes_skip_orm_batch():
"""Test the `SqlaGroup.add_nodes` method with the `skip_orm=True` flag and batches."""
nodes = [orm.Data().store().backend_entity for _ in range(100)]

# Add nodes to groups using different batch size. Check in the end the correct addition.
batch_sizes = (1, 3, 10, 1000)
for batch_size in batch_sizes:
group = orm.Group(label=f'test_batches_{batch_size!s}').store()
group.backend_entity.add_nodes(nodes, skip_orm=True, batch_size=batch_size)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)


def test_remove_nodes_bulk():
"""Test node removal with `skip_orm=True`."""
group = orm.Group(label='test_removing_nodes').store().backend_entity

node_01 = orm.Data().store().backend_entity
node_02 = orm.Data().store().backend_entity
node_03 = orm.Data().store().backend_entity
node_04 = orm.Data().store().backend_entity
nodes = [node_01, node_02, node_03]

group.add_nodes(nodes)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a node that is not in the group: nothing should happen
group.remove_nodes([node_04], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove one Node
nodes.remove(node_03)
group.remove_nodes([node_03], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)

# Remove a list of Nodes and check
nodes.remove(node_01)
nodes.remove(node_02)
group.remove_nodes([node_01, node_02], skip_orm=True)
assert set(_.pk for _ in nodes) == set(_.pk for _ in group.nodes)
Loading

0 comments on commit 36daf22

Please sign in to comment.