-
Notifications
You must be signed in to change notification settings - Fork 3
Adding warning logs for potentially unsafe masking/unmasking calls #447
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
knassre-bodo
wants to merge
4
commits into
kian/mask_server_rewrite
Choose a base branch
from
kian/unmask_warning_logs
base: kian/mask_server_rewrite
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
4 commits
Select commit
Hold shift + click to select a range
74b8824
Initial implementation added, as well as early testing
knassre-bodo 7de68bd
Added bubbleprop tests and more warning log tests [RUN CI]
knassre-bodo d4d2b29
Added docstrings/comments
knassre-bodo 10f40da
Merge branch 'kian/mask_server_rewrite' into kian/unmask_warning_logs
knassre-bodo File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
283 changes: 283 additions & 0 deletions
283
pydough/conversion/masking_critical_detection_visitor.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,283 @@ | ||
| """ | ||
| Logic for detecting mask/unmask calls within the final relational plan that will | ||
| cause a critical logical error if the user does not have permission to make the | ||
| call, and logging warnings for those calls. | ||
| """ | ||
|
|
||
| __all__ = ["MaskingCriticalDetectionVisitor"] | ||
|
|
||
| import pydough.pydough_operators as pydop | ||
| from pydough.logger import get_logger | ||
| from pydough.relational import ( | ||
| Aggregate, | ||
| CallExpression, | ||
| ColumnReference, | ||
| CorrelatedReference, | ||
| EmptySingleton, | ||
| Filter, | ||
| Join, | ||
| Limit, | ||
| LiteralExpression, | ||
| Project, | ||
| RelationalExpression, | ||
| RelationalExpressionVisitor, | ||
| RelationalNode, | ||
| RelationalRoot, | ||
| RelationalVisitor, | ||
| Scan, | ||
| WindowCallExpression, | ||
| ) | ||
| from pydough.relational.rel_util import add_input_name | ||
|
|
||
|
|
||
| class MaskingCriticalDetectionExpressionVisitor(RelationalExpressionVisitor): | ||
| """ | ||
| A visitor to detect mask/unmask calls within expressions based on which | ||
| columns from the input relational node depend on mask/unmask calls. After | ||
| calling accept with an expression, the stack will contain a singleton list | ||
| with a set of (column_name, is_unmask) tuples representing the table | ||
| columns that the expression depends on via mask/unmask calls, and whether | ||
| they depend on masking or unmasking. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self.input_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {} | ||
| self.stack: list[set[tuple[str, bool]]] = [] | ||
|
|
||
| def reset(self) -> None: | ||
| self.stack = [] | ||
|
|
||
| def visit_call_expression(self, expr: CallExpression) -> None: | ||
| # Aggregate the dependencies from all input expressions | ||
| dependencies: set[tuple[str, bool]] = set() | ||
| for input_expr in expr.inputs: | ||
| input_expr.accept(self) | ||
| dependencies.update(self.stack.pop()) | ||
| # If this call expression is a mask/unmask operation, add the relevant | ||
| # column dependency. | ||
| if isinstance(expr.op, pydop.MaskedExpressionFunctionOperator): | ||
| dependencies.add( | ||
| ( | ||
| f"{expr.op.table_path}.{expr.op.masking_metadata.column_name}", | ||
| expr.op.is_unmask, | ||
| ) | ||
| ) | ||
| self.stack.append(dependencies) | ||
|
|
||
| def visit_window_expression(self, expr: WindowCallExpression) -> None: | ||
| # Aggregate the dependencies from all input, partition, and order | ||
| # expressions. | ||
| dependencies: set[tuple[str, bool]] = set() | ||
| for input_expr in expr.inputs: | ||
| input_expr.accept(self) | ||
| dependencies.update(self.stack.pop()) | ||
| for partition_expr in expr.partition_inputs: | ||
| partition_expr.accept(self) | ||
| dependencies.update(self.stack.pop()) | ||
| for order_expr in expr.order_inputs: | ||
| order_expr.expr.accept(self) | ||
| dependencies.update(self.stack.pop()) | ||
| self.stack.append(dependencies) | ||
|
|
||
| def visit_column_reference(self, column_reference: ColumnReference) -> None: | ||
| # Retrieve the dependencies for this column from the input dependencies. | ||
| self.stack.append(self.input_dependencies.get(column_reference, set())) | ||
|
|
||
| def visit_correlated_reference( | ||
| self, correlated_reference: CorrelatedReference | ||
| ) -> None: | ||
| # Correlated references have no dependencies on masking/unmasking. | ||
| self.stack.append(set()) | ||
|
|
||
| def visit_literal_expression(self, literal_expression: LiteralExpression) -> None: | ||
| # Literal expressions have no dependencies on masking/unmasking. | ||
| self.stack.append(set()) | ||
|
|
||
|
|
||
| class MaskingCriticalDetectionVisitor(RelationalVisitor): | ||
| """ | ||
| The main visitor which traverses the relational tree, inferring which | ||
| columns depending on mask/unmask calls, propagating them upward through | ||
| the plan, and logging warnings for any mask/unmask calls that are | ||
| critical to the output of the query. | ||
| """ | ||
|
|
||
| def __init__(self) -> None: | ||
| self.critical_mask_columns: set[str] = set() | ||
| """ | ||
| The set of fully qualified column names where a MASK operation on the | ||
| column is critical to the output of the query. | ||
| """ | ||
|
|
||
| self.critical_unmask_columns: set[str] = set() | ||
|
|
||
| """ | ||
| The set of fully qualified column names where an UNMASK operation on the | ||
| column is critical to the output of the query. | ||
| """ | ||
|
|
||
| self.expression_visitor = MaskingCriticalDetectionExpressionVisitor() | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need type hint for this? |
||
| """ | ||
| The expression visitor used to detect mask/unmask dependencies within | ||
| expressions. | ||
| """ | ||
|
|
||
| self.stack: list[dict[RelationalExpression, set[tuple[str, bool]]]] = [] | ||
| """ | ||
| The stack of input dependency mappings for each relational node visited. | ||
| Each mapping corresponds to a relational node from one of the inputs to | ||
| the current node, and maps each output expression of the node to the set | ||
| of (column_name, is_unmask) tuples that the expression depends on. | ||
| """ | ||
|
|
||
| def reset(self) -> None: | ||
| self.critical_mask_columns.clear() | ||
| self.critical_unmask_columns.clear() | ||
| self.stack.clear() | ||
| self.expression_visitor.reset() | ||
|
|
||
| def visit_inputs(self, node: RelationalNode) -> None: | ||
| """ | ||
| Generic logic to visit all input nodes of a relational node, and build | ||
| the input dependency mapping for the current node that will be given to | ||
| the expression visitor so it knows which column references have | ||
| dependencies on mask/unmask calls. | ||
| """ | ||
| input_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {} | ||
|
|
||
| # Loop over all of the input nodes and recursively visit them, | ||
| # extracting their dependencies from the stack. | ||
| for idx, input_node in enumerate(node.inputs): | ||
| input_node.accept(self) | ||
| dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = ( | ||
| self.stack.pop() | ||
| ) | ||
| # If the node has 1 input, then its dependencies are what should | ||
| # be used. | ||
| if len(node.inputs) == 1: | ||
| input_dependencies = dependencies | ||
|
|
||
| # Otherwise, we need to map the dependencies to the appropriate | ||
| # input alias for the node. | ||
| else: | ||
| alias: str | None = node.default_input_aliases[idx] | ||
| for expr, deps in dependencies.items(): | ||
| input_dependencies[add_input_name(expr, alias)] = deps | ||
|
|
||
| # Register the unmask/mask call dependencies from all inputs to this | ||
| # node with the expression visitor. | ||
| self.expression_visitor.input_dependencies = input_dependencies | ||
|
|
||
| def find_critical_dependencies(self, expr: RelationalExpression) -> None: | ||
| """ | ||
| Takes in an expression used in a critical manner (join condition, | ||
| filter condition, aggregate key, or ordering key for a root/limit), and | ||
| feeds it to the expression visitor to determine if it has any mask/unmask | ||
| call dependencies. If it does, the relevant columns are added to the | ||
| critical mask/unmask column sets. | ||
|
|
||
| Args: | ||
| `expr`: The expression to analyze for any mask/unmask dependencies. | ||
| """ | ||
| expr.accept(self.expression_visitor) | ||
| expr_dependencies: set[tuple[str, bool]] = self.expression_visitor.stack.pop() | ||
| for col_name, is_unmask in expr_dependencies: | ||
| if is_unmask: | ||
| self.critical_unmask_columns.add(col_name) | ||
| else: | ||
| self.critical_mask_columns.add(col_name) | ||
|
|
||
| def add_output_dependencies(self, node: RelationalNode) -> None: | ||
| """ | ||
| Uses the expression visitor to determine the mask/unmask dependencies | ||
| for each output expression of the given relational node, and pushes the | ||
| resulting mapping onto the stack. | ||
|
|
||
| Args: | ||
| `node`: The relational node whose output columns are having their | ||
| dependencies determined. | ||
| """ | ||
| out_dependencies: dict[RelationalExpression, set[tuple[str, bool]]] = {} | ||
| for name, expr in node.columns.items(): | ||
| expr.accept(self.expression_visitor) | ||
| out_dependencies[ColumnReference(name, expr.data_type)] = ( | ||
| self.expression_visitor.stack.pop() | ||
| ) | ||
| self.stack.append(out_dependencies) | ||
|
|
||
| def log_critical_calls(self) -> None: | ||
| """ | ||
| Logs warnings for all critical mask/unmask calls detected during the | ||
| traversal of the tree. | ||
|
|
||
| This should be called once after the visitor has traversed the entire | ||
| relational plan. | ||
| """ | ||
| logger = get_logger() | ||
| for column in self.critical_mask_columns: | ||
| logger.warning( | ||
| f"Query will not produce a valid output unless user has permission to mask column `{column}`" | ||
| ) | ||
| for column in self.critical_unmask_columns: | ||
| logger.warning( | ||
| f"Query will not produce a valid output unless user has permission to unmask column `{column}`" | ||
| ) | ||
|
|
||
| # Clean up the visitor afterwards, to avoid accidentally logging a | ||
| # duplicate. | ||
| self.reset() | ||
|
|
||
| def visit_project(self, project: Project) -> None: | ||
| # Projects simply propagate dependencies from their inputs. | ||
| self.visit_inputs(project) | ||
| self.add_output_dependencies(project) | ||
|
|
||
| def visit_filter(self, filter: Filter) -> None: | ||
| # Filter nodes propagate dependencies from their inputs, but also | ||
| # analyze their condition for critical dependencies. | ||
| self.visit_inputs(filter) | ||
| self.find_critical_dependencies(filter.condition) | ||
| self.add_output_dependencies(filter) | ||
|
|
||
| def visit_join(self, join: Join) -> None: | ||
| # Filter nodes propagate dependencies from their inputs, but also | ||
| # analyze their condition for critical dependencies. | ||
| self.visit_inputs(join) | ||
| self.find_critical_dependencies(join.condition) | ||
| self.add_output_dependencies(join) | ||
|
|
||
| def visit_aggregate(self, aggregate: Aggregate) -> None: | ||
| # Aggregate nodes propagate dependencies from their inputs, but also | ||
| # analyze their aggregation keys for critical dependencies. | ||
| self.visit_inputs(aggregate) | ||
| for agg_key in aggregate.keys.values(): | ||
| self.find_critical_dependencies(agg_key) | ||
| self.add_output_dependencies(aggregate) | ||
|
|
||
| def visit_limit(self, limit: Limit) -> None: | ||
| # Limit nodes propagate dependencies from their inputs, but also | ||
| # analyze their ordering keys for critical dependencies. | ||
| self.visit_inputs(limit) | ||
| for order_expr in limit.orderings: | ||
| self.find_critical_dependencies(order_expr.expr) | ||
| self.add_output_dependencies(limit) | ||
|
|
||
| def visit_root(self, root: RelationalRoot) -> None: | ||
| # Root nodes propagate dependencies from their inputs, but also | ||
| # analyze their ordering keys for critical dependencies. | ||
| self.visit_inputs(root) | ||
| for order_expr in root.orderings: | ||
| self.find_critical_dependencies(order_expr.expr) | ||
| self.add_output_dependencies(root) | ||
|
|
||
| def visit_scan(self, scan: Scan) -> None: | ||
| # Scan nodes have no inputs, so they propagate dependencies based on | ||
| # their columns relative to an empty input. | ||
| self.expression_visitor.input_dependencies = {} | ||
| self.add_output_dependencies(scan) | ||
|
|
||
| def visit_empty_singleton(self, empty_singleton: EmptySingleton) -> None: | ||
| # Empty singletons have no inputs, so they propagate dependencies based | ||
| # on their columns relative to an empty input. | ||
| self.expression_visitor.input_dependencies = {} | ||
| self.add_output_dependencies(empty_singleton) | ||
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.