diff --git a/src/aleph_client/commands/instance/__init__.py b/src/aleph_client/commands/instance/__init__.py index 6ff88142..4e252e05 100644 --- a/src/aleph_client/commands/instance/__init__.py +++ b/src/aleph_client/commands/instance/__init__.py @@ -335,6 +335,16 @@ async def create( selector=True, ), ) + if (tier.compute_units > 4 or confidential) and payment_type == PaymentType.hold: + console.print("VM with more than 4 Compute unit and/or confidential can't run using HOLD.", style="red") + if payment_chain in super_token_chains: + payment_type = PaymentType.superfluid + console.print("Switching payment type to PAY-As-You-Go (superfluid).", style="green") + else: + console.print("The current chain is not compatible with PAYG. Aborting instance creation.", style="red") + console.print(f"Compatible Chain : {super_token_chains}") + raise typer.Exit(code=1) + name = name or validated_prompt("Instance name", lambda x: x and len(x) < 65) vcpus = tier.vcpus memory = tier.memory diff --git a/src/aleph_client/commands/pricing.py b/src/aleph_client/commands/pricing.py index dd491784..f0f521ee 100644 --- a/src/aleph_client/commands/pricing.py +++ b/src/aleph_client/commands/pricing.py @@ -211,9 +211,13 @@ def display_table_for( if "vram" in tier: row.append(f"{tier['vram'] / 1024:.0f}") if "holding" in price_unit: - row.append( - f"{displayable_amount(Decimal(price_unit['holding']) * current_units, decimals=3)} tokens" - ) + # If the pricing entity is confidential, display "Not Available" + if pricing_entity == PricingEntity.INSTANCE_CONFIDENTIAL: + row.append("Not Available") + else: + row.append( + f"{displayable_amount(Decimal(price_unit['holding']) * current_units, decimals=3)} tokens" + ) if "payg" in price_unit and pricing_entity in PAYG_GROUP: payg_hourly = Decimal(price_unit["payg"]) * current_units row.append( diff --git a/src/aleph_client/commands/program.py b/src/aleph_client/commands/program.py index 715d8d6d..b748275b 100644 --- a/src/aleph_client/commands/program.py +++ b/src/aleph_client/commands/program.py @@ -182,6 +182,7 @@ async def upload( verbose=verbose, ), ) + name = name or validated_prompt("Program name", lambda x: x and len(x) < 65) vcpus = tier.vcpus memory = tier.memory diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index 1652eece..28b857db 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -7,6 +7,7 @@ from unittest.mock import AsyncMock, MagicMock, patch import pytest +import typer from aiohttp import InvalidURL from aleph_message.models import Chain, ItemHash from aleph_message.models.execution.base import Payment, PaymentType @@ -379,8 +380,11 @@ def create_mock_vm_coco_client(): "coco_hold_evm", "coco_superfluid_evm", "gpu_superfluid_evm", + "large_vm_hold_evm", + "large_vm_hold_avax", + "large_vm_hold_sol", ], - argnames="args, expected", + argnames="args, expected, should_raise", argvalues=[ ( # regular_hold_evm { @@ -392,6 +396,7 @@ def create_mock_vm_coco_client(): "immutable_volume": [f"mount=/opt/packages,ref={FAKE_STORE_HASH}"], }, (FAKE_VM_HASH, None, "ETH"), + False, ), ( # regular_superfluid_evm { @@ -401,6 +406,7 @@ def create_mock_vm_coco_client(): "crn_url": FAKE_CRN_URL, }, (FAKE_VM_HASH, FAKE_CRN_URL, "AVAX"), + False, ), ( # regular_hold_sol { @@ -409,6 +415,7 @@ def create_mock_vm_coco_client(): "rootfs": "debian12", }, (FAKE_VM_HASH, None, "SOL"), + False, ), ( # coco_hold_sol { @@ -418,7 +425,8 @@ def create_mock_vm_coco_client(): "crn_url": FAKE_CRN_URL, "confidential": True, }, - (FAKE_VM_HASH, FAKE_CRN_URL, "SOL"), + None, + True, ), ( # coco_hold_evm { @@ -428,7 +436,8 @@ def create_mock_vm_coco_client(): "crn_url": FAKE_CRN_URL, "confidential": True, }, - (FAKE_VM_HASH, FAKE_CRN_URL, "ETH"), + None, + True, ), ( # coco_superfluid_evm { @@ -439,6 +448,7 @@ def create_mock_vm_coco_client(): "confidential": True, }, (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), + False, ), ( # gpu_superfluid_evm { @@ -449,11 +459,42 @@ def create_mock_vm_coco_client(): "gpu": True, }, (FAKE_VM_HASH, FAKE_CRN_URL, "BASE"), + False, + ), + ( # large_vm_hold_evm - over 4 CU with HOLD on ETH should fail + { + "payment_type": "hold", + "payment_chain": "ETH", + "rootfs": "debian12", + "compute_units": 5, + }, + None, + True, + ), + ( # large_vm_hold_avax - over 4 CU with HOLD on AVAX should switch to superfluid + { + "payment_type": "hold", + "payment_chain": "AVAX", + "rootfs": "debian12", + "compute_units": 5, + }, + (FAKE_VM_HASH, None, "AVAX"), # The chain remains AVAX but payment_type gets switched + False, + ), + ( # large_vm_hold_sol - over 4 CU with HOLD on SOL (no superfluid) should fail + { + "payment_type": "hold", + "payment_chain": "SOL", + "rootfs": "debian12", + "compute_units": 5, + }, + None, + True, ), ], ) @pytest.mark.asyncio -async def test_create_instance(args, expected): +async def test_create_instance(args, expected, should_raise): mock_validate_ssh_pubkey_file = create_mock_validate_ssh_pubkey_file() mock_load_account = create_mock_load_account() mock_account = mock_load_account.return_value @@ -496,28 +537,36 @@ async def create_instance(instance_spec): all_args.update(instance_spec) return await create(**all_args) - returned = await create_instance(args) - # Basic assertions for all cases - mock_load_account.assert_called_once() - mock_validate_ssh_pubkey_file.return_value.read_text.assert_called_once() - mock_client.get_estimated_price.assert_called_once() - mock_auth_client.create_instance.assert_called_once() - # Payment type specific assertions - if args["payment_type"] == "hold": - mock_get_balance.assert_called_once() - elif args["payment_type"] == "superfluid": - assert mock_account.manage_flow.call_count == 2 - assert mock_wait_for_confirmed_flow.call_count == 2 - # CRN related assertions - if args["payment_type"] == "superfluid" or args.get("confidential") or args.get("gpu"): - mock_fetch_latest_crn_version.assert_called() - if not args.get("gpu"): - mock_fetch_crn_info.assert_called_once() - else: - mock_crn_table.return_value.run_async.assert_called_once() - mock_wait_for_processed_instance.assert_called_once() - mock_vm_client.start_instance.assert_called_once() - assert returned == expected + if should_raise: + with pytest.raises(typer.Exit) as exc_info: + returned = await create_instance(args) + mock_load_account.assert_called_once() + mock_validate_ssh_pubkey_file.return_value.read_text.assert_called_once() + + assert exc_info.value.exit_code == 1 + else: + returned = await create_instance(args) + # Basic assertions for all cases + mock_load_account.assert_called_once() + mock_validate_ssh_pubkey_file.return_value.read_text.assert_called_once() + mock_client.get_estimated_price.assert_called_once() + mock_auth_client.create_instance.assert_called_once() + # Payment type specific assertions + if args["payment_type"] == "hold": + mock_get_balance.assert_called_once() + elif args["payment_type"] == "superfluid": + assert mock_account.manage_flow.call_count == 2 + assert mock_wait_for_confirmed_flow.call_count == 2 + # CRN related assertions + if args["payment_type"] == "superfluid" or args.get("confidential") or args.get("gpu"): + mock_fetch_latest_crn_version.assert_called() + if not args.get("gpu"): + mock_fetch_crn_info.assert_called_once() + else: + mock_crn_table.return_value.run_async.assert_called_once() + mock_wait_for_processed_instance.assert_called_once() + mock_vm_client.start_instance.assert_called_once() + assert returned == expected @pytest.mark.asyncio