Skip to content

Commit

Permalink
[mypy] nncf/common/pruning (#3198)
Browse files Browse the repository at this point in the history
### Changes

Enable mypy for nncf/common/pruning
  • Loading branch information
AlexanderDokuchaev authored Jan 21, 2025
1 parent 0931072 commit 190006d
Show file tree
Hide file tree
Showing 13 changed files with 181 additions and 149 deletions.
28 changes: 14 additions & 14 deletions nncf/common/pruning/clusterization.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,26 +9,26 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Callable, Dict, Generic, Hashable, List, TypeVar
from typing import Callable, Dict, Generic, Hashable, List, Optional, TypeVar

T = TypeVar("T")


class Cluster(Generic[T]):
"""
Represents element of Сlusterization. Groups together elements.
Represents element of Clusterization. Groups together elements.
"""

def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]):
def __init__(self, cluster_id: int, elements: List[T], nodes_orders: List[int]) -> None:
self.id = cluster_id
self.elements = list(elements)
self.importance = max(nodes_orders)

def clean_cluster(self):
def clean_cluster(self) -> None:
self.elements = []
self.importance = 0

def add_elements(self, elements: List[T], importance: int):
def add_elements(self, elements: List[T], importance: int) -> None:
self.elements.extend(elements)
self.importance = max(self.importance, importance)

Expand All @@ -39,7 +39,7 @@ class Clusterization(Generic[T]):
delete existing one or merge existing clusters.
"""

def __init__(self, id_fn: Callable[[T], Hashable] = None):
def __init__(self, id_fn: Optional[Callable[[T], Hashable]] = None) -> None:
self.clusters: Dict[int, Cluster[T]] = {}
self._element_to_cluster: Dict[Hashable, int] = {}
if id_fn is None:
Expand Down Expand Up @@ -78,7 +78,7 @@ def is_node_in_clusterization(self, node_id: int) -> bool:
"""
return node_id in self._element_to_cluster

def add_cluster(self, cluster: Cluster[T]):
def add_cluster(self, cluster: Cluster[T]) -> None:
"""
Adds provided cluster to clusterization.
Expand All @@ -89,9 +89,9 @@ def add_cluster(self, cluster: Cluster[T]):
raise IndexError("Cluster with index = {} already exist".format(cluster_id))
self.clusters[cluster_id] = cluster
for elt in cluster.elements:
self._element_to_cluster[self._id_fn(elt)] = cluster_id
self._element_to_cluster[self._id_fn(elt)] = cluster_id # type: ignore[no-untyped-call]

def delete_cluster(self, cluster_id: int):
def delete_cluster(self, cluster_id: int) -> None:
"""
Removes cluster with `cluster_id` from clusterization.
Expand All @@ -100,7 +100,7 @@ def delete_cluster(self, cluster_id: int):
if cluster_id not in self.clusters:
raise IndexError("No cluster with index = {} to delete".format(cluster_id))
for elt in self.clusters[cluster_id].elements:
node_id = self._id_fn(elt)
node_id = self._id_fn(elt) # type: ignore[no-untyped-call]
self._element_to_cluster.pop(node_id)
self.clusters.pop(cluster_id)

Expand All @@ -123,7 +123,7 @@ def get_all_nodes(self) -> List[T]:
all_elements.extend(cluster.elements)
return all_elements

def merge_clusters(self, first_id: int, second_id: int):
def merge_clusters(self, first_id: int, second_id: int) -> None:
"""
Merges two clusters with provided ids.
Expand All @@ -135,15 +135,15 @@ def merge_clusters(self, first_id: int, second_id: int):
if cluster_1.importance > cluster_2.importance:
cluster_1.add_elements(cluster_2.elements, cluster_2.importance)
for elt in cluster_2.elements:
self._element_to_cluster[self._id_fn(elt)] = first_id
self._element_to_cluster[self._id_fn(elt)] = first_id # type: ignore[no-untyped-call]
self.clusters.pop(second_id)
else:
cluster_2.add_elements(cluster_1.elements, cluster_1.importance)
for elt in cluster_1.elements:
self._element_to_cluster[self._id_fn(elt)] = second_id
self._element_to_cluster[self._id_fn(elt)] = second_id # type: ignore[no-untyped-call]
self.clusters.pop(first_id)

def merge_list_of_clusters(self, clusters: List[int]):
def merge_list_of_clusters(self, clusters: List[int]) -> None:
"""
Merges provided clusters.
Expand Down
23 changes: 12 additions & 11 deletions nncf/common/pruning/mask_propagation.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List, Optional, Type
from typing import Dict, List, Optional, Set, Type

from nncf.common.graph import NNCFGraph
from nncf.common.pruning.operations import BasePruningOp
Expand Down Expand Up @@ -38,7 +38,7 @@ def __init__(
graph: NNCFGraph,
pruning_operator_metatypes: PruningOperationsMetatypeRegistry,
tensor_processor: Optional[Type[NNCFPruningBaseTensorProcessor]] = None,
):
) -> None:
"""
Initializes MaskPropagationAlgorithm.
Expand All @@ -51,7 +51,7 @@ def __init__(
self._pruning_operator_metatypes = pruning_operator_metatypes
self._tensor_processor = tensor_processor

def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
"""
Returns class of metaop that corresponds to `type_name` type.
Expand All @@ -63,14 +63,14 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
cls = self._pruning_operator_metatypes.registry_dict["stop_propagation_ops"]
return cls

def mask_propagation(self):
def mask_propagation(self) -> None:
"""
Mask propagation in graph:
to propagate masks run method mask_propagation (of metaop of current node) on all nodes in topological order.
"""
for node in self._graph.topological_sort():
cls = self.get_meta_operation_by_type_name(node.node_type)
cls.mask_propagation(node, self._graph, self._tensor_processor)
cls.mask_propagation(node, self._graph, self._tensor_processor) # type: ignore

def symbolic_mask_propagation(
self, prunable_layers_types: List[str], can_prune_after_analysis: Dict[int, PruningAnalysisDecision]
Expand All @@ -96,7 +96,7 @@ def symbolic_mask_propagation(
"""

can_be_closing_convs = self._get_can_closing_convs(prunable_layers_types)
can_prune_by_dim = {k: None for k in can_be_closing_convs}
can_prune_by_dim: Dict[int, PruningAnalysisDecision] = {k: None for k in can_be_closing_convs} # type: ignore
for node in self._graph.topological_sort():
if node.node_id in can_be_closing_convs and can_prune_after_analysis[node.node_id]:
# Set output mask
Expand All @@ -109,15 +109,16 @@ def symbolic_mask_propagation(
input_masks = get_input_masks(node, self._graph)
if any(input_masks):
assert len(input_masks) == 1
input_mask: SymbolicMask = input_masks[0]
input_mask = input_masks[0]
assert isinstance(input_mask, SymbolicMask)

for producer in input_mask.mask_producers:
previously_dims_equal = (
True if can_prune_by_dim[producer.id] is None else can_prune_by_dim[producer.id]
)

is_dims_equal = get_input_channels(node) == input_mask.shape[0]
decision = previously_dims_equal and is_dims_equal
decision = bool(previously_dims_equal and is_dims_equal)
can_prune_by_dim[producer.id] = PruningAnalysisDecision(
decision, PruningAnalysisReason.DIMENSION_MISMATCH
)
Expand All @@ -130,7 +131,7 @@ def symbolic_mask_propagation(
can_prune_by_dim[producer.id] = PruningAnalysisDecision(False, PruningAnalysisReason.LAST_CONV)
# Update decision for nodes which
# have no closing convolution
convs_without_closing_conv = {}
convs_without_closing_conv: Dict[int, PruningAnalysisDecision] = {}
for k, v in can_prune_by_dim.items():
if v is None:
convs_without_closing_conv[k] = PruningAnalysisDecision(
Expand All @@ -144,8 +145,8 @@ def symbolic_mask_propagation(

return can_prune_by_dim

def _get_can_closing_convs(self, prunable_layers_types) -> Dict:
retval = set()
def _get_can_closing_convs(self, prunable_layers_types: List[str]) -> Set[int]:
retval: Set[int] = set()
for node in self._graph.get_all_nodes():
if node.node_type in prunable_layers_types and not (
is_grouped_conv(node) or is_batched_linear(node, self._graph)
Expand Down
22 changes: 13 additions & 9 deletions nncf/common/pruning/model_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import Dict, List
from typing import Dict, List, Optional, Type, cast

from nncf.common.graph import NNCFGraph
from nncf.common.graph import NNCFNode
Expand All @@ -23,14 +23,16 @@
from nncf.common.pruning.utils import is_prunable_depthwise_conv


def get_position(nodes_list: List[NNCFNode], idx: int):
def get_position(nodes_list: List[NNCFNode], idx: int) -> Optional[int]:
for i, node in enumerate(nodes_list):
if node.node_id == idx:
return i
return None


def merge_clusters_for_nodes(nodes_to_merge: List[NNCFNode], clusterization: Clusterization):
def merge_clusters_for_nodes(
nodes_to_merge: List[NNCFNode], clusterization: Clusterization # type:ignore[type-arg]
) -> None:
"""
Merges clusters to which nodes from nodes_to_merge belongs.
Expand Down Expand Up @@ -75,7 +77,7 @@ def cluster_special_ops(
# 0. Initially all nodes is a separate clusters
clusterization = Clusterization[NNCFNode](lambda x: x.node_id)
for i, node in enumerate(all_special_nodes):
cluster = Cluster[NNCFNode](i, [node], [get_position(topologically_sorted_nodes, node.node_id)])
cluster = Cluster[NNCFNode](i, [node], [get_position(topologically_sorted_nodes, node.node_id)]) # type: ignore
clusterization.add_cluster(cluster)

for node in topologically_sorted_nodes:
Expand Down Expand Up @@ -125,7 +127,9 @@ def __init__(
self._pruning_operator_metatypes = pruning_operator_metatypes
self._prune_operations_types = prune_operations_types
pruning_op_metatypes_dict = self._pruning_operator_metatypes.registry_dict
self._stop_propagation_op_metatype = pruning_op_metatypes_dict["stop_propagation_ops"]
self._stop_propagation_op_metatype = cast(
Type[BasePruningOp], pruning_op_metatypes_dict["stop_propagation_ops"]
)
self._concat_op_metatype = pruning_op_metatypes_dict["concat"]

self.can_prune = {idx: True for idx in self.graph.get_all_node_ids()}
Expand All @@ -151,7 +155,7 @@ def node_accept_different_inputs(self, nncf_node: NNCFNode) -> bool:
"""
return nncf_node.node_type in self._concat_op_metatype.get_all_op_aliases()

def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
def get_meta_operation_by_type_name(self, type_name: str) -> Type[BasePruningOp]:
"""
Returns class of metaop that corresponds to `type_name` type.
Expand All @@ -162,7 +166,7 @@ def get_meta_operation_by_type_name(self, type_name: str) -> BasePruningOp:
cls = self._stop_propagation_op_metatype
return cls

def propagate_can_prune_attr_up(self):
def propagate_can_prune_attr_up(self) -> None:
"""
Propagating can_prune attribute in reversed topological order.
This attribute depends on accept_pruned_input and can_prune attributes of output nodes.
Expand All @@ -181,7 +185,7 @@ def propagate_can_prune_attr_up(self):
)
self.can_prune[node.node_id] = outputs_accept_pruned_input and outputs_will_be_pruned

def propagate_can_prune_attr_down(self):
def propagate_can_prune_attr_down(self) -> None:
"""
Propagating can_prune attribute down to fix all branching cases with one pruned and one not pruned
branches.
Expand All @@ -199,7 +203,7 @@ def propagate_can_prune_attr_down(self):
):
self.can_prune[node.node_id] = can_prune

def set_accept_pruned_input_attr(self):
def set_accept_pruned_input_attr(self) -> None:
for nncf_node in self.graph.get_all_nodes():
cls = self.get_meta_operation_by_type_name(nncf_node.node_type)
self.accept_pruned_input[nncf_node.node_id] = cls.accept_pruned_input(nncf_node)
Expand Down
14 changes: 7 additions & 7 deletions nncf/common/pruning/node_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,8 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:

# 2. Clusters for nodes that should be pruned together (taking into account clusters for special ops)
for i, cluster in enumerate(special_ops_clusterization.get_all_clusters()):
all_pruned_inputs = {}
clusters_to_merge = []
all_pruned_inputs: Dict[int, NNCFNode] = {}
clusters_to_merge: List[int] = []

for node in cluster.elements:
sources = get_sources_of_node(node, graph, self._prune_operations_types)
Expand All @@ -116,7 +116,7 @@ def create_pruning_groups(self, graph: NNCFGraph) -> Clusterization[NNCFNode]:
all_pruned_inputs[source_node.node_id] = source_node

if all_pruned_inputs:
cluster = Cluster[NNCFNode](i, all_pruned_inputs.values(), all_pruned_inputs.keys())
cluster = Cluster[NNCFNode](i, list(all_pruned_inputs.values()), list(all_pruned_inputs.keys()))
clusters_to_merge.append(cluster.id)
pruned_nodes_clusterization.add_cluster(cluster)

Expand Down Expand Up @@ -202,7 +202,7 @@ def _get_multiforward_nodes(self, graph: NNCFGraph) -> List[List[NNCFNode]]:
def _pruning_dimensions_analysis(
self,
graph: NNCFGraph,
pruned_nodes_clusterization: Clusterization,
pruned_nodes_clusterization: Clusterization, # type: ignore[type-arg]
can_prune_after_check: Dict[int, PruningAnalysisDecision],
) -> Dict[int, PruningAnalysisDecision]:
"""
Expand Down Expand Up @@ -251,7 +251,7 @@ def _check_all_closing_nodes_are_feasible(
return can_prune_updated

def _check_internal_groups_dim(
self, pruned_nodes_clusterization: Clusterization
self, pruned_nodes_clusterization: Clusterization # type: ignore[type-arg]
) -> Dict[int, PruningAnalysisDecision]:
"""
Checks pruning dimensions of all nodes in each cluster group are equal and
Expand All @@ -278,7 +278,7 @@ def _check_internal_groups_dim(
def _should_prune_groups_analysis(
self,
graph: NNCFGraph,
pruned_nodes_clusterization: Clusterization,
pruned_nodes_clusterization: Clusterization, # type: ignore[type-arg]
can_prune: Dict[int, PruningAnalysisDecision],
) -> Dict[int, PruningAnalysisDecision]:
"""
Expand Down Expand Up @@ -312,7 +312,7 @@ def _should_prune_groups_analysis(
return can_prune_updated

def _filter_groups(
self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision]
self, pruned_nodes_clusterization: Clusterization, can_prune: Dict[int, PruningAnalysisDecision] # type: ignore[type-arg]
) -> None:
"""
Check whether all nodes in group can be pruned based on user-defined constraints and
Expand Down
Loading

0 comments on commit 190006d

Please sign in to comment.