Skip to content

Commit 5293865

Browse files
authored
Merge branch 'main' into mf/cpu-varlen-flash-attention
2 parents 962ba6a + d8fb848 commit 5293865

File tree

19 files changed

+249
-97
lines changed

19 files changed

+249
-97
lines changed

Cargo.toml

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ members = [
66
"candle-nn",
77
"candle-pyo3",
88
"candle-transformers",
9+
"candle-ug",
910
"candle-wasm-examples/*",
1011
"candle-wasm-tests",
1112
"tensor-tools",
@@ -43,6 +44,7 @@ candle-metal-kernels = { path = "./candle-metal-kernels", version = "0.9.2-alpha
4344
candle-nn = { path = "./candle-nn", version = "0.9.2-alpha.2" }
4445
candle-onnx = { path = "./candle-onnx", version = "0.9.2-alpha.2" }
4546
candle-transformers = { path = "./candle-transformers", version = "0.9.2-alpha.2" }
47+
candle-ug = { path = "./candle-ug", version = "0.9.2-alpha.2" }
4648
clap = { version = "4.2.4", features = ["derive"] }
4749
criterion = { version = "0.7.0", default-features = false }
4850
cudarc = { version = "0.18.1", features = [
@@ -65,10 +67,7 @@ half = { version = "2.5.0", features = [
6567
"use-intrinsics",
6668
"rand_distr",
6769
] }
68-
float8 = { version = "0.5.0", features = [
69-
"num-traits",
70-
"rand_distr",
71-
] }
70+
float8 = { version = "0.5.0", features = ["num-traits", "rand_distr"] }
7271
hound = "3.5.1"
7372
image = { version = "0.25.2", default-features = false, features = [
7473
"jpeg",

candle-core/Cargo.toml

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,9 +35,7 @@ yoke = { workspace = true }
3535
zip = { workspace = true }
3636

3737
[target.'cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))'.dependencies]
38-
ug = { workspace = true }
39-
ug-cuda = { workspace = true, optional = true }
40-
ug-metal = { workspace = true, optional = true }
38+
candle-ug = { workspace = true, optional = true }
4139

4240
[dev-dependencies]
4341
anyhow = { workspace = true }
@@ -46,7 +44,7 @@ criterion = { workspace = true }
4644

4745
[features]
4846
default = []
49-
cuda = ["cudarc", "dep:candle-kernels", "dep:ug-cuda", "float8/cuda"]
47+
cuda = ["cudarc", "dep:candle-kernels", "candle-ug?/cuda", "float8/cuda"]
5048
cudnn = ["cuda", "cudarc/cudnn"]
5149
nccl = ["cuda", "cudarc/nccl"]
5250
mkl = ["dep:libc", "dep:intel-mkl-src"]
@@ -55,8 +53,9 @@ metal = [
5553
"dep:objc2-metal",
5654
"dep:objc2-foundation",
5755
"dep:candle-metal-kernels",
58-
"dep:ug-metal",
56+
"candle-ug?/metal",
5957
]
58+
ug = ["dep:candle-ug"]
6059

6160
[[bench]]
6261
name = "bench_main"

candle-core/src/cuda_backend/device.rs

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -173,14 +173,14 @@ impl CudaDevice {
173173
self.context.is_event_tracking()
174174
}
175175

176-
#[cfg(not(target_arch = "wasm32"))]
176+
#[cfg(all(feature = "ug", not(target_arch = "wasm32")))]
177177
pub fn compile(
178178
&self,
179179
func_name: &'static str,
180-
kernel: ug::lang::ssa::Kernel,
180+
kernel: candle_ug::lang::ssa::Kernel,
181181
) -> Result<CudaFunc> {
182182
let mut buf = vec![];
183-
ug_cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
183+
candle_ug::cuda::code_gen::gen(&mut buf, func_name, &kernel)?;
184184
let cuda_code = String::from_utf8(buf)?;
185185
let opts = cudarc::nvrtc::CompileOptions {
186186
use_fast_math: Some(true),

candle-core/src/cuda_backend/mod.rs

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1012,7 +1012,7 @@ impl Map1 for UpsampleBilinear2D {
10121012

10131013
// SAFETY: Set later by running the kernel.
10141014
let out = unsafe { dev.alloc::<T>(dst_el)? };
1015-
let ds = dev.memcpy_stod(&ds)?;
1015+
let ds = dev.clone_htod(&ds)?;
10161016

10171017
let mut builder = func.builder();
10181018
barg!(builder, out_w);

candle-core/src/custom_op.rs

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,7 @@ impl Tensor {
376376
}
377377
}
378378

379+
#[cfg(feature = "ug")]
379380
pub struct UgIOp1 {
380381
name: &'static str,
381382
#[cfg(feature = "cuda")]
@@ -384,12 +385,13 @@ pub struct UgIOp1 {
384385
func: candle_metal_kernels::metal::ComputePipeline,
385386
}
386387

388+
#[cfg(feature = "ug")]
387389
impl UgIOp1 {
388390
#[allow(unused)]
389391
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
390392
pub fn new(
391393
name: &'static str,
392-
kernel: ug::lang::ssa::Kernel,
394+
kernel: candle_ug::lang::ssa::Kernel,
393395
device: &crate::Device,
394396
) -> Result<Self> {
395397
#[cfg(feature = "cuda")]
@@ -414,6 +416,7 @@ impl UgIOp1 {
414416
}
415417
}
416418

419+
#[cfg(feature = "ug")]
417420
impl InplaceOp1 for UgIOp1 {
418421
fn name(&self) -> &'static str {
419422
self.name

candle-core/src/error.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,9 @@ pub enum Error {
174174
#[error("Metal error {0}")]
175175
Metal(#[from] MetalError),
176176

177-
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
177+
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios"), feature = "ug"))]
178178
#[error(transparent)]
179-
Ug(#[from] ug::Error),
179+
Ug(#[from] candle_ug::Error),
180180

181181
#[error(transparent)]
182182
TryFromIntError(#[from] core::num::TryFromIntError),

candle-core/src/lib.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,7 +92,9 @@ mod variable;
9292
pub use cuda_backend::cudnn;
9393

9494
pub use cpu_backend::{CpuStorage, CpuStorageRef};
95-
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3, UgIOp1};
95+
#[cfg(feature = "ug")]
96+
pub use custom_op::UgIOp1;
97+
pub use custom_op::{CustomOp1, CustomOp2, CustomOp3, InplaceOp1, InplaceOp2, InplaceOp3};
9698
pub use device::{Device, DeviceLocation, NdArray};
9799
pub use dtype::{DType, DTypeParseError, FloatDType, IntDType, WithDType};
98100
pub use dummy_dtype::{F4, F6E2M3, F6E3M2, F8E8M0};

candle-core/src/metal_backend/device.rs

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,17 @@
11
use crate::{DType, Result};
2+
3+
#[cfg(feature = "ug")]
4+
use candle_metal_kernels::metal::ComputePipeline;
25
use candle_metal_kernels::{
36
metal::{
4-
BlitCommandEncoder, Buffer, BufferMap, Commands, ComputeCommandEncoder, ComputePipeline,
5-
Device, MTLResourceOptions,
7+
BlitCommandEncoder, Buffer, BufferMap, Commands, ComputeCommandEncoder, Device,
8+
MTLResourceOptions,
69
},
710
Kernels,
811
};
912
use objc2_foundation::NSURL;
1013
use objc2_metal::{MTLCaptureDescriptor, MTLCaptureDestination, MTLCaptureManager};
14+
1115
use std::path::Path;
1216
use std::sync::{Arc, Mutex, RwLock};
1317

@@ -88,14 +92,14 @@ impl std::ops::Deref for MetalDevice {
8892
}
8993

9094
impl MetalDevice {
91-
#[cfg(all(not(target_arch = "wasm32"), not(target_os = "ios")))]
95+
#[cfg(all(feature = "ug", not(target_arch = "wasm32"), not(target_os = "ios")))]
9296
pub fn compile(
9397
&self,
9498
func_name: &'static str,
95-
kernel: ug::lang::ssa::Kernel,
99+
kernel: candle_ug::lang::ssa::Kernel,
96100
) -> Result<ComputePipeline> {
97101
let mut buf = vec![];
98-
ug_metal::code_gen::gen(&mut buf, func_name, &kernel)?;
102+
candle_ug::metal::code_gen::gen(&mut buf, func_name, &kernel)?;
99103
let metal_code = String::from_utf8(buf)?;
100104
let lib = self
101105
.device

candle-core/src/op.rs

Lines changed: 25 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -247,31 +247,31 @@ pub trait BinaryOpT {
247247
fn i64_vec(_xs1: &[i64], _xs2: &[i64], _ys: &mut [i64]) {}
248248
}
249249

250-
pub(crate) struct Add;
251-
pub(crate) struct Div;
252-
pub(crate) struct Mul;
253-
pub(crate) struct Sub;
254-
pub(crate) struct Maximum;
255-
pub(crate) struct Minimum;
256-
pub(crate) struct Exp;
257-
pub(crate) struct Log;
258-
pub(crate) struct Sin;
259-
pub(crate) struct Cos;
260-
pub(crate) struct Abs;
261-
pub(crate) struct Neg;
262-
pub(crate) struct Recip;
263-
pub(crate) struct Sqr;
264-
pub(crate) struct Sqrt;
265-
pub(crate) struct Gelu;
266-
pub(crate) struct GeluErf;
267-
pub(crate) struct Erf;
268-
pub(crate) struct Relu;
269-
pub(crate) struct Silu;
270-
pub(crate) struct Tanh;
271-
pub(crate) struct Floor;
272-
pub(crate) struct Ceil;
273-
pub(crate) struct Round;
274-
pub(crate) struct Sign;
250+
pub struct Add;
251+
pub struct Div;
252+
pub struct Mul;
253+
pub struct Sub;
254+
pub struct Maximum;
255+
pub struct Minimum;
256+
pub struct Exp;
257+
pub struct Log;
258+
pub struct Sin;
259+
pub struct Cos;
260+
pub struct Abs;
261+
pub struct Neg;
262+
pub struct Recip;
263+
pub struct Sqr;
264+
pub struct Sqrt;
265+
pub struct Gelu;
266+
pub struct GeluErf;
267+
pub struct Erf;
268+
pub struct Relu;
269+
pub struct Silu;
270+
pub struct Tanh;
271+
pub struct Floor;
272+
pub struct Ceil;
273+
pub struct Round;
274+
pub struct Sign;
275275

276276
macro_rules! bin_op {
277277
($op:ident, $name: literal, $e: expr, $f32_vec: ident, $f64_vec: ident) => {

candle-core/src/sort.rs

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,9 +85,11 @@ mod cuda {
8585
let ncols = self.last_dim;
8686
let nrows = elem_count / ncols;
8787
let ncols_pad = next_power_of_2(ncols);
88+
// Limit block dim to 1024 threads, which is the maximum on modern CUDA gpus.
89+
let block_dim = ncols_pad.min(1024);
8890
let cfg = LaunchConfig {
8991
grid_dim: (nrows as u32, 1, 1),
90-
block_dim: (ncols_pad as u32, 1, 1),
92+
block_dim: (block_dim as u32, 1, 1),
9193
shared_mem_bytes: (ncols_pad * std::mem::size_of::<u32>()) as u32,
9294
};
9395
let stream = dev.cuda_stream();

0 commit comments

Comments
 (0)