Skip to content

Commit

Permalink
fix: use more generic types in Target
Browse files Browse the repository at this point in the history
  • Loading branch information
tysmith committed Feb 20, 2025
1 parent c3703e4 commit ffc9284
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 8 deletions.
4 changes: 2 additions & 2 deletions src/grizzly/target/puppet_target.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
from .target_monitor import TargetMonitor

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping

from sapphire import CertificateBundle

Expand Down Expand Up @@ -337,7 +337,7 @@ def log_size(self) -> int:
total += length
return total

def merge_environment(self, extra: dict[str, str]) -> None:
def merge_environment(self, extra: Mapping[str, str]) -> None:
output = dict(extra)
if self.environ:
# prioritize existing environment variables
Expand Down
12 changes: 6 additions & 6 deletions src/grizzly/target/target.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
from .assets import AssetManager

if TYPE_CHECKING:
from collections.abc import Iterable
from collections.abc import Iterable, Mapping
from pathlib import Path

from ..common.report import Report
Expand Down Expand Up @@ -88,7 +88,7 @@ def __init__(
self._lock = Lock()
self.binary = binary
self.certs = certs
self.environ = self.scan_environment(dict(environ), self.TRACKED_ENVVARS)
self.environ = self.scan_environment(environ, self.TRACKED_ENVVARS)
self.launch_timeout = launch_timeout
self.log_limit = log_limit
self.memory_limit = memory_limit
Expand Down Expand Up @@ -268,7 +268,7 @@ def log_size(self) -> int:
"""

@abstractmethod
def merge_environment(self, extra: dict[str, str]) -> None:
def merge_environment(self, extra: Mapping[str, str]) -> None:
"""Add to existing environment.
Args:
Expand Down Expand Up @@ -314,8 +314,8 @@ def reverse(self, remote: int, local: int) -> None:

@staticmethod
def scan_environment(
env: dict[str, str],
include: tuple[str, ...] | None,
env: Mapping[str, str],
include: Iterable[str],
) -> dict[str, str]:
"""Scan environment for tracked environment variables.
Expand All @@ -326,7 +326,7 @@ def scan_environment(
Returns:
Tracked variables found in scanned environment.
"""
return {var: env[var] for var in include if var in env} if include else {}
return {var: env[var] for var in include if var in env}

@abstractmethod
def save_logs(self, dst: Path) -> None:
Expand Down

0 comments on commit ffc9284

Please sign in to comment.