Skip to content

Commit 5363036

Browse files
author
Zhi Lin
committed
fix tests
Signed-off-by: Zhi Lin <[email protected]>
1 parent cb1fcad commit 5363036

File tree

4 files changed

+8
-13
lines changed

4 files changed

+8
-13
lines changed

python/raydp/spark/ray_cluster.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,7 @@ def _prepare_spark_configs(self):
125125

126126
raydp_agent_path = os.path.abspath(os.path.join(os.path.abspath(__file__),
127127
"../../jars/raydp-agent*.jar"))
128+
print(raydp_agent_path)
128129
raydp_agent_jar = glob.glob(raydp_agent_path)[0]
129130
self._configs[SPARK_JAVAAGENT] = raydp_agent_jar
130131
# for JVM running in ray

python/raydp/tf/estimator.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def __init__(self,
4343
metrics: Union[List[keras.metrics.Metric], List[str]] = None,
4444
feature_columns: Union[str, List[str]] = None,
4545
label_columns: Union[str, List[str]] = None,
46-
merge_feature_columns: bool = True,
46+
merge_feature_columns: bool = False,
4747
batch_size: int = 128,
4848
drop_last: bool = False,
4949
num_epochs: int = 1,
@@ -211,10 +211,6 @@ def fit(self,
211211
train_ds = train_ds.random_shuffle()
212212
if evaluate_ds:
213213
evaluate_ds = evaluate_ds.random_shuffle()
214-
datasets = {"train": train_ds}
215-
if evaluate_ds is not None:
216-
train_loop_config["evaluate"] = True
217-
datasets["evaluate"] = evaluate_ds
218214
preprocessor = None
219215
if self._merge_feature_columns:
220216
if isinstance(self._feature_columns, list) and len(self._feature_columns) > 1:
@@ -224,6 +220,11 @@ def fit(self,
224220
preprocessor = Concatenator(output_column_name="features",
225221
exclude=label_cols)
226222
train_loop_config["feature_columns"] = "features"
223+
train_ds = preprocessor.fit_transform(train_ds)
224+
datasets = {"train": train_ds}
225+
if evaluate_ds is not None:
226+
train_loop_config["evaluate"] = True
227+
datasets["evaluate"] = evaluate_ds
227228
self._trainer = TensorflowTrainer(TFEstimator.train_func,
228229
train_loop_config=train_loop_config,
229230
scaling_config=scaling_config,

python/raydp/torch/estimator.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -378,4 +378,4 @@ def fit_on_spark(self,
378378

379379
def get_model(self):
380380
assert self._trainer is not None, "Must call fit first"
381-
return TorchCheckpoint(self._trained_results.checkpoint).get_model()
381+
return TorchCheckpoint(self._trained_results.checkpoint.as_directory()).get_model()

python/raydp/xgboost/estimator.py

Lines changed: 0 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ class XGBoostEstimator(EstimatorInterface, SparkEstimatorInterface):
3131
def __init__(self,
3232
xgboost_params: Dict,
3333
label_column: str,
34-
dmatrix_params: Dict = None,
3534
num_workers: int = 1,
3635
resources_per_worker: Optional[Dict[str, float]] = None,
3736
shuffle: bool = True):
@@ -41,10 +40,6 @@ def __init__(self,
4140
for a list of possible parameters.
4241
:param label_column: Name of the label column. A column with this name
4342
must be present in the training dataset passed to fit() later.
44-
:param dmatrix_params: Dict of ``dataset name:dict of kwargs`` passed to respective
45-
:class:`xgboost_ray.RayDMatrix` initializations, which in turn are passed
46-
to ``xgboost.DMatrix`` objects created on each worker. For example, this can
47-
be used to add sample weights with the ``weights`` parameter.
4843
:param num_workers: the number of workers to do the distributed training.
4944
:param resources_per_worker: the resources defined in this Dict will be reserved for
5045
each worker. The ``CPU`` and ``GPU`` keys (case-sensitive) can be defined to
@@ -53,7 +48,6 @@ def __init__(self,
5348
"""
5449
self._xgboost_params = xgboost_params
5550
self._label_column = label_column
56-
self._dmatrix_params = dmatrix_params
5751
self._num_workers = num_workers
5852
self._resources_per_worker = resources_per_worker
5953
self._shuffle = shuffle
@@ -76,7 +70,6 @@ def fit(self,
7670
datasets=datasets,
7771
label_column=self._label_column,
7872
params=self._xgboost_params,
79-
dmatrix_params=self._dmatrix_params,
8073
run_config=run_config)
8174
self._results = trainer.fit()
8275

0 commit comments

Comments
 (0)