24
24
from nncf .experimental .common .tensor_statistics .collectors import AbsMaxReducer
25
25
from nncf .experimental .common .tensor_statistics .collectors import MaxAggregator
26
26
from nncf .experimental .common .tensor_statistics .collectors import TensorCollector
27
- from nncf .openvino .graph .transformations .commands import OVMultiplyInsertionCommand
28
- from nncf .openvino .graph .transformations .commands import OVWeightUpdateCommand
29
27
from nncf .quantization .algorithms .smooth_quant .backend import SmoothQuantAlgoBackend
30
28
from nncf .tensor import Tensor
31
29
from nncf .torch .graph .transformations .command_creation import create_command_to_update_weight
32
30
from nncf .torch .graph .transformations .commands import PTSharedFnInsertionCommand
33
31
from nncf .torch .graph .transformations .commands import PTTargetPoint
32
+ from nncf .torch .graph .transformations .commands import PTWeightUpdateCommand
34
33
from nncf .torch .layer_utils import COMPRESSION_MODULES
35
34
from nncf .torch .layer_utils import CompressionParameter
36
35
from nncf .torch .layer_utils import StatefullModuleInterface
@@ -127,7 +126,7 @@ def get_weight_value(node_with_weight: NNCFNode, model: NNCFNetwork, nncf_graph:
127
126
return Tensor (weight_data )
128
127
129
128
@staticmethod
130
- def weight_update_command (node_with_weight : NNCFNode , weight_value : np .ndarray ) -> OVWeightUpdateCommand :
129
+ def weight_update_command (node_with_weight : NNCFNode , weight_value : np .ndarray ) -> PTWeightUpdateCommand :
131
130
return create_command_to_update_weight (node_with_weight , weight_value )
132
131
133
132
@staticmethod
@@ -137,7 +136,7 @@ def scale_insertion_command(
137
136
source_output_port_id : int ,
138
137
nodes : List [NNCFNode ],
139
138
scale_node_name : str ,
140
- ) -> OVMultiplyInsertionCommand :
139
+ ) -> PTSharedFnInsertionCommand :
141
140
input_port_id = 0
142
141
target_points = []
143
142
for node in nodes :
0 commit comments