diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 45b3831..911bc50 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -10,21 +10,21 @@ repos: - id: isort - repo: https://github.com/psf/black - rev: 24.4.2 + rev: 24.8.0 hooks: - id: black name: black - repo: https://github.com/pycqa/flake8 - rev: 7.0.0 + rev: 7.1.1 hooks: - id: flake8 - repo: https://github.com/pre-commit/mirrors-mypy - rev: v1.10.0 + rev: v1.11.1 hooks: - id: mypy - additional_dependencies: [types-requests, types-setuptools, pydantic, types-pkg-resources] + additional_dependencies: [types-requests, types-setuptools, pydantic] - repo: https://github.com/executablebooks/mdformat rev: 0.7.17 diff --git a/ape_solidity/_models.py b/ape_solidity/_models.py new file mode 100644 index 0000000..74a780c --- /dev/null +++ b/ape_solidity/_models.py @@ -0,0 +1,421 @@ +import os +from collections.abc import Iterable +from functools import singledispatchmethod +from pathlib import Path +from typing import TYPE_CHECKING, Optional + +from ape.exceptions import CompilerError, ProjectError +from ape.managers import ProjectManager +from ape.utils.basemodel import BaseModel, ManagerAccessMixin, classproperty +from ape.utils.os import get_relative_path +from pydantic import field_serializer + +from ape_solidity._utils import get_single_import_lines + +if TYPE_CHECKING: + from ape_solidity.compiler import SolidityCompiler + + +class ApeSolidityMixin(ManagerAccessMixin): + @classproperty + def solidity(cls) -> "SolidityCompiler": + return cls.compiler_manager.solidity + + +class ApeSolidityModel(BaseModel, ApeSolidityMixin): + pass + + +def _create_import_remapping(project: ProjectManager) -> dict[str, str]: + prefix = f"{get_relative_path(project.contracts_folder, project.path)}" + specified = project.dependencies.install() + + # Ensure .cache folder is ready-to-go. + cache_folder = project.contracts_folder / ".cache" + cache_folder.mkdir(exist_ok=True, parents=True) + + # Start with explicitly configured remappings. + cfg_remappings: dict[str, str] = { + m.key: m.value for m in project.config.solidity.import_remapping + } + key_map: dict[str, str] = {} + + def get_cache_id(dep) -> str: + return os.path.sep.join((prefix, ".cache", dep.name, dep.version)) + + def unpack(dep): + # Ensure the dependency is installed. + try: + dep.project + except ProjectError: + # Try to compile anyway. + # Let the compiler fail on its own. + return + + for unpacked_dep in dep.unpack(project.contracts_folder / ".cache"): + main_key = key_map.get(unpacked_dep.name) + keys = (main_key,) if main_key else (f"@{unpacked_dep.name}", unpacked_dep.name) + for _key in keys: + if _key not in remapping: + remapping[_key] = get_cache_id(unpacked_dep) + # else, was specified or configured more appropriately. + + remapping: dict[str, str] = {} + for key, value in cfg_remappings.items(): + # Check if legacy-style and still accept it. + parts = value.split(os.path.sep) + name = parts[0] + _version = None + if len(parts) > 2: + # Clearly, not pointing at a dependency. + remapping[key] = value + continue + + elif len(parts) == 2: + _version = parts[1] + + if _version is None: + matching_deps = [d for d in project.dependencies.installed if d.name == name] + if len(matching_deps) == 1: + _version = matching_deps[0].version + else: + # Not obvious if it is pointing at one of these dependencies. + remapping[key] = value + continue + + # Dependency found. Map to it using the provider key. + dependency = project.dependencies.get_dependency(name, _version) + key_map[dependency.name] = key + unpack(dependency) + + # Add auto-remapped dependencies. + # (Meaning, the dependencies are specified but their remappings + # are not, so we auto-generate default ones). + for dependency in specified: + unpack(dependency) + + return remapping + + +class ImportRemappingCache(ApeSolidityMixin): + def __init__(self): + # Cache project paths to import remapping. + self._cache: dict[str, dict[str, str]] = {} + + def __getitem__(self, project: ProjectManager) -> dict[str, str]: + if remapping := self._cache.get(f"{project.path}"): + return remapping + + return self.add_project(project) + + def add_project(self, project: ProjectManager) -> dict[str, str]: + remapping = _create_import_remapping(project) + return self.add(project, remapping) + + def add(self, project: ProjectManager, remapping: dict[str, str]): + self._cache[f"{project.path}"] = remapping + return remapping + + @classmethod + def get_import_remapping(cls, project: ProjectManager): + return _create_import_remapping(project) + + +class ImportStatementMetadata(ApeSolidityModel): + quote_char: str + sep_char: str + raw_value: str + + # Only set when remappings are involved. + import_remap_key: Optional[str] = None + import_remap_value: Optional[str] = None + + # Only set when import-remapping resolves to a dependency. + dependency_name: Optional[str] = None + dependency_version: Optional[str] = None + + # Set once a source-file is located. This happens _after_ + # dependency related properties. + source_id: Optional[str] = None + path: Optional[Path] = None + + @property + def value(self) -> str: + if self.import_remap_key and self.import_remap_value: + return self.raw_value.replace(self.import_remap_key, self.import_remap_value) + + return self.raw_value + + @property + def dependency(self) -> Optional[ProjectManager]: + if name := self.dependency_name: + if version := self.dependency_version: + return self.local_project.dependencies[name][version] + + return None + + @classmethod + def parse_line( + cls, + value: str, + reference: Path, + project: ProjectManager, + dependency: Optional[ProjectManager] = None, + ) -> "ImportStatementMetadata": + quote = '"' if '"' in value else "'" + sep = "\\" if "\\" in value else "/" + + try: + end_index = value.index(quote) + 1 + except ValueError as err: + raise CompilerError( + f"Error parsing import statement '{value}' in '{reference.name}'." + ) from err + + import_str_prefix = value[end_index:] + value = import_str_prefix[: import_str_prefix.index(quote)] + result = cls(quote_char=quote, sep_char=sep, raw_value=value) + result._resolve_source(reference, project, dependency=dependency) + return result + + def __repr__(self) -> str: + return self.raw_value + + def __hash__(self) -> int: + path = self.path or Path(self.raw_value) + return hash(path) + + def _resolve_source( + self, reference: Path, project: ProjectManager, dependency: Optional[ProjectManager] = None + ): + if not self._resolve_dependency(project, dependency=dependency): + # Handle non-dependencies. + self._resolve_import_remapping(project) + self._resolve_path(reference, project) + + def _resolve_import_remapping(self, project: ProjectManager): + import_remapping = self.solidity._import_remapping_cache[project] + + # Get all matches. + valid_matches: list[tuple[str, str]] = [] + for check_remap_key, check_remap_value in import_remapping.items(): + if check_remap_key not in self.value: + continue + + valid_matches.append((check_remap_key, check_remap_value)) + + if valid_matches: + self.import_remap_key, self.import_remap_value = max( + valid_matches, key=lambda x: len(x[0]) + ) + + def _resolve_path(self, reference: Path, project: ProjectManager): + base_path = None + if self.value.startswith("."): + base_path = reference.parent + elif (project.path / self.value).is_file(): + base_path = project.path + elif (project.contracts_folder / self.value).is_file(): + base_path = project.contracts_folder + elif self.import_remap_key is not None and self.import_remap_key.startswith("@"): + nm = self.import_remap_key[1:] + for cfg_dep in project.config.dependencies: + if ( + cfg_dep.get("name") == nm + and "project" in cfg_dep + and (Path(cfg_dep["project"]) / self.value).is_file() + ): + base_path = Path(cfg_dep["project"]) + + if base := base_path: + self.path = (base / self.value).resolve().absolute() + self.source_id = f"{get_relative_path(self.path, project.path)}" + + def _resolve_dependency( + self, project: ProjectManager, dependency: Optional[ProjectManager] = None + ) -> bool: + config_project = dependency or project + # NOTE: Dependency is set if we are getting dependencies of dependencies. + # It is tricky because we still need the base (local) project along + # with project defining this dependency, for separate pieces of data. + # (need base project for relative .cache folder location and need dependency + # for configuration). + import_remapping = self.solidity._import_remapping_cache[config_project] + parts = self.value.split(self.sep_char) + pot_dep_names = {parts[0], parts[0].lstrip("@"), f"@{parts[0].lstrip('@')}"} + matches = [] + for nm in pot_dep_names: + if nm not in import_remapping or nm not in self.value: + continue + + matches.append(nm) + + if not matches: + return False + + name = max(matches, key=lambda x: len(x)) + resolved_import = import_remapping[name] + resolved_path_parts = resolved_import.split(self.sep_char) + if ".cache" not in resolved_path_parts: + # Not a dependency + return False + + cache_index = resolved_path_parts.index(".cache") + nm_index = cache_index + 1 + version_index = nm_index + 1 + + if version_index >= len(resolved_path_parts): + # Not sure. + return False + + cache_folder_name = resolved_path_parts[nm_index] + cache_folder_version = resolved_path_parts[version_index] + dependency_project = config_project.dependencies[cache_folder_name][cache_folder_version] + if not dependency_project: + return False + + self.import_remap_key = name + self.import_remap_value = resolved_import + self.dependency_name = dependency_project.name + self.dependency_version = dependency_project.version + path = project.path / self.value + if path.is_file(): + self.source_id = self.value + self.path = project.path / self.source_id + else: + contracts_dir = dependency_project.contracts_folder + dep_path = dependency_project.path + contracts_folder_name = f"{get_relative_path(contracts_dir, dep_path)}" + prefix_pth = dep_path / contracts_folder_name + start_idx = version_index + 1 + suffix = self.sep_char.join(self.value.split(self.sep_char)[start_idx:]) + new_path = prefix_pth / suffix + if not new_path.is_file(): + # No further resolution required (but still is a resolved dependency). + return True + + adjusted_base_path = ( + f"{self.sep_char.join(resolved_path_parts[:4])}" + f"{self.sep_char}{contracts_folder_name}" + ) + adjusted_src_id = f"{adjusted_base_path}{self.sep_char}{suffix}" + + # Also, correct import remappings now, since it didn't work. + if key := self.import_remap_key: + # Base path will now included the missing contracts name. + self.solidity._import_remapping_cache[project][key] = adjusted_base_path + self.import_remap_value = adjusted_base_path + + self.path = project.path / adjusted_src_id + self.source_id = adjusted_src_id + + return True + + +class SourceTree(ApeSolidityModel): + """ + A model representing a source-tree, meaning given a sequence + of base-sources, this is the tree of each of them with their imports. + """ + + import_statements: dict[tuple[Path, str], set[ImportStatementMetadata]] = {} + """ + Mapping of each file to its import-statements. + """ + + @field_serializer("import_statements") + def _serialize_import_statements(self, statements, info): + imports_by_source_id = {k[1]: v for k, v in statements.items()} + keys = sorted(imports_by_source_id.keys()) + return { + k: sorted(list({i.source_id for i in imports_by_source_id[k] if i.source_id})) + for k in keys + } + + @classmethod + def from_source_files( + cls, + source_files: Iterable[Path], + project: ProjectManager, + statements: Optional[dict[tuple[Path, str], set[ImportStatementMetadata]]] = None, + dependency: Optional[ProjectManager] = None, + ) -> "SourceTree": + statements = statements or {} + for path in source_files: + key = (path, f"{get_relative_path(path.absolute(), project.path)}") + if key in statements: + # We have already captures all of the imports from the file. + continue + + statements[key] = set() + for line in get_single_import_lines(path): + node_data = ImportStatementMetadata.parse_line( + line, path, project, dependency=dependency + ) + statements[key].add(node_data) + if sub_path := node_data.path: + sub_source_id = f"{get_relative_path(sub_path.absolute(), project.path)}" + sub_key = (sub_path, sub_source_id) + + if sub_key in statements: + sub_statements = statements[sub_key] + else: + sub_tree = SourceTree.from_source_files( + (sub_path,), + project, + statements=statements, + dependency=node_data.dependency, + ) + statements = {**statements, **sub_tree.import_statements} + sub_statements = statements[sub_key] + + for sub_stmt in sub_statements: + statements[key].add(sub_stmt) + + return cls(import_statements=statements) + + @singledispatchmethod + def __getitem__(self, key) -> set[ImportStatementMetadata]: + return set() + + @__getitem__.register + def __getitem_path(self, path: Path) -> set[ImportStatementMetadata]: + return next((v for k, v in self.import_statements.items() if k[0] == path), set()) + + @__getitem__.register + def __getitem_str(self, source_id: str) -> set[ImportStatementMetadata]: + return next((v for k, v in self.import_statements.items() if k[1] == source_id), set()) + + @singledispatchmethod + def __contains__(self, value) -> bool: + return False + + @__contains__.register + def __contains_path(self, path: Path) -> bool: + return any(x[0] == path for x in self.import_statements) + + @__contains__.register + def __contains_str(self, source_id: str) -> bool: + return any(x[1] == source_id for x in self.import_statements) + + @__contains__.register + def __contains_tuple(self, key: tuple) -> bool: + return key in self.import_statements + + def __repr__(self) -> str: + key_str = ", ".join([f"{k[1]}={v}" for k, v in self.import_statements.items() if v]) + return f"" + + def get_imported_paths(self, path: Path) -> set[Path]: + return {x.path for x in self[path] if x.path} + + def get_remappings_used(self, paths: Iterable[Path]) -> dict[str, str]: + remappings = {} + for path in paths: + for metadata in self[path]: + if not metadata.import_remap_key or not metadata.import_remap_value: + continue + + remappings[metadata.import_remap_key] = metadata.import_remap_value + + return remappings diff --git a/ape_solidity/_utils.py b/ape_solidity/_utils.py index fe3fdc3..d9ca03d 100644 --- a/ape_solidity/_utils.py +++ b/ape_solidity/_utils.py @@ -30,32 +30,36 @@ class Extension(Enum): def get_import_lines(source_paths: Iterable[Path]) -> dict[Path, list[str]]: imports_dict: dict[Path, list[str]] = {} for filepath in source_paths: - import_set = set() - if not filepath or not filepath.is_file(): - continue + imports_dict[filepath] = get_single_import_lines(filepath) + + return imports_dict - source_lines = filepath.read_text().splitlines() - num_lines = len(source_lines) - for line_number, ln in enumerate(source_lines): - if not ln.startswith("import"): - continue - import_str = ln - second_line_number = line_number - while ";" not in import_str: - second_line_number += 1 - if second_line_number >= num_lines: - raise CompilerError("Import statement missing semicolon.") +def get_single_import_lines(source_path: Path) -> list[str]: + import_set = set() + if not source_path.is_file(): + return [] - next_line = source_lines[second_line_number] - import_str += f" {next_line.strip()}" + source_lines = source_path.read_text(encoding="utf8").splitlines() + num_lines = len(source_lines) + for line_number, ln in enumerate(source_lines): + if not ln.startswith("import"): + continue - import_set.add(import_str) - line_number += 1 + import_str = ln + second_line_number = line_number + while ";" not in import_str: + second_line_number += 1 + if second_line_number >= num_lines: + raise CompilerError("Import statement missing semicolon.") - imports_dict[filepath] = list(import_set) + next_line = source_lines[second_line_number] + import_str += f" {next_line.strip()}" - return imports_dict + import_set.add(import_str) + line_number += 1 + + return list(import_set) def get_pragma_spec_from_path(source_file_path: Union[Path, str]) -> Optional[SpecifierSet]: @@ -72,7 +76,7 @@ def get_pragma_spec_from_path(source_file_path: Union[Path, str]) -> Optional[Sp if not path.is_file(): return None - source_str = path.read_text() + source_str = path.read_text(encoding="utf8") return get_pragma_spec_from_str(source_str) diff --git a/ape_solidity/compiler.py b/ape_solidity/compiler.py index da01c7b..97fee09 100644 --- a/ape_solidity/compiler.py +++ b/ape_solidity/compiler.py @@ -1,13 +1,12 @@ -import os import re from collections import defaultdict -from collections.abc import Iterable, Iterator +from collections.abc import Iterable, Iterator, Sequence from pathlib import Path from typing import Any, Optional, Union from ape.api import CompilerAPI, PluginConfig from ape.contracts import ContractInstance -from ape.exceptions import CompilerError, ConfigError, ContractLogicError, ProjectError +from ape.exceptions import CompilerError, ConfigError, ContractLogicError from ape.logging import logger from ape.managers.project import LocalProject, ProjectManager from ape.types import AddressType, ContractType @@ -30,11 +29,11 @@ from solcx.exceptions import SolcError from solcx.install import get_executable +from ape_solidity._models import ImportRemappingCache, SourceTree from ape_solidity._utils import ( OUTPUT_SELECTION, Extension, add_commit_hash, - get_import_lines, get_pragma_spec_from_path, get_pragma_spec_from_str, get_versions_can_use, @@ -143,7 +142,7 @@ class SolidityConfig(PluginConfig): def _get_flattened_source(path: Path, name: Optional[str] = None) -> str: name = name or path.name result = f"// File: {name}\n" - result += f"{path.read_text().rstrip()}\n" + result += f"{path.read_text(encoding='utf8').rstrip()}\n" return result @@ -200,6 +199,10 @@ def latest_installed_version(self) -> Optional[Version]: """ return _try_max(self.installed_versions) + @cached_property + def _import_remapping_cache(self) -> ImportRemappingCache: + return ImportRemappingCache() + def _get_configured_version( self, project: Optional[ProjectManager] = None ) -> Optional[Version]: @@ -263,7 +266,7 @@ def add_library(self, *contracts: ContractInstance, project: Optional[ProjectMan pm.update_manifest(contract_types=all_types) def get_versions(self, all_paths: Iterable[Path]) -> set[str]: - _validate_can_compile(all_paths) + all_paths = _validate_can_compile(all_paths) versions = set() for path in all_paths: # Make sure we have the compiler available to compile this @@ -284,105 +287,42 @@ def get_import_remapping(self, project: Optional[ProjectManager] = None) -> dict e.g. `".cache/openzeppelin/4.4.2". """ pm = project or self.local_project - prefix = f"{get_relative_path(pm.contracts_folder, pm.path)}" - - specified = pm.dependencies.install() - - # Ensure .cache folder is ready-to-go. - cache_folder = pm.contracts_folder / ".cache" - cache_folder.mkdir(exist_ok=True, parents=True) - - # Start with explicitly configured remappings. - cfg_remappings: dict[str, str] = { - m.key: m.value for m in pm.config.solidity.import_remapping - } - key_map: dict[str, str] = {} - - def get_cache_id(dep) -> str: - return os.path.sep.join((prefix, ".cache", dep.name, dep.version)) - - def unpack(dep): - # Ensure the dependency is installed. - try: - dep.project - except ProjectError: - # Try to compile anyway. - # Let the compiler fail on its own. - return - - for unpacked_dep in dep.unpack(pm.contracts_folder / ".cache"): - main_key = key_map.get(unpacked_dep.name) - keys = (main_key,) if main_key else (f"@{unpacked_dep.name}", unpacked_dep.name) - for _key in keys: - if _key not in remapping: - remapping[_key] = get_cache_id(unpacked_dep) - # else, was specified or configured more appropriately. - - remapping: dict[str, str] = {} - for key, value in cfg_remappings.items(): - # Check if legacy-style and still accept it. - parts = value.split(os.path.sep) - name = parts[0] - _version = None - if len(parts) > 2: - # Clearly, not pointing at a dependency. - remapping[key] = value - continue - - elif len(parts) == 2: - _version = parts[1] - - if _version is None: - matching_deps = [d for d in pm.dependencies.installed if d.name == name] - if len(matching_deps) == 1: - _version = matching_deps[0].version - else: - # Not obvious if it is pointing at one of these dependencies. - remapping[key] = value - continue - - # Dependency found. Map to it using the provider key. - dependency = pm.dependencies.get_dependency(name, _version) - key_map[dependency.name] = key - unpack(dependency) - - # Add auto-remapped dependencies. - # (Meaning, the dependencies are specified but their remappings - # are not, so we auto-generate default ones). - for dependency in specified: - unpack(dependency) - + # Always get a fresh remapping when calling the top-level method. + remapping = self._import_remapping_cache.get_import_remapping(pm) + # Cache, so all lower-level methods don't have to recalculate. + self._import_remapping_cache.add(pm, remapping) return remapping def get_compiler_settings( self, contract_filepaths: Iterable[Path], project: Optional[ProjectManager] = None, **kwargs ) -> dict[Version, dict]: pm = project or self.local_project - _validate_can_compile(contract_filepaths) - remapping = self.get_import_remapping(project=pm) - imports = self.get_imports_from_remapping(contract_filepaths, remapping, project=pm) - return self._get_settings_from_imports(contract_filepaths, imports, remapping, project=pm) + paths = _validate_can_compile(contract_filepaths) + imports = SourceTree.from_source_files(paths, pm) + return self._get_settings_from_imports(paths, imports, project=pm, **kwargs) def _get_settings_from_imports( self, contract_filepaths: Iterable[Path], - import_map: dict[str, list[str]], - remappings: dict[str, str], + import_tree: SourceTree, project: Optional[ProjectManager] = None, + **kwargs, ): pm = project or self.local_project files_by_solc_version = self.get_version_map_from_imports( - contract_filepaths, import_map, project=pm + contract_filepaths, import_tree, project=pm ) return self._get_settings_from_version_map( - files_by_solc_version, remappings, import_map=import_map, project=pm + files_by_solc_version, + import_tree=import_tree, + project=pm, + **kwargs, ) def _get_settings_from_version_map( self, - version_map: dict, - import_remappings: dict[str, str], - import_map: Optional[dict[str, list[str]]] = None, + version_map: dict[Version, set[Path]], + import_tree: SourceTree, project: Optional[ProjectManager] = None, **kwargs, ) -> dict[Version, dict]: @@ -401,9 +341,7 @@ def _get_settings_from_version_map( }, **kwargs, } - if remappings_used := self._get_used_remappings( - sources, import_remappings, import_map=import_map, project=pm - ): + if remappings_used := import_tree.get_remappings_used(sources): remappings_str = [f"{k}={v}" for k, v in remappings_used.items()] # Standard JSON input requires remappings to be sorted. @@ -423,40 +361,6 @@ def _get_settings_from_version_map( return settings - def _get_used_remappings( - self, - sources: Iterable[Path], - remappings: dict[str, str], - import_map: Optional[dict[str, list[str]]] = None, - project: Optional[ProjectManager] = None, - ) -> dict[str, str]: - pm = project or self.local_project - if not remappings: - # No remappings used at all. - return {} - - cache_path = ( - f"{get_relative_path(pm.contracts_folder.absolute(), pm.path)}{os.path.sep}.cache" - ) - - # Filter out unused import remapping. - result = {} - sources = list(sources) - import_map = import_map or self.get_imports(sources, project=pm) - imports = import_map.values() - - for source_list in imports: - for src in source_list: - if not src.startswith(cache_path): - continue - - parent_key = os.path.sep.join(src.split(os.path.sep)[:3]) - for k, v in remappings.items(): - if parent_key in v: - result[k] = v - - return result - def get_standard_input_json( self, contract_filepaths: Iterable[Path], @@ -465,24 +369,22 @@ def get_standard_input_json( ) -> dict[Version, dict]: pm = project or self.local_project paths = list(contract_filepaths) # Handle if given generator= - remapping = self.get_import_remapping(project=pm) - import_map = self.get_imports_from_remapping(paths, remapping, project=pm) - version_map = self.get_version_map_from_imports(paths, import_map, project=pm) + import_tree = SourceTree.from_source_files(paths, pm) + version_map = self.get_version_map_from_imports(paths, import_tree, project=pm) return self.get_standard_input_json_from_version_map( - version_map, remapping, project=pm, import_map=import_map, **overrides + version_map, project=pm, import_tree=import_tree, **overrides ) def get_standard_input_json_from_version_map( self, version_map: dict[Version, set[Path]], - import_remapping: dict[str, str], - import_map: Optional[dict[str, list[str]]] = None, + import_tree: SourceTree, project: Optional[ProjectManager] = None, **overrides, ): pm = project or self.local_project settings = self._get_settings_from_version_map( - version_map, import_remapping, import_map=import_map, project=pm, **overrides + version_map, import_tree, project=pm, **overrides ) return self.get_standard_input_json_from_settings(settings, version_map, project=pm) @@ -506,6 +408,10 @@ def get_standard_input_json_from_settings( if solc_version >= Version("0.6.9"): arguments["base_path"] = pm.path + vers_settings["outputSelection"] = { + k: v for k, v in vers_settings["outputSelection"].items() if (pm.path / k).is_file() + } + if missing_sources := [ x for x in vers_settings["outputSelection"] if not (pm.path / x).is_file() ]: @@ -521,14 +427,14 @@ def get_standard_input_json_from_settings( "in an ape-config.yaml or using the `ape pm install` command." ) - # Otherwise, we are missing project-level source files for some reason. - # This would only happen if the user passes in unexpected files outside - # of core. - missing_src_str = ", ".join(missing_sources) - raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.") + # Otherwise, we are missing project-level source files for some reason. + # This would only happen if the user passes in unexpected files outside + # of core. + missing_src_str = ", ".join(missing_sources) + raise CompilerError(f"Sources '{missing_src_str}' not found in '{pm.name}'.") sources = { - x: {"content": (pm.path / x).read_text()} + x: {"content": (pm.path / x).read_text(encoding="utf8")} for x in sorted(vers_settings["outputSelection"]) } @@ -567,15 +473,13 @@ def _compile( settings: Optional[dict] = None, ): pm = project or self.local_project - remapping = self.get_import_remapping(project=pm) paths = list(contract_filepaths) # Handle if given generator= - import_map = self.get_imports_from_remapping(paths, remapping, project=pm) - version_map = self.get_version_map_from_imports(paths, import_map, project=pm) + import_tree = SourceTree.from_source_files(paths, pm) + version_map = self.get_version_map_from_imports(paths, import_tree, project=pm) input_jsons = self.get_standard_input_json_from_version_map( version_map, - remapping, + import_tree, project=pm, - import_map=import_map, **(settings or {}), ) contract_versions: dict[str, Version] = {} @@ -752,87 +656,9 @@ def get_imports( project: Optional[ProjectManager] = None, ) -> dict[str, list[str]]: pm = project or self.local_project - remapping = self.get_import_remapping(project=pm) - _validate_can_compile(contract_filepaths) - paths = [x for x in contract_filepaths] # Handle if given generator. - return self.get_imports_from_remapping(paths, remapping, project=pm) - - def get_imports_from_remapping( - self, - paths: Iterable[Path], - remapping: dict[str, str], - project: Optional[ProjectManager] = None, - ) -> dict[str, list[str]]: - pm = project or self.local_project - return self._get_imports(paths, remapping, pm, tracked=set()) # type: ignore - - def _get_imports( - self, - paths: Iterable[Path], - remapping: dict[str, str], - pm: "ProjectManager", - tracked: set[str], - include_raw: bool = False, - ) -> dict[str, Union[dict[str, str], list[str]]]: - result: dict = {} - - for src_path, import_strs in get_import_lines(paths).items(): - source_id = str(get_relative_path(src_path, pm.path)) - if source_id in tracked: - # We have already accumulated imports from this source. - continue - - tracked.add(source_id) - - # Init with all top-level imports. - import_map = { - x: self._import_str_to_source_id(x, src_path, remapping, project=pm) - for x in import_strs - } - import_source_ids = list(set(list(import_map.values()))) - - # NOTE: Add entry even if empty here. - result[source_id] = import_map if include_raw else import_source_ids - - # Add imports of imports. - if not result[source_id]: - # Nothing else imported. - continue - - # Add known imports. - known_imports = {p: result[p] for p in import_source_ids if p in result} - imp_paths = [pm.path / p for p in import_source_ids if p not in result] - unknown_imports = self._get_imports( - imp_paths, - remapping, - pm, - tracked=tracked, - include_raw=include_raw, - ) - sub_imports = {**known_imports, **unknown_imports} - - # All imported sources from imported sources are imported sources. - for sub_set in sub_imports.values(): - if isinstance(sub_set, dict): - for import_str, sub_import in sub_set.items(): - result[source_id][import_str] = sub_import - - else: - for sub_import in sub_set: - if sub_import not in result[source_id]: - result[source_id].append(sub_import) - - # Keep sorted. - if include_raw: - result[source_id] = sorted((result[source_id]), key=lambda x: x[1]) - else: - result[source_id] = sorted((result[source_id])) - - # Combine results. This ends up like a tree-structure. - result = {**result, **sub_imports} - - # Sort final keys and import lists for more predictable compiler behavior. - return {k: result[k] for k in sorted(result.keys())} + paths = _validate_can_compile(contract_filepaths) + tree = SourceTree.from_source_files(paths, pm) + return tree.model_dump(mode="json")["import_statements"] def get_version_map( self, @@ -846,15 +672,15 @@ def get_version_map( else [p for p in contract_filepaths] ) _validate_can_compile(paths) - imports = self.get_imports(paths, project=pm) - return self.get_version_map_from_imports(paths, imports, project=pm) + import_tree = SourceTree.from_source_files(paths, pm) + return self.get_version_map_from_imports(paths, import_tree, project=pm) def get_version_map_from_imports( self, contract_filepaths: Union[Path, Iterable[Path]], - import_map: dict[str, list[str]], + import_tree: SourceTree, project: Optional[ProjectManager] = None, - ): + ) -> dict[Version, set[Path]]: pm = project or self.local_project paths = ( [contract_filepaths] @@ -865,11 +691,10 @@ def get_version_map_from_imports( # Add imported source files to list of contracts to compile. for source_path in paths: - source_id = f"{get_relative_path(source_path, pm.path)}" - if source_id not in import_map or len(import_map[source_id]) == 0: + if source_path not in import_tree or len(import_tree[source_path]) == 0: continue - import_set = {pm.path / src_id for src_id in import_map[source_id]} + import_set = import_tree.get_imported_paths(source_path) path_set = path_set.union(import_set) # Use specified version if given one @@ -893,9 +718,7 @@ def get_version_map_from_imports( files_by_solc_version: dict[Version, set[Path]] = {} for source_file_path in path_set: solc_version = self._get_best_version(source_file_path, pragma_map) - imported_source_paths = self._get_imported_source_paths( - source_file_path, pm.path, import_map - ) + imported_source_paths = import_tree.get_imported_paths(source_file_path) for imported_source_path in imported_source_paths: if imported_source_path not in pragma_map: @@ -961,31 +784,6 @@ def get_version_map_from_imports( # is more predictable. Also, remove any lingering empties. return {k: result[k] for k in sorted(result) if result[k]} - def _get_imported_source_paths( - self, - path: Path, - base_path: Path, - imports: dict, - source_ids_checked: Optional[list[str]] = None, - ) -> set[Path]: - source_ids_checked = source_ids_checked or [] - source_identifier = str(get_relative_path(path, base_path)) - if source_identifier in source_ids_checked: - # Already got this source's imports - return set() - - source_ids_checked.append(source_identifier) - import_file_paths = [base_path / i for i in imports.get(source_identifier, []) if i] - return_set = {i for i in import_file_paths} - for import_path in import_file_paths: - indirect_imports = self._get_imported_source_paths( - import_path, base_path, imports, source_ids_checked=source_ids_checked - ) - for indirect_import in indirect_imports: - return_set.add(indirect_import) - - return return_set - def _get_pramga_spec_from_str(self, source_str: str) -> Optional[SpecifierSet]: if not (pragma_spec := get_pragma_spec_from_str(source_str)): return None @@ -1104,37 +902,31 @@ def enrich_error(self, err: ContractLogicError) -> ContractLogicError: def _flatten_source( self, path: Union[Path, str], + import_tree: SourceTree, project: Optional[ProjectManager] = None, raw_import_name: Optional[str] = None, handled: Optional[set[str]] = None, ) -> str: pm = project or self.local_project handled = handled or set() - path = Path(path) source_id = f"{get_relative_path(path, pm.path)}" if path.is_absolute() else f"{path}" - handled.add(source_id) - remapping = self.get_import_remapping(project=project) - imports = self._get_imports((path,), remapping, pm, tracked=set(), include_raw=True) - relevant_imports = imports.get(source_id, {}) + relevant_imports = sorted(list(import_tree[path]), key=lambda x: x.raw_value) final_source = "" - - # type-ignore note: we know it is a dict because of `include_raw=True`. - import_items = relevant_imports.items() # type: ignore - - import_iter = sorted(import_items, key=lambda x: f"{x[1]}{x[0]}") - for import_str, source_id in import_iter: - if source_id in handled: + for import_metadata in relevant_imports: + if import_metadata.source_id in handled: + continue + elif not (sub_path := import_metadata.path): continue - sub_import_name = import_str.replace("import ", "").strip(" \n\t;\"'") sub_source = self._flatten_source( - pm.path / source_id, + sub_path, + import_tree, project=pm, - raw_import_name=sub_import_name, handled=handled, + raw_import_name=import_metadata.raw_value, ) final_source += sub_source @@ -1149,11 +941,13 @@ def _flatten_source( def flatten_contract( self, path: Path, project: Optional[ProjectManager] = None, **kwargs ) -> Content: - res = self._flatten_source(path, project=project) + pm = project or self.local_project + tree = SourceTree.from_source_files((path,), pm) + res = self._flatten_source(path, tree, project=pm) res = remove_imports(res) res = process_licenses(res) res = remove_version_pragmas(res) - pragma = get_first_version_pragma(path.read_text()) + pragma = get_first_version_pragma(path.read_text(encoding="utf8")) res = "\n".join([pragma, res]) # Simple auto-format. @@ -1168,7 +962,6 @@ def _import_str_to_source_id( self, _import_str: str, source_path: Path, - import_remapping: dict[str, str], project: Optional[ProjectManager] = None, ) -> str: pm = project or self.local_project @@ -1189,6 +982,7 @@ def _import_str_to_source_id( valid_matches: list[tuple[str, str]] = [] import_remap_key = None base_path = None + import_remapping = self._import_remapping_cache[pm] for check_remap_key, check_remap_value in import_remapping.items(): if check_remap_key not in import_str_value: continue @@ -1395,9 +1189,15 @@ def _try_max(ls: list[Any]): return max(ls) if ls else None -def _validate_can_compile(paths: Iterable[Path]): - extensions = {get_full_extension(p): p for p in paths if p} +def _validate_can_compile(paths: Iterable[Path]) -> Sequence[Path]: + path_ls = [] + valid_extensions = [e.value for e in Extension] + + for path in paths: + ext = get_full_extension(path) + if ext not in valid_extensions: + raise CompilerError(f"Unable to compile '{path.name}' using Solidity compiler.") + + path_ls.append(path) - for ext, file in extensions.items(): - if ext not in [e.value for e in Extension]: - raise CompilerError(f"Unable to compile '{file.name}' using Solidity compiler.") + return path_ls diff --git a/setup.py b/setup.py index 9e6540b..7d2b832 100644 --- a/setup.py +++ b/setup.py @@ -12,12 +12,11 @@ "pytest-benchmark", # For performance tests ], "lint": [ - "black>=24.4.2,<25", # Auto-formatter and linter - "mypy>=1.10.0,<2", # Static type analyzer + "black>=24.8.0,<25", # Auto-formatter and linter + "mypy>=1.11.1,<2", # Static type analyzer "types-requests", # Needed for mypy type shed "types-setuptools", # Needed for mypy type shed - "types-pkg-resources", # Needed for type checking tests - "flake8>=7.0.0,<8", # Style linter + "flake8>=7.1.1,<8", # Style linter "isort>=5.13.2,<6", # Import sorting linter "mdformat>=0.7.17", # Auto-formatter for markdown "mdformat-gfm>=0.3.5", # Needed for formatting GitHub-flavored markdown diff --git a/tests/test_cli.py b/tests/test_cli.py index 1e719ad..495cbe9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -8,6 +8,67 @@ pragma solidity ^0.8.4; // SPDX-License-Identifier: MIT +// File: ./././././././././././././././././././././././././././././././././././MissingPragma.sol + +contract MissingPragma { + function foo() pure public returns(bool) { + return true; + } +} +// File: ./NumerousDefinitions.sol + +struct Struct0 { + string name; + uint value; +} + +struct Struct1 { + string name; + uint value; +} + +struct Struct2 { + string name; + uint value; +} + +struct Struct3 { + string name; + uint value; +} + +struct Struct4 { + string name; + uint value; +} + +struct Struct5 { + string name; + uint value; +} + +contract NumerousDefinitions { + function foo() pure public returns(bool) { + return true; + } +} +// File: ./Source.extra.ext.sol + +// Showing sources with extra extensions are by default excluded, +// unless used as an import somewhere in a non-excluded source. +contract SourceExtraExt { + function foo() pure public returns(bool) { + return true; + } +} +// File: ./subfolder/Relativecontract.sol + +contract Relativecontract { + + function foo() pure public returns(bool) { + return true; + } +} // File: @browniedependency/contracts/BrownieContract.sol contract CompilingContract { @@ -23,7 +84,7 @@ } } -// File: @dependency/contracts/Dependency.sol" as Depend2 +// File: @dependency/contracts/Dependency.sol struct DependencyStruct { string name; @@ -56,7 +117,7 @@ contract Enum { enum Operation {Call, DelegateCall} } -// File: { MyStruct } from "contracts/CompilesOnce.sol +// File: contracts/CompilesOnce.sol struct MyStruct { string name; @@ -72,67 +133,6 @@ return true; } } -// File: ./././././././././././././././././././././././././././././././././././MissingPragma.sol - -contract MissingPragma { - function foo() pure public returns(bool) { - return true; - } -} -// File: { Struct0, Struct1, Struct2, Struct3, Struct4, Struct5 } from "./NumerousDefinitions.sol - -struct Struct0 { - string name; - uint value; -} - -struct Struct1 { - string name; - uint value; -} - -struct Struct2 { - string name; - uint value; -} - -struct Struct3 { - string name; - uint value; -} - -struct Struct4 { - string name; - uint value; -} - -struct Struct5 { - string name; - uint value; -} - -contract NumerousDefinitions { - function foo() pure public returns(bool) { - return true; - } -} -// File: ./Source.extra.ext.sol - -// Showing sources with extra extensions are by default excluded, -// unless used as an import somewhere in a non-excluded source. -contract SourceExtraExt { - function foo() pure public returns(bool) { - return true; - } -} -// File: ./subfolder/Relativecontract.sol - -contract Relativecontract { - - function foo() pure public returns(bool) { - return true; - } -} // File: Imports.sol @@ -152,11 +152,12 @@ def test_cli_flatten(project, cli_runner): - path = project.contracts_folder / "Imports.sol" + filename = "Imports.sol" + path = project.contracts_folder / filename arguments = ["flatten", str(path)] end = ("--project", str(project.path)) with create_tempdir() as tmpdir: - file = tmpdir / "Imports.sol" + file = tmpdir / filename arguments.extend([str(file), *end]) result = cli_runner.invoke(cli, arguments, catch_exceptions=False) assert result.exit_code == 0, result.stderr_bytes diff --git a/tests/test_compiler.py b/tests/test_compiler.py index bf79a76..7ae4713 100644 --- a/tests/test_compiler.py +++ b/tests/test_compiler.py @@ -67,6 +67,9 @@ def test_get_import_remapping_handles_config(project, compiler): # Show other dependency-deduced remappings still work. assert actual["@browniedependency"] == "contracts/.cache/browniedependency/local" + # Clear remapping, to return to regular config values. + compiler.__dict__.pop("_import_remapping_cache", None) + def test_get_imports(project, compiler): source_id = "contracts/ImportSourceWithEqualSignVersion.sol" @@ -101,7 +104,7 @@ def test_get_imports_indirect(project, compiler): assert source_id in actual actual_str = ", ".join(list(actual[source_id])) for ex in expected: - assert ex in actual[source_id], f"{ex} not in {actual_str}" + assert ex in actual[source_id], f"{ex} WAS NOT found in {actual_str}" def test_get_imports_complex(project, compiler): @@ -146,14 +149,21 @@ def test_get_imports_dependencies(project, compiler): path = project.sources.lookup(source_id) import_ls = compiler.get_imports((path,), project=project) actual = import_ls[source_id] - token_path = "contracts/.cache/openzeppelin/4.5.0/contracts/token" + + # NOTE: Both Yearn-vaults master branch and yearn-vaults 0.4.5 + # use OpenZeppelin 4.7.1. However, the root project for these + # tests uses OpenZeppelin 4.5.0. This proves we are handling + # dependencies-of-dependencies correctly. + + token_path = "contracts/.cache/openzeppelin/4.7.1/contracts/token" expected = [ f"{token_path}/ERC20/ERC20.sol", f"{token_path}/ERC20/IERC20.sol", f"{token_path}/ERC20/extensions/IERC20Metadata.sol", + f"{token_path}/ERC20/extensions/draft-IERC20Permit.sol", f"{token_path}/ERC20/utils/SafeERC20.sol", - "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Address.sol", - "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Context.sol", + "contracts/.cache/openzeppelin/4.7.1/contracts/utils/Address.sol", + "contracts/.cache/openzeppelin/4.7.1/contracts/utils/Context.sol", "contracts/.cache/vault/v0.4.5/contracts/BaseStrategy.sol", "contracts/.cache/vaultmain/master/contracts/BaseStrategy.sol", ] @@ -269,14 +279,15 @@ def test_get_version_map_dependencies(project, compiler): older = versions[0] # Via ImportOlderDependency latest = versions[1] # via UseYearn - oz_token = "contracts/.cache/openzeppelin/4.5.0/contracts/token" + oz_token = "contracts/.cache/openzeppelin/4.7.1/contracts/token" expected_latest_source_ids = [ f"{oz_token}/ERC20/ERC20.sol", f"{oz_token}/ERC20/IERC20.sol", f"{oz_token}/ERC20/extensions/IERC20Metadata.sol", + f"{oz_token}/ERC20/extensions/draft-IERC20Permit.sol", f"{oz_token}/ERC20/utils/SafeERC20.sol", - "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Address.sol", - "contracts/.cache/openzeppelin/4.5.0/contracts/utils/Context.sol", + "contracts/.cache/openzeppelin/4.7.1/contracts/utils/Address.sol", + "contracts/.cache/openzeppelin/4.7.1/contracts/utils/Context.sol", "contracts/.cache/vault/v0.4.5/contracts/BaseStrategy.sol", "contracts/.cache/vaultmain/master/contracts/BaseStrategy.sol", source_id, @@ -375,18 +386,15 @@ def test_get_compiler_settings(project, compiler): assert settings["optimizer"] == {"enabled": True, "runs": 190} # NOTE: These should be sorted! - assert settings["remappings"] == [ + expected_remapping = [ "@browniedependency=contracts/.cache/browniedependency/local", "@dependency=contracts/.cache/dependency/local", "@dependencyofdependency=contracts/.cache/dependencyofdependency/local", # This remapping below was auto-corrected because imports were excluding contracts/ suffix. "@noncompilingdependency=contracts/.cache/noncompilingdependency/local/contracts", "@safe=contracts/.cache/safe/1.3.0", - "browniedependency=contracts/.cache/browniedependency/local", - "dependency=contracts/.cache/dependency/local", - "dependencyofdependency=contracts/.cache/dependencyofdependency/local", - "safe=contracts/.cache/safe/1.3.0", ] + assert settings["remappings"] == expected_remapping # Set in config. assert settings["evmVersion"] == "constantinople" @@ -482,9 +490,16 @@ def test_compile_performance(benchmark, compiler, project): args=((path,),), kwargs={"project": project}, rounds=1, + warmup_rounds=1, ) assert len(result) > 0 + # Currently seeing '~0.08; on macOS locally. + # Was seeing '~0.68' before https://github.com/ApeWorX/ape-solidity/pull/151 + threshold = 0.2 + + assert benchmark.stats["median"] < threshold + def test_compile_multiple_definitions_in_source(project, compiler): """ @@ -711,7 +726,7 @@ def test_flatten(mocker, project, compiler): flattened_source_path = base_expected / "ImportingLessConstrainedVersionFlat.sol" actual = str(flattened_source) - expected = str(flattened_source_path.read_text()) + expected = str(flattened_source_path.read_text(encoding="utf8")) assert actual == expected