Skip to content

Commit d8fb848

Browse files
feat!: Make ug dependency optional (#3268)
* feat!: Make `ug` dep optional * fix(example/mnist-training): Run all epochs * doc(`candle-ug`): Crate documentation * fix: feature-gate the `ComputePipeline` import --------- Co-authored-by: ivarflakstad <[email protected]>
1 parent 5de3d0f commit d8fb848

File tree

11 files changed

+69
-29
lines changed

11 files changed

+69
-29
lines changed

Cargo.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"candle-nn",
77
"candle-pyo3",
88
"candle-transformers",
9+
"candle-ug",
910
"candle-wasm-examples/*",
1011
"candle-wasm-tests",
1112
"tensor-tools",
@@ -43,6 +44,7 @@ candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.2-alpha
4344
candle-nn = { path = "./candle-nn", version = "0.9.2-alpha.2" }
4445
candle-onnx = { path = "./candle-onnx", version = "0.9.2-alpha.2" }
4546
candle-transformers = { path = "./candle-transformers", version = "0.9.2-alpha.2" }
47+
candle-ug = { path = "./candle-ug", version = "0.9.2-alpha.2" }
4648
clap = { version = "4.2.4", features = ["derive"] }
4749
criterion = { version = "0.7.0", default-features = false }
4850
cudarc = { version = "0.18.1", features = [
@@ -65,10 +67,7 @@ half = { version = "2.5.0", features = [
6567
"use-intrinsics",
6668
"rand_distr",
6769
] }
68-
float8 = { version = "0.5.0", features = [
69-
"num-traits",
70-
"rand_distr",
71-
] }
70+
float8 = { version = "0.5.0", features = ["num-traits", "rand_distr"] }
7271
hound = "3.5.1"
7372
image = { version = "0.25.2", default-features = false, features = [
7473
"jpeg",

candle-core/Cargo.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ yoke = { workspace = true }
3535
zip = { workspace = true }
3636

3737
[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))'.dependencies]
38-
ug = { workspace = true }
39-
ug-cuda = { workspace = true, optional = true }
40-
ug-metal = { workspace = true, optional = true }
38+
candle-ug = { workspace = true, optional = true }
4139

4240
[dev-dependencies]
4341
anyhow = { workspace = true }
@@ -46,7 +44,7 @@ criterion = { workspace = true }
4644

4745
[features]
4846
default = []
49-
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
47+
cuda = ["cudarc", "dep:candle-kernels", "candle-ug?/cuda", "float8/cuda"]
5048
cudnn = ["cuda", "cudarc/cudnn"]
5149
nccl = ["cuda", "cudarc/nccl"]
5250
mkl = ["dep:libc", "dep:intel-mkl-src"]
@@ -55,8 +53,9 @@ metal = [
5553
"dep:objc2-metal",
5654
"dep:objc2-foundation",
5755
"dep:candle-metal-kernels",
58-
"dep:ug-metal",
56+
"candle-ug?/metal",
5957
]
58+
ug = ["dep:candle-ug"]
6059

6160
[[bench]]
6261
name = "bench_main"

candle-core/src/cuda_backend/device.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,14 @@ impl CudaDevice {
173173
self.context.is_event_tracking()
174174
}
175175

176-
#[cfg(not(target_arch = "wasm32"))]
176+
#[cfg(all(feature = "ug", not(target_arch = "wasm32")))]
177177
pub fn compile(
178178
&self,
179179
func_name: &'static str,
180-
kernel: ug::lang::ssa::Kernel,
180+
kernel: candle_ug::lang::ssa::Kernel,
181181
) -> Result<CudaFunc> {
182182
let mut buf = vec![];
183-
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
183+
candle_ug::cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
184184
let cuda_code = String::from_utf8(buf)?;
185185
let opts = cudarc::nvrtc::CompileOptions {
186186
use_fast_math: Some(true),

candle-core/src/custom_op.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ impl Tensor {
376376
}
377377
}
378378

379+
#[cfg(feature = "ug")]
379380
pub struct UgIOp1 {
380381
name: &'static str,
381382
#[cfg(feature = "cuda")]
@@ -384,12 +385,13 @@ pub struct UgIOp1 {
384385
func: candle_metal_kernels::metal::ComputePipeline,
385386
}
386387

388+
#[cfg(feature = "ug")]
387389
impl UgIOp1 {
388390
#[allow(unused)]
389391
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
390392
pub fn new(
391393
name: &'static str,
392-
kernel: ug::lang::ssa::Kernel,
394+
kernel: candle_ug::lang::ssa::Kernel,
393395
device: &crate::Device,
394396
) -> Result<Self> {
395397
#[cfg(feature = "cuda")]
@@ -414,6 +416,7 @@ impl UgIOp1 {
414416
}
415417
}
416418

419+
#[cfg(feature = "ug")]
417420
impl InplaceOp1 for UgIOp1 {
418421
fn name(&self) -> &'static str {
419422
self.name

candle-core/src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ pub enum Error {
174174
#[error("Metal error {0}")]
175175
Metal(#[from] MetalError),
176176

177-
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
177+
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios"), feature = "ug"))]
178178
#[error(transparent)]
179-
Ug(#[from] ug::Error),
179+
Ug(#[from] candle_ug::Error),
180180

181181
#[error(transparent)]
182182
TryFromIntError(#[from] core::num::TryFromIntError),

candle-core/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ mod variable;
9292
pub use cuda_backend::cudnn;
9393

9494
pub use cpu_backend::{CpuStorage, CpuStorageRef};
95-
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
95+
#[cfg(feature = "ug")]
96+
pub use custom_op::UgIOp1;
97+
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
9698
pub use device::{Device, DeviceLocation, NdArray};
9799
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
98100
pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0};

candle-core/src/metal_backend/device.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
use crate::{DType, Result};
2+
3+
#[cfg(feature = "ug")]
4+
use candle_metal_kernels::metal::ComputePipeline;
25
use candle_metal_kernels::{
36
metal::{
4-
BlitCommandEncoder, Buffer, BufferMap, Commands, ComputeCommandEncoder, ComputePipeline,
5-
Device, MTLResourceOptions,
7+
BlitCommandEncoder, Buffer, BufferMap, Commands, ComputeCommandEncoder, Device,
8+
MTLResourceOptions,
69
},
710
Kernels,
811
};
912
use objc2_foundation::NSURL;
1013
use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager};
14+
1115
use std::path::Path;
1216
use std::sync::{Arc, Mutex, RwLock};
1317

@@ -88,14 +92,14 @@ impl std::ops::Deref for MetalDevice {
8892
}
8993

9094
impl MetalDevice {
91-
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
95+
#[cfg(all(feature = "ug", not(target_arch = "wasm32"), not(target_os = "ios")))]
9296
pub fn compile(
9397
&self,
9498
func_name: &'static str,
95-
kernel: ug::lang::ssa::Kernel,
99+
kernel: candle_ug::lang::ssa::Kernel,
96100
) -> Result<ComputePipeline> {
97101
let mut buf = vec![];
98-
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
102+
candle_ug::metal::code_gen::gen(&mut buf, func_name, &kernel)?;
99103
let metal_code = String::from_utf8(buf)?;
100104
let lib = self
101105
.device

candle-core/tests/custom_op_tests.rs

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -145,20 +145,20 @@ fn inplace_op1() -> Result<()> {
145145
Ok(())
146146
}
147147

148-
#[cfg(any(feature = "cuda", feature = "metal"))]
148+
#[cfg(all(feature = "ug", any(feature = "cuda", feature = "metal")))]
149149
#[allow(clippy::approx_constant)]
150150
#[test]
151151
fn ug_op() -> Result<()> {
152152
let kernel = {
153-
use ug::lang::op;
153+
use candle_ug::lang::op;
154154

155-
let layout = ug::Layout::from_shape(&[12]);
156-
let ptr = op::Arg::ptr(ug::DType::F32);
157-
let src = op::load(ptr.id(), layout.clone(), ug::DType::F32)?;
155+
let layout = candle_ug::Layout::from_shape(&[12]);
156+
let ptr = op::Arg::ptr(candle_ug::DType::F32);
157+
let src = op::load(ptr.id(), layout.clone(), candle_ug::DType::F32)?;
158158
let src = op::unary(op::UnaryOp::Exp, src)?;
159159
let st = op::store(ptr.id(), layout, src)?;
160160
let kernel = op::Kernel::new("exp".to_string(), vec![ptr], vec![st]);
161-
let opts: ug::lower_op::Opts = Default::default();
161+
let opts: candle_ug::lower_op::Opts = Default::default();
162162
kernel.lower(&opts)?
163163
};
164164
let device = if candle_core::utils::cuda_is_available() {

candle-examples/examples/mnist-training/main.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -137,7 +137,7 @@ fn training_loop_cnn(
137137
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
138138
let n_batches = train_images.dim(0)? / BSIZE;
139139
let mut batch_idxs = (0..n_batches).collect::<Vec<usize>>();
140-
for epoch in 1..args.epochs {
140+
for epoch in 1..=args.epochs {
141141
let mut sum_loss = 0f32;
142142
batch_idxs.shuffle(&mut rng());
143143
for batch_idx in batch_idxs.iter() {
@@ -194,7 +194,7 @@ fn training_loop<M: Model>(
194194
let mut sgd = candle_nn::SGD::new(varmap.all_vars(), args.learning_rate)?;
195195
let test_images = m.test_images.to_device(&dev)?;
196196
let test_labels = m.test_labels.to_dtype(DType::U32)?.to_device(&dev)?;
197-
for epoch in 1..args.epochs {
197+
for epoch in 1..=args.epochs {
198198
let logits = model.forward(&train_images)?;
199199
let log_sm = ops::log_softmax(&logits, D::Minus1)?;
200200
let loss = loss::nll(&log_sm, &train_labels)?;

candle-ug/Cargo.toml

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
[package]
2+
name = "candle-ug"
3+
version.workspace = true
4+
edition.workspace = true
5+
description.workspace = true
6+
repository.workspace = true
7+
keywords.workspace = true
8+
categories.workspace = true
9+
license.workspace = true
10+
11+
[dependencies]
12+
ug = { workspace = true }
13+
ug-cuda = { workspace = true, optional = true }
14+
ug-metal = { workspace = true, optional = true }
15+
16+
[features]
17+
default = []
18+
cuda = ["dep:ug-cuda"]
19+
metal = ["dep:ug-metal"]

0 commit comments

Comments
 (0)