Skip to content

Commit f3cf6ed

Browse files
Weiyi Zhengfacebook-github-bot
authored andcommitted
add fbgemm fp16 (fbfcpacked) support, add global_init_net in predictor_export_meta (pytorch#18257)
Summary: Pull Request resolved: pytorch#18257 support adding op in global_init_net. because pred_init_net is per thread, and just doesn't cut it. Reviewed By: jspark1105 Differential Revision: D14552695 fbshipit-source-id: 53dd44c84ad019019ab9f35fc04d076b7f941ddc
1 parent afc7574 commit f3cf6ed

File tree

2 files changed

+29
-2
lines changed

2 files changed

+29
-2
lines changed

caffe2/python/predictor/predictor_exporter.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ def get_predictor_exporter_helper(submodelNetName):
3636
class PredictorExportMeta(collections.namedtuple(
3737
'PredictorExportMeta',
3838
'predict_net, parameters, inputs, outputs, shapes, name, \
39-
extra_init_net, net_type, num_workers, trainer_prefix')):
39+
extra_init_net, global_init_net, net_type, num_workers, trainer_prefix')):
4040
"""
4141
Metadata to be used for serializaing a net.
4242
@@ -52,6 +52,13 @@ class PredictorExportMeta(collections.namedtuple(
5252
num_workers specifies for net type 'dag' how many threads should run ops
5353
5454
trainer_prefix specifies the type of trainer.
55+
56+
extra_init_net gets appended to pred_init_net, useful for thread local init
57+
58+
global_init_net gets appended to global_init_net, useful for global init
59+
on a shared across threads parameter workspace
60+
(in a case of multi-threaded inference)
61+
5562
"""
5663
def __new__(
5764
cls,
@@ -62,6 +69,7 @@ def __new__(
6269
shapes=None,
6370
name="",
6471
extra_init_net=None,
72+
global_init_net=None,
6573
net_type=None,
6674
num_workers=None,
6775
trainer_prefix=None,
@@ -85,7 +93,7 @@ def __new__(
8593
assert isinstance(predict_net, (caffe2_pb2.NetDef, caffe2_pb2.PlanDef))
8694
return super(PredictorExportMeta, cls).__new__(
8795
cls, predict_net, parameters, inputs, outputs, shapes, name,
88-
extra_init_net, net_type, num_workers, trainer_prefix)
96+
extra_init_net, global_init_net, net_type, num_workers, trainer_prefix)
8997

9098
def inputs_name(self):
9199
return utils.get_comp_name(predictor_constants.INPUTS_BLOB_TYPE,
@@ -154,6 +162,9 @@ def _global_init_net(predictor_export_meta):
154162
net.Proto().external_input.extend([predictor_constants.PREDICTOR_DBREADER])
155163
net.Proto().external_output.extend(predictor_export_meta.parameters)
156164

165+
if predictor_export_meta.global_init_net:
166+
net.AppendNet(predictor_export_meta.global_init_net)
167+
157168
# Add the model_id in the predict_net to the global_init_net
158169
utils.AddModelIdArg(predictor_export_meta, net.Proto())
159170
return net.Proto()

caffe2/python/predictor/predictor_exporter_test.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,13 +100,23 @@ def test_meta_net_def_net_runs(self):
100100

101101
extra_init_net = core.Net('extra_init')
102102
extra_init_net.ConstantFill('data', 'data', value=1.0)
103+
104+
global_init_net = core.Net('global_init')
105+
global_init_net.ConstantFill(
106+
[],
107+
'global_init_blob',
108+
value=1.0,
109+
shape=[1, 5],
110+
dtype=core.DataType.FLOAT
111+
)
103112
pem = pe.PredictorExportMeta(
104113
predict_net=self.predictor_export_meta.predict_net,
105114
parameters=self.predictor_export_meta.parameters,
106115
inputs=self.predictor_export_meta.inputs,
107116
outputs=self.predictor_export_meta.outputs,
108117
shapes=self.predictor_export_meta.shapes,
109118
extra_init_net=extra_init_net,
119+
global_init_net=global_init_net,
110120
net_type='dag',
111121
)
112122

@@ -142,11 +152,17 @@ def test_meta_net_def_net_runs(self):
142152
np.testing.assert_array_equal(
143153
workspace.FetchBlob("y"), np.zeros(shape=(1, 10)))
144154

155+
self.assertTrue("global_init_blob" not in workspace.Blobs())
145156
# Load parameters from DB
146157
global_init_net = pred_utils.GetNet(meta_net_def,
147158
pc.GLOBAL_INIT_NET_TYPE)
148159
workspace.RunNetOnce(global_init_net)
149160

161+
# make sure the extra global_init_net is running
162+
self.assertTrue(workspace.HasBlob('global_init_blob'))
163+
np.testing.assert_array_equal(
164+
workspace.FetchBlob("global_init_blob"), np.ones(shape=(1, 5)))
165+
150166
# Run the net with a reshaped input and verify we are
151167
# producing good numbers (with our custom implementation)
152168
workspace.FeedBlob("data", np.random.randn(2, 5).astype(np.float32))

0 commit comments

Comments
 (0)