Skip to content

Commit

Permalink
feat: Added metal support fully
Browse files Browse the repository at this point in the history
  • Loading branch information
uttarayan21 committed Dec 23, 2024
1 parent 857fbe0 commit 88f8523
Showing 1 changed file with 37 additions and 3 deletions.
40 changes: 37 additions & 3 deletions mnn-sys/build.rs
Original file line number Diff line number Diff line change
Expand Up @@ -708,7 +708,8 @@ pub fn mnn_cpp_build(vendor: impl AsRef<Path>) -> Result<()> {
}

#[cfg(feature = "metal")]
metal(&vendor);
// metal(&vendor, &includes);
let build = metal_builder(build, vendor).change_context(Error)?;

build
.try_compile("mnn")
Expand Down Expand Up @@ -983,7 +984,38 @@ impl<P: AsRef<Path>> HasExtension for P {
}
}

pub fn metal(source: impl AsRef<Path>) {
pub fn metal<P: AsRef<Path>>(source: impl AsRef<Path>, includes: impl AsRef<[P]>) {
let metal_source_dir = source.as_ref().join("source/backend/metal");
let metal_files = ignore::Walk::new(&metal_source_dir)
.flatten()
.filter(|e| e.path().has_extension(["cpp"]))
.map(ignore::DirEntry::into_path);
// .chain(core::iter::once(
// metal_source_dir.join("MetalOPRegister.mm"),
// ));

// if cfg!(feature = "support_render") {
// let render_dir = current_list_dir.join("render");
// for entry in fs::read_dir(render_dir).unwrap() {
// let path = entry.unwrap().path();
// if let Some(ext) = path.extension() {
// if ext == "mm" || ext == "hpp" || ext == "cpp" {
// metal_files.push(path);
// }
// }
// }
// }

cc::Build::new()
.files(metal_files)
.include(source.as_ref().join("source"))
.includes(includes.as_ref())
.flag("-fobjc-arc")
.flag("-DMNN_METAL_ENABLED=1")
.compile("MNNMetal");
}

pub fn metal_builder(mut build: cc::Build, source: impl AsRef<Path>) -> Result<cc::Build> {
let metal_source_dir = source.as_ref().join("source/backend/metal");
let metal_files = ignore::Walk::new(&metal_source_dir)
.flatten()
Expand All @@ -1005,10 +1037,12 @@ pub fn metal(source: impl AsRef<Path>) {
// }
// }

cc::Build::new()
build
.clone()
.files(metal_files)
.include(source.as_ref().join("source"))
.flag("-fobjc-arc")
.flag("-DMNN_METAL_ENABLED=1")
.compile("MNNMetal");
Ok(build)
}

0 comments on commit 88f8523

Please sign in to comment.