Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -112,10 +112,19 @@ option(ADD_PKGCONFIG "Add xgboost.pc into system." ON)
if(USE_DEBUG_OUTPUT AND (NOT (CMAKE_BUILD_TYPE MATCHES Debug)))
message(SEND_ERROR "Do not enable `USE_DEBUG_OUTPUT' with release build.")
endif()
if(USE_NCCL AND NOT (USE_CUDA))
if(USE_NVTX AND (NOT USE_CUDA))
message(SEND_ERROR "`USE_NVTX` must be enabled with `USE_CUDA` flag.")
endif()
if(USE_NVTX)
if(CMAKE_VERSION VERSION_LESS "3.25.0")
# CUDA:nvtx3 target is added in 3.25
message("cmake >= 3.25 is required for NVTX.")
endif()
endif()
if(USE_NCCL AND (NOT USE_CUDA))
message(SEND_ERROR "`USE_NCCL` must be enabled with `USE_CUDA` flag.")
endif()
if(USE_DEVICE_DEBUG AND NOT (USE_CUDA))
if(USE_DEVICE_DEBUG AND (NOT USE_CUDA))
message(SEND_ERROR "`USE_DEVICE_DEBUG` must be enabled with `USE_CUDA` flag.")
endif()
if(BUILD_WITH_SHARED_NCCL AND (NOT USE_NCCL))
Expand Down
1 change: 1 addition & 0 deletions jvm-packages/.gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -2,3 +2,4 @@ build.sh
xgboost4j-tester/pom.xml
xgboost4j-tester/iris.csv
dependency-reduced-pom.xml
.factorypath
9 changes: 9 additions & 0 deletions jvm-packages/create_jni.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ def native_build(cli_args: argparse.Namespace) -> None:
os.environ["JAVA_HOME"] = (
subprocess.check_output("/usr/libexec/java_home").strip().decode()
)
if cli_args.use_debug == "ON":
CONFIG["CMAKE_BUILD_TYPE"] = "Debug"
if cli_args.use_nvtx == "ON":
CONFIG["USE_NVTX"] = "ON"
if cli_args.plugin_rmm == "ON":
CONFIG["PLUGIN_RMM"] = "ON"

print("building Java wrapper", flush=True)
with cd(".."):
Expand Down Expand Up @@ -187,5 +193,8 @@ def native_build(cli_args: argparse.Namespace) -> None:
)
parser.add_argument("--use-cuda", type=str, choices=["ON", "OFF"], default="OFF")
parser.add_argument("--use-openmp", type=str, choices=["ON", "OFF"], default="ON")
parser.add_argument("--use-debug", type=str, choices=["ON", "OFF"], default="OFF")
parser.add_argument("--use-nvtx", type=str, choices=["ON", "OFF"], default="OFF")
parser.add_argument("--plugin-rmm", type=str, choices=["ON", "OFF"], default="OFF")
cli_args = parser.parse_args()
native_build(cli_args)
3 changes: 3 additions & 0 deletions jvm-packages/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,9 @@
<log.capi.invocation>OFF</log.capi.invocation>
<use.cuda>OFF</use.cuda>
<use.openmp>ON</use.openmp>
<use.debug>OFF</use.debug>
<use.nvtx>OFF</use.nvtx>
<plugin.rmm>OFF</plugin.rmm>
<cudf.version>24.10.0</cudf.version>
<spark.rapids.version>24.10.0</spark.rapids.version>
<spark.rapids.classifier>cuda12</spark.rapids.classifier>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,17 @@ private List<CudfColumn> initializeCudfColumns(Table table) {
.collect(Collectors.toList());
}

// visible for testing
public Table getFeatureTable() {
return featureTable;
}

// visible for testing
public Table getLabelTable() {
return labelTable;
}


public List<CudfColumn> getFeatures() {
return features;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*
Copyright (c) 2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.java;

import java.util.Iterator;
import java.util.Map;

import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.module.SimpleModule;

public class ExtMemQuantileDMatrix extends QuantileDMatrix {
// on_host is set to true by default as we only support GPU at the moment
// cache_prefix is not used yet since we have on_host=true.
public ExtMemQuantileDMatrix(Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref,
int nthread,
int maxNumDevicePages,
int maxQuantileBatches,
int minCachePageBytes) throws XGBoostError {
long[] out = new long[1];
long[] refHandle = null;
if (ref != null) {
refHandle = new long[1];
refHandle[0] = ref.getHandle();
}
String conf = this.getConfig(missing, maxBin, nthread, maxNumDevicePages,
maxQuantileBatches, minCachePageBytes);
XGBoostJNI.checkCall(XGBoostJNI.XGExtMemQuantileDMatrixCreateFromCallback(
iter, refHandle, conf, out));
handle = out[0];
}

public ExtMemQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin,
DMatrix ref) throws XGBoostError {
this(iter, missing, maxBin, ref, 0, -1, -1, -1);
}

public ExtMemQuantileDMatrix(
Iterator<ColumnBatch> iter,
float missing,
int maxBin) throws XGBoostError {
this(iter, missing, maxBin, null);
}

private String getConfig(float missing, int maxBin, int nthread, int maxNumDevicePages,
int maxQuantileBatches, int minCachePageBytes) {
Map<String, Object> conf = new java.util.HashMap<>();
conf.put("missing", missing);
conf.put("max_bin", maxBin);
conf.put("nthread", nthread);

if (maxNumDevicePages > 0) {
conf.put("max_num_device_pages", maxNumDevicePages);
}
if (maxQuantileBatches > 0) {
conf.put("max_quantile_batches", maxQuantileBatches);
}
if (minCachePageBytes > 0) {
conf.put("min_cache_page_bytes", minCachePageBytes);
}

conf.put("on_host", true);
conf.put("cache_prefix", ".");
ObjectMapper mapper = new ObjectMapper();

// Handle NaN values. Jackson by default serializes NaN values into strings.
SimpleModule module = new SimpleModule();
module.addSerializer(Double.class, new F64NaNSerializer());
module.addSerializer(Float.class, new F32NaNSerializer());
mapper.registerModule(module);

try {
return mapper.writeValueAsString(conf);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
}
};
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,11 @@ public void serialize(Float value, JsonGenerator gen,
* QuantileDMatrix will only be used to train
*/
public class QuantileDMatrix extends DMatrix {
// implicit constructor for the ext mem version of the QDM.
protected QuantileDMatrix() {
super(0);
}

/**
* Create QuantileDMatrix from iterator based on the cuda array interface
*
Expand Down Expand Up @@ -158,8 +163,7 @@ private String getConfig(float missing, int maxBin, int nthread) {
mapper.registerModule(module);

try {
String config = mapper.writeValueAsString(conf);
return config;
return mapper.writeValueAsString(conf);
} catch (JsonProcessingException e) {
throw new RuntimeException("Failed to serialize configuration", e);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
Copyright (c) 2025 by Contributors

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
*/
package ml.dmlc.xgboost4j.scala

import scala.collection.JavaConverters._

import ml.dmlc.xgboost4j.java.{ColumnBatch, ExtMemQuantileDMatrix => jExtMemQuantileDMatrix}

class ExtMemQuantileDMatrix private[scala](
private[scala] override val jDMatrix: jExtMemQuantileDMatrix) extends QuantileDMatrix(jDMatrix) {

def this(iter: Iterator[ColumnBatch],
missing: Float,
maxBin: Int,
ref: Option[QuantileDMatrix],
nthread: Int,
maxNumDevicePages: Int,
maxQuantileBatches: Int,
minCachePageBytes: Int) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin,
ref.map(_.jDMatrix).orNull,
nthread, maxNumDevicePages, maxQuantileBatches, minCachePageBytes))
}

def this(iter: Iterator[ColumnBatch], missing: Float, maxBin: Int) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin))
}

def this(
iter: Iterator[ColumnBatch],
ref: ExtMemQuantileDMatrix,
missing: Float,
maxBin: Int
) {
this(new jExtMemQuantileDMatrix(iter.asJava, missing, maxBin, ref.jDMatrix))
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,22 +39,22 @@ class QuantileDMatrix private[scala](
/**
* Create QuantileDMatrix from iterator based on the array interface
*
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param refDMatrix The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @param iter the XGBoost ColumnBatch batch to provide the corresponding array interface
* @param ref The reference QuantileDMatrix that provides quantile information, needed
* when creating validation/test dataset with QuantileDMatrix. Supplying the
* training DMatrix as a reference means that the same quantisation applied
* to the training data is applied to the validation/test data
* @param missing the missing value
* @param maxBin the max bin
* @param nthread the parallelism
* @throws XGBoostError
*/
def this(iter: Iterator[ColumnBatch],
ref: QuantileDMatrix,
ref: Option[QuantileDMatrix],
missing: Float,
maxBin: Int,
nthread: Int) {
this(new JQuantileDMatrix(iter.asJava, ref.jDMatrix, missing, maxBin, nthread))
this(new JQuantileDMatrix(iter.asJava, ref.map(_.jDMatrix).orNull, missing, maxBin, nthread))
}

/**
Expand Down
Loading
Loading