File tree Expand file tree Collapse file tree 4 files changed +8
-7
lines changed Expand file tree Collapse file tree 4 files changed +8
-7
lines changed Original file line number Diff line number Diff line change @@ -125,9 +125,10 @@ def _prepare_spark_configs(self):
125
125
126
126
raydp_agent_path = os .path .abspath (os .path .join (os .path .abspath (__file__ ),
127
127
"../../jars/raydp-agent*.jar" ))
128
- print (raydp_agent_path )
129
- raydp_agent_jar = glob .glob (raydp_agent_path )[0 ]
130
- self ._configs [SPARK_JAVAAGENT ] = raydp_agent_jar
128
+ print (os .listdir (raydp_cp ))
129
+ raydp_agent_jars = glob .glob (raydp_agent_path )
130
+ if raydp_agent_jars :
131
+ self ._configs [SPARK_JAVAAGENT ] = raydp_agent_jars [0 ]
131
132
# for JVM running in ray
132
133
self ._configs [SPARK_RAY_LOG4J_FACTORY_CLASS_KEY ] = versions .RAY_LOG4J_VERSION
133
134
Original file line number Diff line number Diff line change @@ -43,7 +43,7 @@ def __init__(self,
43
43
metrics : Union [List [keras .metrics .Metric ], List [str ]] = None ,
44
44
feature_columns : Union [str , List [str ]] = None ,
45
45
label_columns : Union [str , List [str ]] = None ,
46
- merge_feature_columns : bool = False ,
46
+ merge_feature_columns : bool = True ,
47
47
batch_size : int = 128 ,
48
48
drop_last : bool = False ,
49
49
num_epochs : int = 1 ,
@@ -268,4 +268,4 @@ def fit_on_spark(self,
268
268
269
269
def get_model (self ) -> Any :
270
270
assert self ._trainer , "Trainer has not been created"
271
- return TensorflowCheckpoint (self ._results .checkpoint ).get_model ()
271
+ return TensorflowCheckpoint (self ._results .checkpoint . to_directory () ).get_model ()
Original file line number Diff line number Diff line change @@ -378,4 +378,4 @@ def fit_on_spark(self,
378
378
379
379
def get_model (self ):
380
380
assert self ._trainer is not None , "Must call fit first"
381
- return TorchCheckpoint (self ._trained_results .checkpoint .as_directory ()).get_model ()
381
+ return TorchCheckpoint (self ._trained_results .checkpoint .to_directory ()).get_model ()
Original file line number Diff line number Diff line change @@ -109,4 +109,4 @@ def fit_on_spark(self,
109
109
train_ds , evaluate_ds , max_retries )
110
110
111
111
def get_model (self ):
112
- return XGBoostCheckpoint . from_checkpoint (self ._results .checkpoint ).get_model ()
112
+ return XGBoostCheckpoint (self ._results .checkpoint . to_directory () ).get_model ()
You can’t perform that action at this time.
0 commit comments