Skip to content

Commit eff6b64

Browse files
author
Zhi Lin
committed
update to latest ray doc
Signed-off-by: Zhi Lin <[email protected]>
1 parent 0d86b0e commit eff6b64

File tree

4 files changed

+44
-12
lines changed

4 files changed

+44
-12
lines changed

python/raydp/tests/test_tf.py

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,12 +34,11 @@ def test_tf_estimator(spark_on_ray_small, use_fs_directory):
3434
spark = spark_on_ray_small
3535

3636
# ---------------- data process with Spark ------------
37-
# calculate z = 3 * x + 4 * y + 5
37+
# calculate y = 3 * x + 4
3838
df: pyspark.sql.DataFrame = spark.range(0, 100000)
3939
df = df.withColumn("x", rand() * 100) # add x column
40-
df = df.withColumn("y", rand() * 1000) # ad y column
41-
df = df.withColumn("z", df.x * 3 + df.y * 4 + rand() + 5) # ad z column
42-
df = df.select(df.x, df.y, df.z)
40+
df = df.withColumn("y", df.x * 3 + rand() + 4) # add y column
41+
df = df.select(df.x, df.y)
4342

4443
train_df, test_df = random_split(df, [0.7, 0.3])
4544

@@ -59,8 +58,8 @@ def test_tf_estimator(spark_on_ray_small, use_fs_directory):
5958
optimizer=optimizer,
6059
loss=loss,
6160
metrics=["accuracy", "mse"],
62-
feature_columns=["x", "y"],
63-
label_columns="z",
61+
feature_columns=["x"],
62+
label_columns="y",
6463
batch_size=1000,
6564
num_epochs=2,
6665
use_gpu=False)

python/raydp/tf/estimator.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,17 @@
1515
# limitations under the License.
1616
#
1717

18+
import json
19+
import os
20+
import tempfile
1821
from typing import Any, List, NoReturn, Optional, Union, Dict
1922

2023
import tensorflow as tf
2124
import tensorflow.keras as keras
2225
from tensorflow import DType, TensorShape
2326
from tensorflow.keras.callbacks import Callback
2427

28+
from ray.train import Checkpoint
2529
from ray.train.tensorflow import TensorflowTrainer, TensorflowCheckpoint, prepare_dataset_shard
2630
from ray.air import session
2731
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
@@ -43,7 +47,7 @@ def __init__(self,
4347
metrics: Union[List[keras.metrics.Metric], List[str]] = None,
4448
feature_columns: Union[str, List[str]] = None,
4549
label_columns: Union[str, List[str]] = None,
46-
merge_feature_columns: bool = True,
50+
merge_feature_columns: bool = False,
4751
batch_size: int = 128,
4852
drop_last: bool = False,
4953
num_epochs: int = 1,
@@ -184,7 +188,14 @@ def train_func(config):
184188
if config["evaluate"]:
185189
test_history = multi_worker_model.evaluate(eval_tf_dataset, callbacks=callbacks)
186190
results.append(test_history)
187-
session.report({}, checkpoint=TensorflowCheckpoint.from_model(multi_worker_model))
191+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
192+
multi_worker_model.save(os.path.join(temp_checkpoint_dir, "model.keras"))
193+
checkpoint_dict = os.path.join(temp_checkpoint_dir, "checkpoint.json")
194+
with open(checkpoint_dict, "w") as f:
195+
json.dump({"epoch": config["num_epochs"]}, f)
196+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
197+
198+
session.report({}, checkpoint=checkpoint)
188199

189200
def fit(self,
190201
train_ds: Dataset,

python/raydp/torch/estimator.py

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616
#
1717

18+
import os
19+
import tempfile
1820
import inspect
1921
from typing import Any, Callable, List, NoReturn, Optional, Union, Dict
2022

@@ -30,6 +32,7 @@
3032

3133
import ray
3234
from ray import train
35+
from ray.train import Checkpoint
3336
from ray.train.torch import TorchTrainer, TorchCheckpoint
3437
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
3538
from ray.air import session
@@ -254,7 +257,18 @@ def train_func(config):
254257
else:
255258
# if num_workers = 1, model is not wrapped
256259
states = model.state_dict()
257-
session.report({}, checkpoint=TorchCheckpoint.from_state_dict(states))
260+
with tempfile.TemporaryDirectory() as temp_checkpoint_dir:
261+
checkpoint = None
262+
# In standard DDP training, where the model is the same across all ranks,
263+
# only the global rank 0 worker needs to save and report the checkpoint
264+
if train.get_context().get_world_rank() == 0:
265+
torch.save(
266+
states,
267+
os.path.join(temp_checkpoint_dir, "model.pt"),
268+
)
269+
checkpoint = Checkpoint.from_directory(temp_checkpoint_dir)
270+
271+
session.report({}, checkpoint=checkpoint)
258272

259273
@staticmethod
260274
def train_epoch(dataset, model, criterion, optimizer, metrics, scheduler=None):

python/raydp/xgboost/estimator.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
from raydp.spark import spark_dataframe_to_ray_dataset, get_raydp_master_owner
2424

2525
import ray
26-
from ray.air.config import ScalingConfig, RunConfig, FailureConfig
26+
from ray.air.config import ScalingConfig, RunConfig, FailureConfig, CheckpointConfig
2727
from ray.data.dataset import Dataset
2828
from ray.train.xgboost import XGBoostTrainer, XGBoostCheckpoint
2929

@@ -58,7 +58,15 @@ def fit(self,
5858
max_retries=3) -> NoReturn:
5959
scaling_config = ScalingConfig(num_workers=self._num_workers,
6060
resources_per_worker=self._resources_per_worker)
61-
run_config = RunConfig(failure_config=FailureConfig(max_failures=max_retries))
61+
run_config = RunConfig(
62+
checkpoint_config=CheckpointConfig(
63+
# Checkpoint every iteration.
64+
checkpoint_frequency=1,
65+
# Only keep the latest checkpoint and delete the others.
66+
num_to_keep=1,
67+
),
68+
failure_config=FailureConfig(max_failures=max_retries)
69+
)
6270
if self._shuffle:
6371
train_ds = train_ds.random_shuffle()
6472
if evaluate_ds:
@@ -109,4 +117,4 @@ def fit_on_spark(self,
109117
train_ds, evaluate_ds, max_retries)
110118

111119
def get_model(self):
112-
return XGBoostCheckpoint(self._results.checkpoint.to_directory()).get_model()
120+
return XGBoostTrainer.get_model(self._results.checkpoint)

0 commit comments

Comments
 (0)