From f37839040cd653e2ba18fb428a6f4dc8f18ff0be Mon Sep 17 00:00:00 2001 From: antazoey Date: Thu, 6 Jun 2024 15:41:07 -0500 Subject: [PATCH] fix: handle missing contracts folder ID in remappings (#147) --- ape_solidity/compiler.py | 142 ++++++++++++------ .../contracts/subdir/SubCompilingContract.sol | 9 ++ tests/ape-config.yaml | 4 + tests/contracts/Imports.sol | 7 +- tests/test_compiler.py | 33 +++- 5 files changed, 149 insertions(+), 46 deletions(-) create mode 100644 tests/NonCompilingDependency/contracts/subdir/SubCompilingContract.sol diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index 550ed21..3886ac3 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -1,5 +1,6 @@ import os import re +from collections import defaultdict from collections.abc import Iterable, Iterator from pathlib import Path from typing import Any, Optional, Union @@ -8,7 +9,7 @@ from ape.contracts import ContractInstance from ape.exceptions import CompilerError, ConfigError, ContractLogicError, ProjectError from ape.logging import logger -from ape.managers.project import ProjectManager +from ape.managers.project import LocalProject, ProjectManager from ape.types import AddressType, ContractType from ape.utils import cached_property, get_full_extension, get_relative_path from ape.version import version @@ -309,10 +310,12 @@ def unpack(dep): return for unpacked_dep in dep.unpack(pm.contracts_folder / ".cache"): - _key = key_map.get(unpacked_dep.name, f"@{unpacked_dep.name}") - if _key not in remapping: - remapping[_key] = get_cache_id(unpacked_dep) - # else, was specified or configured more appropriately. + main_key = key_map.get(unpacked_dep.name) + keys = (main_key,) if main_key else (f"@{unpacked_dep.name}", unpacked_dep.name) + for _key in keys: + if _key not in remapping: + remapping[_key] = get_cache_id(unpacked_dep) + # else, was specified or configured more appropriately. remapping: dict[str, str] = {} for key, value in cfg_remappings.items(): @@ -514,7 +517,7 @@ def get_standard_input_json_from_settings( # and cater the error message accordingly. if dependencies_needed := [x for x in missing_sources if str(x).startswith("@")]: # Missing dependencies. Should only get here if dependencies are found - # in import-strs but are not installed anywhere (not in project or globally). + # in import-strs but are not installed (not in project or globally). missing_str = ", ".join(dependencies_needed) raise CompilerError( f"Missing required dependencies '{missing_str}'. " @@ -522,11 +525,11 @@ def get_standard_input_json_from_settings( "in an ape-config.yaml or using the `ape pm install` command." ) - # Otherwise, we are missing project-level source files for some reason. - # This would only happen if the user passes in unexpected files outside - # of core. - missing_src_str = ", ".join(missing_sources) - raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.") + # Otherwise, we are missing project-level source files for some reason. + # This would only happen if the user passes in unexpected files outside + # of core. + missing_src_str = ", ".join(missing_sources) + raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.") sources = { x: {"content": (pm.path / x).read_text()} @@ -914,33 +917,42 @@ def get_version_map_from_imports( # If being used in another version AND no imports in this version require it, # remove it from this version. - for solc_version, files in files_by_solc_version.copy().items(): - for file in files.copy(): - used_in_other_version = any( - [file in ls for v, ls in files_by_solc_version.items() if v != solc_version] - ) - if not used_in_other_version: + cleaned_mapped: dict[Version, set[Path]] = defaultdict(set) + for solc_version, files in files_by_solc_version.items(): + other_versions = {v: ls for v, ls in files_by_solc_version.items() if v != solc_version} + for file in files: + other_versions_used_in = {v for v in other_versions if file in other_versions[v]} + if not other_versions_used_in: + # This file is only in 1 version, which is perfect. + cleaned_mapped[solc_version].add(file) continue - other_files = [f for f in files_by_solc_version[solc_version] if f != file] - used_in_imports = False - for other_file in other_files: - source_id = str(get_relative_path(other_file, pm.path)) - import_paths = [pm.path / i for i in import_map.get(source_id, []) if i] - if file in import_paths: - used_in_imports = True - break + # This file is in multiple versions. Attempt to clean. + for other_version in other_versions_used_in: + # Other files that may need this file are any file that is not this file as well + # any file that is not also found the other version. We want to make sure + # before removing this file that it won't be needed. + other_files_that_may_need_this_file = [ + f for f in files if f != file and f not in other_versions[other_version] + ] + if other_files_that_may_need_this_file: + # This file is used by other files in this version, so we must keep it. + cleaned_mapped[solc_version].add(file) + continue - if not used_in_imports: - files_by_solc_version[solc_version].remove(file) - if not files_by_solc_version[solc_version]: - del files_by_solc_version[solc_version] + # Remove other the rest of files. + other_files_can_remove = [ + f for f in files if f != file and f in other_versions[other_version] + ] + for other_file in other_files_can_remove: + if other_file in cleaned_mapped[solc_version]: + cleaned_mapped[solc_version].remove(other_file) - result = {add_commit_hash(v): ls for v, ls in files_by_solc_version.items()} + result = {add_commit_hash(v): ls for v, ls in cleaned_mapped.items()} # Sort, so it is a nicer version map and the rest of the compilation flow - # is more predictable. - return {k: result[k] for k in sorted(result)} + # is more predictable. Also, remove any lingering empties. + return {k: result[k] for k in sorted(result) if result[k]} def _get_imported_source_paths( self, @@ -1137,6 +1149,7 @@ def _import_str_to_source_id( ) -> str: pm = project or self.local_project quote = '"' if '"' in _import_str else "'" + sep = "\\" if "\\" in _import_str else "/" try: end_index = _import_str.index(quote) + 1 @@ -1150,17 +1163,17 @@ def _import_str_to_source_id( # Get all matches. valid_matches: list[tuple[str, str]] = [] - key = None + import_remap_key = None base_path = None - for key, value in import_remapping.items(): - if key not in import_str_value: + for check_remap_key, check_remap_value in import_remapping.items(): + if check_remap_key not in import_str_value: continue - valid_matches.append((key, value)) + valid_matches.append((check_remap_key, check_remap_value)) if valid_matches: - key, value = max(valid_matches, key=lambda x: len(x[0])) - import_str_value = import_str_value.replace(key, value) + import_remap_key, import_remap_value = max(valid_matches, key=lambda x: len(x[0])) + import_str_value = import_str_value.replace(import_remap_key, import_remap_value) if import_str_value.startswith("."): base_path = source_path.parent @@ -1168,8 +1181,8 @@ def _import_str_to_source_id( base_path = pm.path elif (pm.contracts_folder / import_str_value).is_file(): base_path = pm.contracts_folder - elif key is not None and key.startswith("@"): - nm = key[1:] + elif import_remap_key is not None and import_remap_key.startswith("@"): + nm = import_remap_key[1:] for cfg_dep in pm.config.dependencies: if ( cfg_dep.get("name") == nm @@ -1178,8 +1191,53 @@ def _import_str_to_source_id( ): base_path = Path(cfg_dep["project"]) - if base_path is None: - # No base_path, do as-is. + import_str_parts = import_str_value.split(sep) + if base_path is None and ".cache" in import_str_parts: + # No base_path. First, check if the `contracts/` folder is missing, + # which is the case when compiling older Ape projects and some Foundry + # projects as well. + cache_index = import_str_parts.index(".cache") + nm_index = cache_index + 1 + version_index = nm_index + 1 + if version_index >= len(import_str_parts): + # Not sure. + return import_str_value + + cache_folder_name = import_str_parts[nm_index] + cache_folder_version = import_str_parts[version_index] + dm = pm.dependencies + dependency = dm.get_dependency(cache_folder_name, cache_folder_version) + dep_project = dependency.project + + if not isinstance(dep_project, LocalProject): + # TODO: Handle manifest-based projects as well. + # to work with old compiled manifests. + return import_str_value + + contracts_dir = dep_project.contracts_folder + dep_path = dep_project.path + contracts_folder_name = f"{get_relative_path(contracts_dir, dep_path)}" + prefix_pth = dep_path / contracts_folder_name + start_idx = version_index + 1 + suffix = sep.join(import_str_parts[start_idx:]) + new_path = prefix_pth / suffix + + if not new_path.is_file(): + # Maybe this source is actually missing... + return import_str_value + + adjusted_base_path = f"{sep.join(import_str_parts[:4])}{sep}{contracts_folder_name}" + adjusted_src_id = f"{adjusted_base_path}{sep}{suffix}" + + # Also, correct import remappings now, since it didn't work. + if key := import_remap_key: + # Base path will now included the missing contracts name. + import_remapping[key] = adjusted_base_path + + return adjusted_src_id + + elif base_path is None: + # No base_path, return as-is. return import_str_value path = (base_path / import_str_value).resolve() diff --git a/tests/NonCompilingDependency/contracts/subdir/SubCompilingContract.sol b/tests/NonCompilingDependency/contracts/subdir/SubCompilingContract.sol new file mode 100644 index 0000000..07127e1 --- /dev/null +++ b/tests/NonCompilingDependency/contracts/subdir/SubCompilingContract.sol @@ -0,0 +1,9 @@ +// SPDX-License-Identifier: MIT + +pragma solidity ^0.8.4; + +contract SubCompilingContract { + function foo() pure public returns(bool) { + return true; + } +} diff --git a/tests/ape-config.yaml b/tests/ape-config.yaml index 5c22a66..c6dab25 100644 --- a/tests/ape-config.yaml +++ b/tests/ape-config.yaml @@ -28,3 +28,7 @@ dependencies: solidity: # Using evm_version compatible with older and newer solidity versions. evm_version: constantinople + + import_remapping: + # Legacy support test (missing contracts key in import test) + - "@noncompilingdependency=noncompilingdependency" diff --git a/tests/contracts/Imports.sol b/tests/contracts/Imports.sol index 5556066..8687fd1 100644 --- a/tests/contracts/Imports.sol +++ b/tests/contracts/Imports.sol @@ -17,12 +17,15 @@ import { Struct4, Struct5 } from "./NumerousDefinitions.sol"; -import "@noncompilingdependency/contracts/CompilingContract.sol"; +import "@noncompilingdependency/CompilingContract.sol"; // Purposely repeat an import to test how the plugin handles that. -import "@noncompilingdependency/contracts/CompilingContract.sol"; +import "@noncompilingdependency/CompilingContract.sol"; import "@safe/contracts/common/Enum.sol"; +// Purposely exclude the contracts folder to test older Ape-style project imports. +import "@noncompilingdependency/subdir/SubCompilingContract.sol"; + contract Imports { function foo() pure public returns(bool) { return true; diff --git a/tests/test_compiler.py b/tests/test_compiler.py index ae8ac1d..65f8a40 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -121,6 +121,7 @@ def test_get_imports_complex(project, compiler): "contracts/.cache/dependency/local/contracts/Dependency.sol", "contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol", "contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol", + "contracts/.cache/noncompilingdependency/local/contracts/subdir/SubCompilingContract.sol", # noqa: E501 "contracts/.cache/safe/1.3.0/contracts/common/Enum.sol", "contracts/CompilesOnce.sol", "contracts/MissingPragma.sol", @@ -239,7 +240,29 @@ def test_get_version_map_dependencies(project, compiler): actual = compiler.get_version_map(paths, project=project) fail_msg = f"versions: {', '.join([str(x) for x in actual])}" - assert len(actual) == 2, fail_msg + actual_len = len(actual) + + # Expecting one old version for ImportOlderDependency and one version for Yearn stuff. + expected_len = 2 + + if actual_len > expected_len: + # Weird anomaly in CI/CD tests sometimes (at least at the time of write). + # Including additional debug information. + alt_map: dict = {} + for version, src_ids in actual.items(): + for src_id in src_ids: + if src_id in alt_map: + other_version = alt_map[src_id] + versions_str = ", ".join([str(other_version), str(version)]) + pytest.fail(f"{src_id} in multiple version '{versions_str}'") + else: + alt_map[src_id] = version + + # No duplicated versions found but still have unexpected extras. + pytest.fail(f"Unexpected number of versions. {fail_msg}") + + elif actual_len < expected_len: + pytest.fail(fail_msg) versions = sorted(list(actual.keys())) older = versions[0] # Via ImportOlderDependency @@ -355,8 +378,13 @@ def test_get_compiler_settings(project, compiler): "@browniedependency=contracts/.cache/browniedependency/local", "@dependency=contracts/.cache/dependency/local", "@dependencyofdependency=contracts/.cache/dependencyofdependency/local", - "@noncompilingdependency=contracts/.cache/noncompilingdependency/local", + # This remapping below was auto-corrected because imports were excluding contracts/ suffix. + "@noncompilingdependency=contracts/.cache/noncompilingdependency/local/contracts", "@safe=contracts/.cache/safe/1.3.0", + "browniedependency=contracts/.cache/browniedependency/local", + "dependency=contracts/.cache/dependency/local", + "dependencyofdependency=contracts/.cache/dependencyofdependency/local", + "safe=contracts/.cache/safe/1.3.0", ] # Set in config. @@ -369,6 +397,7 @@ def test_get_compiler_settings(project, compiler): "contracts/.cache/dependency/local/contracts/Dependency.sol", "contracts/.cache/dependencyofdependency/local/contracts/DependencyOfDependency.sol", "contracts/.cache/noncompilingdependency/local/contracts/CompilingContract.sol", + "contracts/.cache/noncompilingdependency/local/contracts/subdir/SubCompilingContract.sol", "contracts/.cache/safe/1.3.0/contracts/common/Enum.sol", "contracts/CompilesOnce.sol", "contracts/Imports.sol",