Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 69 additions & 51 deletions metadata-ingestion/src/datahub/masking/masking_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -452,73 +452,94 @@ def __getattr__(self, name):
return getattr(self._original, name)


def _update_existing_handlers() -> None:
"""Update all existing logging handlers to use wrapped streams."""
updated_count = 0

# Get all loggers (including root and all named loggers)
all_loggers = [logging.getLogger()] + [
logging.getLogger(name) for name in logging.root.manager.loggerDict
]

for log in all_loggers:
if not isinstance(log, logging.Logger):
# Skip PlaceHolder objects in logger dict
def _iter_all_loggers() -> list[logging.Logger]:
"""Root logger plus every initialized named logger (skipping PlaceHolders)."""
loggers: list[logging.Logger] = [logging.getLogger()]
# .get() per snapshotted key: a logger may be removed between snapshot and access.
for name in list(logging.root.manager.loggerDict.keys()):
obj = logging.root.manager.loggerDict.get(name)
if isinstance(obj, logging.Logger):
loggers.append(obj)
return loggers


def _add_filter_to_existing_handlers(masking_filter: SecretMaskingFilter) -> None:
"""Attach the masking filter to all existing handlers.

Masking lives on handlers, not the logger: Python skips logger-level filters
for records propagated from child loggers, so a root-logger filter would miss
almost everything. A handler filter sees every record reaching that output and
masks it in place, without touching the handler's stream. (Repointing streams
instead loops forever under celery -- see install_masking_filter.)

Skip datahub.masking.* loggers: they log to the original stderr by design and
carry no secrets, so filtering them only risks re-entrancy.
"""
added = 0
for log in _iter_all_loggers():
if log.name.startswith("datahub.masking"):
continue
# Copy: handlers may be added/removed by other threads during iteration.
for handler in list(log.handlers):
if not any(isinstance(f, SecretMaskingFilter) for f in handler.filters):
handler.addFilter(masking_filter)
added += 1
if added:
logger.debug(f"Installed SecretMaskingFilter on {added} handler(s)")

for handler in log.handlers:
if isinstance(handler, logging.StreamHandler):
# Check if handler is using an unwrapped stream
if hasattr(handler, "stream"):
stream = handler.stream

# If handler's stream is the original unwrapped stdout/stderr,
# update it to use our wrapped version
if not isinstance(stream, StreamMaskingWrapper):
# Check if this is stdout or stderr by comparing the underlying file
try:
if hasattr(stream, "name"):
if stream.name == "<stderr>":
handler.setStream(sys.stderr)
updated_count += 1
elif stream.name == "<stdout>":
handler.setStream(sys.stdout)
updated_count += 1
except Exception:
# If we can't determine the stream, skip it
pass

if updated_count > 0:
logger.debug(f"Updated {updated_count} logging handlers to use wrapped streams")

def _remove_filter_from_existing_handlers() -> None:
"""Remove the masking filter from every handler it was attached to."""
for log in _iter_all_loggers():
for handler in list(log.handlers):
handler.filters = [
f for f in handler.filters if not isinstance(f, SecretMaskingFilter)
]


def install_masking_filter(
secret_registry: Optional[SecretRegistry] = None,
max_message_size: int = 5000,
install_stdout_wrapper: bool = True,
) -> SecretMaskingFilter:
"""Install secret masking filter on root logger and optionally wrap stdout/stderr."""
# Create filter
"""Enable secret masking: install the filter on existing handlers (+ root
logger) and, optionally, wrap sys.stdout/stderr for raw writes.

Masking happens at the handler level (see _add_filter_to_existing_handlers).
Coverage is a snapshot of the handlers present now, so call this AFTER logging
is configured; handlers added later are covered only by a re-install or, for
stdout/stderr, by the stream wrapper.

Fail-open limitation: a handler added to a child logger after install can emit
unmasked (a logger's own handlers run before ancestors'). Not a concern in the
executor, where handlers exist before masking is installed per task.
"""
masking_filter = SecretMaskingFilter(
secret_registry=secret_registry, max_message_size=max_message_size
)

# Install on root logger (affects all loggers)
root_logger = logging.getLogger()

# Check if already installed (avoid duplicates)
# The root-logger filter is just the "already installed?" sentinel (and masks
# records logged directly on root). The real masking is the handler filters
# below, since logger-level filters don't see propagated child-logger records.
existing_filters = [
f for f in root_logger.filters if isinstance(f, SecretMaskingFilter)
]

if existing_filters:
logger.debug("SecretMaskingFilter already installed on root logger")
return existing_filters[0]
# Already installed: re-scan to cover handlers added since (fail-open).
masking_filter = existing_filters[0]
_add_filter_to_existing_handlers(masking_filter)
logger.debug("SecretMaskingFilter already installed; refreshed handlers")
return masking_filter

root_logger.addFilter(masking_filter)
logger.info("Installed SecretMaskingFilter on root logger")
_add_filter_to_existing_handlers(masking_filter)
logger.info("Installed SecretMaskingFilter on root logger and existing handlers")

# Optionally install stdout/stderr wrapper as backup
# Wrap stdout/stderr only to mask raw writes (print(), C-extension output).
# We do NOT repoint handler streams here: under celery, sys.stderr re-enters
# logging, so a handler pointed at it would recurse infinitely and drop output.
if install_stdout_wrapper:
if not isinstance(sys.stdout, StreamMaskingWrapper):
sys.stdout = StreamMaskingWrapper(sys.stdout, masking_filter)
Expand All @@ -528,22 +549,19 @@ def install_masking_filter(
sys.stderr = StreamMaskingWrapper(sys.stderr, masking_filter)
logger.debug("Wrapped sys.stderr with StreamMaskingWrapper")

# Update all existing logging handlers to use wrapped streams
# Handlers created before masking was initialized will have cached
# references to the original unwrapped stderr/stdout
_update_existing_handlers()

return masking_filter


def uninstall_masking_filter() -> None:
"""Remove secret masking filter from root logger."""
root_logger = logging.getLogger()

# Remove filters
# Remove filter from the root logger and from every handler it was added to
# (symmetric with install_masking_filter).
root_logger.filters = [
f for f in root_logger.filters if not isinstance(f, SecretMaskingFilter)
]
_remove_filter_from_existing_handlers()

# Unwrap stdout/stderr
if isinstance(sys.stdout, StreamMaskingWrapper):
Expand Down
17 changes: 17 additions & 0 deletions metadata-ingestion/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@

# We need our imports to go below the os.environ updates, since mere act
# of importing some datahub modules will load env variables.
from datahub.masking.bootstrap import shutdown_secret_masking # noqa: E402
from datahub.masking.secret_registry import SecretRegistry # noqa: E402
from datahub.testing.pytest_hooks import ( # noqa: F401,E402
load_golden_flags,
pytest_addoption,
Expand All @@ -44,6 +46,21 @@ def mock_time():
yield


@pytest.fixture(autouse=True)
def _reset_secret_masking():
"""Keep process-global secret masking from leaking across tests.

Masking installs a filter on every logging handler and populates a singleton
SecretRegistry. Tests that trigger it (the ingest CLI, initialize_secret_masking,
or SecretStr config validation) don't always tear it down, which would mask
later tests' captured log output. Reset after each test for isolation
(shutdown also clears bootstrap state so a later init re-installs cleanly).
"""
yield
shutdown_secret_masking()
SecretRegistry.reset_instance()


def pytest_ignore_collect(
collection_path: pathlib.Path, config: pytest.Config
) -> Optional[bool]:
Expand Down
119 changes: 67 additions & 52 deletions metadata-ingestion/tests/unit/test_masking_error_recovery.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,9 @@
from datahub.masking.masking_filter import (
SecretMaskingFilter,
StreamMaskingWrapper,
_update_existing_handlers,
_add_filter_to_existing_handlers,
install_masking_filter,
uninstall_masking_filter,
)
from datahub.masking.secret_registry import SecretRegistry

Expand Down Expand Up @@ -279,8 +280,10 @@ def test_wrapper_getattr(self):
assert callable(wrapper.getvalue)


class TestUpdateExistingHandlers:
"""Test _update_existing_handlers function."""
class TestAddFilterToExistingHandlers:
"""Test _add_filter_to_existing_handlers: masking attaches to handlers
without modifying their streams (the celery-safe replacement for the old
stream-redirecting behavior)."""

def setup_method(self):
shutdown_secret_masking()
Expand All @@ -290,79 +293,91 @@ def teardown_method(self):
shutdown_secret_masking()
SecretRegistry.reset_instance()

def test_update_existing_handlers_with_stdout_handler(self):
"""Test that existing stdout handlers are updated."""
# Create a logger with a stdout handler
test_logger = logging.getLogger("test_stdout_update")
def test_filter_added_without_changing_stream(self):
"""The filter is attached to an existing handler and its stream is left
untouched (repointing the stream is what caused the celery deadlock)."""
test_logger = logging.getLogger("test_add_filter_stream")
test_logger.handlers.clear()

stdout_handler = logging.StreamHandler(sys.stdout)
test_logger.addHandler(stdout_handler)
custom_stream = StringIO()
handler = logging.StreamHandler(custom_stream)
test_logger.addHandler(handler)

# Install masking
install_masking_filter(install_stdout_wrapper=True)

# Update handlers
_update_existing_handlers()
assert handler.stream is custom_stream
assert any(isinstance(f, SecretMaskingFilter) for f in handler.filters)

# Cleanup
test_logger.removeHandler(stdout_handler)
test_logger.removeHandler(handler)
test_logger.handlers.clear()

def test_update_existing_handlers_with_stderr_handler(self):
"""Test that existing stderr handlers are updated."""
# Create a logger with a stderr handler
test_logger = logging.getLogger("test_stderr_update")
def test_filter_not_added_twice(self):
"""Calling the helper again must not add a duplicate filter."""
test_logger = logging.getLogger("test_add_filter_idempotent")
test_logger.handlers.clear()
handler = logging.StreamHandler(StringIO())
test_logger.addHandler(handler)

stderr_handler = logging.StreamHandler(sys.stderr)
test_logger.addHandler(stderr_handler)

# Install masking
install_masking_filter(install_stdout_wrapper=True)
masking_filter = install_masking_filter(install_stdout_wrapper=False)
_add_filter_to_existing_handlers(masking_filter)

# Update handlers
_update_existing_handlers()
count = sum(isinstance(f, SecretMaskingFilter) for f in handler.filters)
assert count == 1

# Cleanup
test_logger.removeHandler(stderr_handler)
test_logger.removeHandler(handler)
test_logger.handlers.clear()

def test_update_existing_handlers_with_custom_stream(self):
"""Test that handlers with custom streams are not updated."""
test_logger = logging.getLogger("test_custom_stream")
test_logger.handlers.clear()
def test_masking_namespace_loggers_are_skipped(self):
"""The masking framework's own loggers bypass masking by design."""
masking_logger = logging.getLogger("datahub.masking.test_skip")
masking_logger.handlers.clear()
handler = logging.StreamHandler(StringIO())
masking_logger.addHandler(handler)

custom_stream = StringIO()
custom_handler = logging.StreamHandler(custom_stream)
test_logger.addHandler(custom_handler)
install_masking_filter(install_stdout_wrapper=False)

# Install masking
install_masking_filter(install_stdout_wrapper=True)
assert not any(isinstance(f, SecretMaskingFilter) for f in handler.filters)

# Update handlers
_update_existing_handlers()
masking_logger.removeHandler(handler)
masking_logger.handlers.clear()

# Custom handler should not be updated
assert custom_handler.stream is custom_stream
def test_repeat_install_attaches_to_newly_added_handler(self):
"""A second install must re-scan and cover handlers added after the
first install (masking is fail-open, so missed handlers leak)."""
test_logger = logging.getLogger("test_repeat_install")
test_logger.handlers.clear()
h1 = logging.StreamHandler(StringIO())
test_logger.addHandler(h1)

# Cleanup
test_logger.removeHandler(custom_handler)
install_masking_filter(install_stdout_wrapper=False)
assert any(isinstance(f, SecretMaskingFilter) for f in h1.filters)

# Handler added AFTER the first install.
h2 = logging.StreamHandler(StringIO())
test_logger.addHandler(h2)

install_masking_filter(install_stdout_wrapper=False)
assert any(isinstance(f, SecretMaskingFilter) for f in h2.filters)

test_logger.removeHandler(h1)
test_logger.removeHandler(h2)
test_logger.handlers.clear()

def test_update_existing_handlers_skips_placeholders(self):
"""Test that PlaceHolder objects in logger dict are skipped."""
# This tests the check for isinstance(log, logging.Logger)
# PlaceHolder objects exist in logging.root.manager.loggerDict
# but are not actual Logger instances
def test_uninstall_removes_filter_from_all_handlers(self):
"""Teardown is symmetric: no SecretMaskingFilter remains on any handler."""
test_logger = logging.getLogger("test_uninstall_handlers")
test_logger.handlers.clear()
handler = logging.StreamHandler(StringIO())
test_logger.addHandler(handler)

# Install masking
install_masking_filter(install_stdout_wrapper=True)
install_masking_filter(install_stdout_wrapper=False)
assert any(isinstance(f, SecretMaskingFilter) for f in handler.filters)

# Call update (should handle placeholders gracefully)
_update_existing_handlers()
uninstall_masking_filter()
assert not any(isinstance(f, SecretMaskingFilter) for f in handler.filters)

# Should not raise any errors
test_logger.removeHandler(handler)
test_logger.handlers.clear()


class TestBootstrapErrorHandling:
Expand Down
Loading
Loading