Skip to content

Commit bdeaf44

Browse files
committed
updated execute_or_node_remote() to only set score or predict results at an output node where applicable
1 parent 5e73919 commit bdeaf44

File tree

3 files changed

+56
-31
lines changed

3 files changed

+56
-31
lines changed

Diff for: codeflare/pipelines/Runtime.py

+26-8
Original file line numberDiff line numberDiff line change
@@ -62,7 +62,7 @@ class ExecutionType(Enum):
6262

6363

6464
@ray.remote
65-
def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref: dm.XYRef):
65+
def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref: dm.XYRef, is_outputNode: bool):
6666
"""
6767
Helper remote function that executes an OR node. As such, this is a remote task that runs the estimator
6868
in the provided mode with the data pointed to by XYRef. The key aspect to note here is the choice of input
@@ -107,9 +107,16 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
107107
elif mode == ExecutionType.SCORE:
108108
if base.is_classifier(estimator) or base.is_regressor(estimator):
109109
estimator = node.get_estimator()
110-
score_ref = ray.put(estimator.score(X, y))
111-
result = dm.XYRef(score_ref, score_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
112-
return result
110+
if is_outputNode:
111+
score_ref = ray.put(estimator.score(X, y))
112+
result = dm.XYRef(score_ref, score_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
113+
return result
114+
else:
115+
res_xy = estimator.score(xy_list)
116+
res_xref = ray.put(res_xy.get_x())
117+
res_yref = ray.put(res_xy.get_y())
118+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
119+
return result
113120
else:
114121
res_Xref = ray.put(estimator.transform(X))
115122
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
@@ -118,9 +125,17 @@ def execute_or_node_remote(node: dm.EstimatorNode, mode: ExecutionType, xy_ref:
118125
elif mode == ExecutionType.PREDICT:
119126
# Test mode does not clone as it is a simple predict or transform
120127
if base.is_classifier(estimator) or base.is_regressor(estimator):
121-
predict_ref = ray.put(estimator.predict(X))
122-
result = dm.XYRef(predict_ref, predict_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
123-
return result
128+
if is_outputNode:
129+
predict_ref = ray.put(estimator.predict(X))
130+
result = dm.XYRef(predict_ref, predict_ref, prev_node_ptr, prev_node_ptr, [xy_ref])
131+
return result
132+
else:
133+
res_xy = estimator.predict(xy_list)
134+
res_xref = ray.put(res_xy.get_x())
135+
res_yref = ray.put(res_xy.get_y())
136+
137+
result = dm.XYRef(res_xref, res_yref, prev_node_ptr, prev_node_ptr, Xyref_list)
138+
return result
124139
else:
125140
res_Xref = ray.put(estimator.transform(X))
126141
result = dm.XYRef(res_Xref, xy_ref.get_yref(), prev_node_ptr, prev_node_ptr, [xy_ref])
@@ -147,7 +162,10 @@ def execute_or_node(node, pre_edges, edge_args, post_edges, mode: ExecutionType)
147162
exec_xyrefs = []
148163
for xy_ref_ptr in Xyref_ptrs:
149164
xy_ref = ray.get(xy_ref_ptr)
150-
inner_result = execute_or_node_remote.remote(node, mode, xy_ref)
165+
if post_edges:
166+
inner_result = execute_or_node_remote.remote(node, mode, xy_ref, True)
167+
else:
168+
inner_result = execute_or_node_remote.remote(node, mode, xy_ref, False)
151169
exec_xyrefs.append(inner_result)
152170

153171
for post_edge in post_edges:

Diff for: notebooks/plot_nca_classification.ipynb

+9-2
Original file line numberDiff line numberDiff line change
@@ -150,14 +150,14 @@
150150
},
151151
{
152152
"cell_type": "code",
153-
"execution_count": 4,
153+
"execution_count": 3,
154154
"metadata": {},
155155
"outputs": [
156156
{
157157
"name": "stderr",
158158
"output_type": "stream",
159159
"text": [
160-
"2021-07-19 10:50:40,647\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
160+
"2021-07-22 17:14:51,530\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
161161
]
162162
},
163163
{
@@ -295,6 +295,13 @@
295295
"\n",
296296
"ray.shutdown()"
297297
]
298+
},
299+
{
300+
"cell_type": "code",
301+
"execution_count": null,
302+
"metadata": {},
303+
"outputs": [],
304+
"source": []
298305
}
299306
],
300307
"metadata": {

Diff for: notebooks/plot_rbm_logistic_classification.ipynb

+21-21
Original file line numberDiff line numberDiff line change
@@ -27,16 +27,16 @@
2727
"output_type": "stream",
2828
"text": [
2929
"Automatically created module for IPython interactive environment\n",
30-
"[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.10s\n",
31-
"[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.14s\n",
32-
"[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
33-
"[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.14s\n",
34-
"[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
35-
"[BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.10s\n",
36-
"[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.13s\n",
37-
"[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.13s\n",
38-
"[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.13s\n",
39-
"[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.13s\n",
30+
"[BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.13s\n",
31+
"[BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.18s\n",
32+
"[BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.18s\n",
33+
"[BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.17s\n",
34+
"[BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.17s\n",
35+
"[BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.18s\n",
36+
"[BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.17s\n",
37+
"[BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.17s\n",
38+
"[BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.17s\n",
39+
"[BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.16s\n",
4040
"Logistic regression using RBM features:\n",
4141
" precision recall f1-score support\n",
4242
"\n",
@@ -214,23 +214,23 @@
214214
"name": "stderr",
215215
"output_type": "stream",
216216
"text": [
217-
"2021-07-19 11:01:26,551\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8266\u001b[39m\u001b[22m\n"
217+
"2021-07-22 17:16:19,742\tINFO services.py:1267 -- View the Ray dashboard at \u001b[1m\u001b[32mhttp://127.0.0.1:8267\u001b[39m\u001b[22m\n"
218218
]
219219
},
220220
{
221221
"name": "stdout",
222222
"output_type": "stream",
223223
"text": [
224-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.11s\n",
225-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.15s\n",
226-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.15s\n",
227-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.16s\n",
228-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.15s\n",
229-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.14s\n",
230-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.13s\n",
231-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.15s\n",
232-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.15s\n",
233-
"\u001b[2m\u001b[36m(pid=8995)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.15s\n",
224+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 1, pseudo-likelihood = -25.57, time = 0.16s\n",
225+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 2, pseudo-likelihood = -23.68, time = 0.22s\n",
226+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 3, pseudo-likelihood = -22.74, time = 0.22s\n",
227+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 4, pseudo-likelihood = -21.83, time = 0.22s\n",
228+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 5, pseudo-likelihood = -21.62, time = 0.22s\n",
229+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 6, pseudo-likelihood = -21.11, time = 0.21s\n",
230+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 7, pseudo-likelihood = -20.88, time = 0.21s\n",
231+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 8, pseudo-likelihood = -20.58, time = 0.21s\n",
232+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 9, pseudo-likelihood = -20.32, time = 0.21s\n",
233+
"\u001b[2m\u001b[36m(pid=12180)\u001b[0m [BernoulliRBM] Iteration 10, pseudo-likelihood = -20.13, time = 0.47s\n",
234234
"Logistic regression using RBM features:\n",
235235
" precision recall f1-score support\n",
236236
"\n",

0 commit comments

Comments
 (0)