Skip to content

Commit 6602572

Browse files
author
Zhi Lin
committed
fix
Signed-off-by: Zhi Lin <[email protected]>
1 parent 5363036 commit 6602572

File tree

4 files changed

+8
-7
lines changed

4 files changed

+8
-7
lines changed

python/raydp/spark/ray_cluster.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -125,9 +125,10 @@ def _prepare_spark_configs(self):
125125

126126
raydp_agent_path = os.path.abspath(os.path.join(os.path.abspath(__file__),
127127
"../../jars/raydp-agent*.jar"))
128-
print(raydp_agent_path)
129-
raydp_agent_jar = glob.glob(raydp_agent_path)[0]
130-
self._configs[SPARK_JAVAAGENT] = raydp_agent_jar
128+
print(os.listdir(raydp_cp))
129+
raydp_agent_jars = glob.glob(raydp_agent_path)
130+
if raydp_agent_jars:
131+
self._configs[SPARK_JAVAAGENT] = raydp_agent_jars[0]
131132
# for JVM running in ray
132133
self._configs[SPARK_RAY_LOG4J_FACTORY_CLASS_KEY] = versions.RAY_LOG4J_VERSION
133134

python/raydp/tf/estimator.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self,
4343
metrics: Union[List[keras.metrics.Metric], List[str]] = None,
4444
feature_columns: Union[str, List[str]] = None,
4545
label_columns: Union[str, List[str]] = None,
46-
merge_feature_columns: bool = False,
46+
merge_feature_columns: bool = True,
4747
batch_size: int = 128,
4848
drop_last: bool = False,
4949
num_epochs: int = 1,
@@ -268,4 +268,4 @@ def fit_on_spark(self,
268268

269269
def get_model(self) -> Any:
270270
assert self._trainer, "Trainer has not been created"
271-
return TensorflowCheckpoint(self._results.checkpoint).get_model()
271+
return TensorflowCheckpoint(self._results.checkpoint.to_directory()).get_model()

python/raydp/torch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,4 @@ def fit_on_spark(self,
378378

379379
def get_model(self):
380380
assert self._trainer is not None, "Must call fit first"
381-
return TorchCheckpoint(self._trained_results.checkpoint.as_directory()).get_model()
381+
return TorchCheckpoint(self._trained_results.checkpoint.to_directory()).get_model()

python/raydp/xgboost/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -109,4 +109,4 @@ def fit_on_spark(self,
109109
train_ds, evaluate_ds, max_retries)
110110

111111
def get_model(self):
112-
return XGBoostCheckpoint.from_checkpoint(self._results.checkpoint).get_model()
112+
return XGBoostCheckpoint(self._results.checkpoint.to_directory()).get_model()

0 commit comments

Comments
 (0)