From 3f337bf4e9c604b8c00fea43d3b0a5baf08e4b28 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 6 May 2024 13:40:41 +0200 Subject: [PATCH 1/5] Work on GPU support --- Cargo.lock | 1427 ++++++++++++++++++--------------- Cargo.toml | 10 +- python/nutpie/compile_pymc.py | 1 + src/iree.rs | 723 +++++++++++++++++ src/lib.rs | 6 + src/ort.rs | 215 +++++ src/torch.rs | 147 ++++ src/wrapper.rs | 42 + 8 files changed, 1923 insertions(+), 648 deletions(-) create mode 100644 src/iree.rs create mode 100644 src/ort.rs create mode 100644 src/torch.rs diff --git a/Cargo.lock b/Cargo.lock index 6dd8489..28a79fe 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,23 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "adler" +version = "1.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + [[package]] name = "ahash" version = "0.8.11" @@ -9,7 +26,6 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", - "const-random", "getrandom", "once_cell", "version_check", @@ -25,21 +41,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - -[[package]] -name = "android_system_properties" -version = "0.1.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" -dependencies = [ - "libc", -] - [[package]] name = "anes" version = "0.1.6" @@ -54,179 +55,29 @@ checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" [[package]] name = "anyhow" -version = "1.0.86" +version = "1.0.82" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" +checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" [[package]] -name = "arrow" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" -dependencies = [ - "arrow-arith", - "arrow-array", - "arrow-buffer", - "arrow-cast", - "arrow-data", - "arrow-ord", - "arrow-row", - "arrow-schema", - "arrow-select", - "arrow-string", -] - -[[package]] -name = "arrow-arith" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "num", -] - -[[package]] -name = "arrow-array" -version = "52.1.0" +name = "arrow2" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" +checksum = "963fef509b757bcbbf9e5ffa23bcb345614d99f4f6f531f97417b27b8604d389" dependencies = [ "ahash", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "chrono", - "half", - "hashbrown", - "num", -] - -[[package]] -name = "arrow-buffer" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" -dependencies = [ - "bytes", - "half", - "num", -] - -[[package]] -name = "arrow-cast" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "atoi", - "base64", + "bytemuck", "chrono", - "half", - "lexical-core", - "num", - "ryu", -] - -[[package]] -name = "arrow-data" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" -dependencies = [ - "arrow-buffer", - "arrow-schema", - "half", - "num", -] - -[[package]] -name = "arrow-ord" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "half", - "num", -] - -[[package]] -name = "arrow-row" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "half", + "dyn-clone", + "either", + "ethnum", + "foreign_vec", + "getrandom", + "hash_hasher", "hashbrown", -] - -[[package]] -name = "arrow-schema" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" -dependencies = [ - "bitflags 2.6.0", -] - -[[package]] -name = "arrow-select" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" -dependencies = [ - "ahash", - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "num", -] - -[[package]] -name = "arrow-string" -version = "52.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" -dependencies = [ - "arrow-array", - "arrow-buffer", - "arrow-data", - "arrow-schema", - "arrow-select", - "memchr", - "num", - "regex", - "regex-syntax", -] - -[[package]] -name = "atoi" -version = "2.0.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" -dependencies = [ "num-traits", + "rustc_version", + "simdutf8", ] [[package]] @@ -241,13 +92,19 @@ version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" +[[package]] +name = "base64ct" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8c3c1a368f70d6cf7302d78f8f7093da241fb8e8807c05cc9e51a125895a6d5b" + [[package]] name = "bindgen" version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.5.0", "cexpr", "clang-sys", "itertools 0.12.1", @@ -260,7 +117,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.72", + "syn 2.0.60", "which", ] @@ -272,9 +129,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.6.0" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" [[package]] name = "block-buffer" @@ -287,14 +144,12 @@ dependencies = [ [[package]] name = "bridgestan" -version = "2.5.0" +version = "2.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc8db213e11ba8b22c444912e269f8164d4af17292e6e78d9d6a4162225e929b" +checksum = "9f9cac326d10621223fcf840aee74a9b8087ce93921ecf1ae82ede1581d4366c" dependencies = [ "bindgen", "libloading", - "log", - "path-absolutize", "thiserror", ] @@ -306,22 +161,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.16.1" +version = "1.15.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.7.0" +version = "1.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" +checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] @@ -331,10 +186,25 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] -name = "bytes" -version = "1.6.1" +name = "bzip2" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" +checksum = "bdb116a6ef3f6c3698828873ad02c3014b3c85cadb88496095628e3ef1e347f8" +dependencies = [ + "bzip2-sys", + "libc", +] + +[[package]] +name = "bzip2-sys" +version = "0.1.11+1.0.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "736a955f3fa7875102d57c82b8cac37ec45224a07fd32d58f9f7a186b6cd4cdc" +dependencies = [ + "cc", + "libc", + "pkg-config", +] [[package]] name = "cast" @@ -344,9 +214,14 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.1.6" +version = "1.0.97" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" +checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" +dependencies = [ + "jobserver", + "libc", + "once_cell", +] [[package]] name = "cexpr" @@ -369,10 +244,7 @@ version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ - "android-tzdata", - "iana-time-zone", "num-traits", - "windows-targets", ] [[package]] @@ -402,11 +274,21 @@ dependencies = [ "half", ] +[[package]] +name = "cipher" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "773f3b9af64447d2ce9850330c473515014aa235e6a783b02db81ff39e4a3dad" +dependencies = [ + "crypto-common", + "inout", +] + [[package]] name = "clang-sys" -version = "1.8.1" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" +checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" dependencies = [ "glob", "libc", @@ -415,18 +297,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.10" +version = "4.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8f6b81fb3c84f5563d509c59b5a48d935f689e993afa90fe39047f05adef9142" +checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.10" +version = "4.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ca6706fd5224857d9ac5eb9355f6683563cc0541c7cd9d014043b57cbec78ac" +checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" dependencies = [ "anstyle", "clap_lex", @@ -434,54 +316,30 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" - -[[package]] -name = "coe-rs" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" - -[[package]] -name = "console" -version = "0.15.8" +version = "0.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" -dependencies = [ - "encode_unicode", - "lazy_static", - "libc", - "unicode-width", - "windows-sys", -] +checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" [[package]] -name = "const-random" -version = "0.1.18" +name = "cmake" +version = "0.1.50" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" dependencies = [ - "const-random-macro", + "cc", ] [[package]] -name = "const-random-macro" -version = "0.1.16" +name = "coe-rs" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" -dependencies = [ - "getrandom", - "once_cell", - "tiny-keccak", -] +checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" [[package]] -name = "core-foundation-sys" -version = "0.8.6" +name = "constant_time_eq" +version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" +checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" [[package]] name = "cpufeatures" @@ -492,6 +350,15 @@ dependencies = [ "libc", ] +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + [[package]] name = "criterion" version = "0.5.1" @@ -549,9 +416,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.20" +version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" [[package]] name = "crunchy" @@ -575,6 +442,15 @@ version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e6ca96b45ca70b8045e0462f191bd209fcb3c3bfe8dbfb1257ada54c4dd59169" +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + [[package]] name = "digest" version = "0.10.7" @@ -583,8 +459,15 @@ checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", + "subtle", ] +[[package]] +name = "dyn-clone" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" + [[package]] name = "dyn-stack" version = "0.10.0" @@ -596,16 +479,31 @@ dependencies = [ ] [[package]] -name = "either" -version = "1.13.0" +name = "eerie" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4f697edddf865085bb1238c234ff46713c2baf0a353687b1ae3d15e5eaa91c6b" +dependencies = [ + "eerie-sys", + "log", + "thiserror", +] + +[[package]] +name = "eerie-sys" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" +checksum = "470184b575c1acd2895b719e7510efd844b8fac49a00a6b671905888664fda89" +dependencies = [ + "bindgen", + "cmake", +] [[package]] -name = "encode_unicode" -version = "0.3.6" +name = "either" +version = "1.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" [[package]] name = "enum-as-inner" @@ -616,44 +514,50 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] name = "equator" -version = "0.2.2" +version = "0.1.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" +checksum = "a3b0a88aa91d0ad2b9684e4479aed31a17d3f9051bdbbc634bd2c01bc5a5eee8" dependencies = [ "equator-macro", ] [[package]] name = "equator-macro" -version = "0.2.1" +version = "0.1.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" +checksum = "60d08acb9849f7fb4401564f251be5a526829183a3645a90197dea8e786cf3ae" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] name = "errno" -version = "0.3.9" +version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" dependencies = [ "libc", "windows-sys", ] +[[package]] +name = "ethnum" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" + [[package]] name = "faer" -version = "0.19.1" +version = "0.18.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "41543c4de4bfb32efdffdd75cbcca5ef41b800e8a811ea4a41fb9393c6ef3bc0" +checksum = "e547492d9b55c4ea882584e691ed092228981e337d0c800bc721301d7e61e40a" dependencies = [ "bytemuck", "coe-rs", @@ -665,7 +569,6 @@ dependencies = [ "libm", "matrixcompare", "matrixcompare-core", - "nano-gemm", "npyz", "num-complex", "num-traits", @@ -679,9 +582,9 @@ dependencies = [ [[package]] name = "faer-entity" -version = "0.19.0" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ab968a02be27be95de0f1ad0af901b865fa0866b6a9b553a6cc9cf7f19c2ce71" +checksum = "22ea5c06233193392c614a46aa3bbe3de29c1404692c8053abd9c2765a1cd159" dependencies = [ "bytemuck", "coe-rs", @@ -692,11 +595,48 @@ dependencies = [ "reborrow", ] +[[package]] +name = "filetime" +version = "0.2.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ee447700ac8aa0b2f2bd7bc4462ad686ba06baa6727ac149a2d6277f0d240fd" +dependencies = [ + "cfg-if", + "libc", + "redox_syscall 0.4.1", + "windows-sys", +] + +[[package]] +name = "flate2" +version = "1.0.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5f54427cfd1c7829e2a139fcefea601bf088ebca651d2bf53ebc600eac295dae" +dependencies = [ + "crc32fast", + "miniz_oxide", +] + +[[package]] +name = "foreign_vec" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" + +[[package]] +name = "form_urlencoded" +version = "1.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" +dependencies = [ + "percent-encoding", +] + [[package]] name = "gemm" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e400f2ffd14e7548356236c35dc39cad6666d833a852cb8a8f3f28029359bb03" +checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" dependencies = [ "dyn-stack", "gemm-c32", @@ -714,9 +654,9 @@ dependencies = [ [[package]] name = "gemm-c32" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "10dc4a6176c8452d60eac1a155b454c91c668f794151a303bf3c75ea2874812d" +checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" dependencies = [ "dyn-stack", "gemm-common", @@ -729,9 +669,9 @@ dependencies = [ [[package]] name = "gemm-c64" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cc2032ce2c0bb150da0256338759a6fb01ca056f6dfe28c4d14af32d7f878f6f" +checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" dependencies = [ "dyn-stack", "gemm-common", @@ -744,9 +684,9 @@ dependencies = [ [[package]] name = "gemm-common" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90fd234fc525939654f47b39325fd5f55e552ceceea9135f3aa8bdba61eabef6" +checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" dependencies = [ "bytemuck", "dyn-stack", @@ -764,9 +704,9 @@ dependencies = [ [[package]] name = "gemm-f16" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3fc3652651f96a711d46b8833e1fac27a864be4bdfa81a374055f33ddd25c0c6" +checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" dependencies = [ "dyn-stack", "gemm-common", @@ -782,9 +722,9 @@ dependencies = [ [[package]] name = "gemm-f32" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "acbc51c44ae3defd207e6d9416afccb3c4af1e7cef5e4960e4c720ac4d6f998e" +checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" dependencies = [ "dyn-stack", "gemm-common", @@ -797,9 +737,9 @@ dependencies = [ [[package]] name = "gemm-f64" -version = "0.18.0" +version = "0.17.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3f37fc86e325c2415a4d0cab8324a0c5371ec06fc7d2f9cb1636fcfc9536a8d8" +checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" dependencies = [ "dyn-stack", "gemm-common", @@ -822,13 +762,15 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.15" +version = "0.2.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" dependencies = [ "cfg-if", + "js-sys", "libc", "wasi", + "wasm-bindgen", ] [[package]] @@ -849,11 +791,20 @@ dependencies = [ "num-traits", ] +[[package]] +name = "hash_hasher" +version = "2.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" + [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" +dependencies = [ + "ahash", +] [[package]] name = "heck" @@ -868,48 +819,31 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] -name = "home" -version = "0.5.9" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" -dependencies = [ - "windows-sys", -] - -[[package]] -name = "iana-time-zone" -version = "0.1.60" +name = "hmac" +version = "0.12.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +checksum = "6c49c37c09c17a53d937dfbb742eb3a961d65a994e6bcdcf37e7399d0cc8ab5e" dependencies = [ - "android_system_properties", - "core-foundation-sys", - "iana-time-zone-haiku", - "js-sys", - "wasm-bindgen", - "windows-core", + "digest", ] [[package]] -name = "iana-time-zone-haiku" -version = "0.1.2" +name = "home" +version = "0.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +checksum = "e3d1354bf6b7235cb4a0576c2619fd4ed18183f689b12b006a0ee7329eeff9a5" dependencies = [ - "cc", + "windows-sys", ] [[package]] -name = "indicatif" -version = "0.17.8" +name = "idna" +version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +checksum = "634d9b1461af396cad843f47fdba5597a4f9e6ddd4bfb6ff5d85028c25cb12f6" dependencies = [ - "console", - "instant", - "number_prefix", - "portable-atomic", - "unicode-width", + "unicode-bidi", + "unicode-normalization", ] [[package]] @@ -919,12 +853,12 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b248f5224d1d606005e02c97f5aa4e88eeb230488bcc03bc9ca4d7991399f2b5" [[package]] -name = "instant" -version = "0.1.13" +name = "inout" +version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +checksum = "a0c10553d664a4d0bcff9f4215d0aac67a639cc68ef660840afe309b807bc9f5" dependencies = [ - "cfg-if", + "generic-array", ] [[package]] @@ -956,21 +890,21 @@ dependencies = [ "either", ] -[[package]] -name = "itertools" -version = "0.13.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" -dependencies = [ - "either", -] - [[package]] name = "itoa" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" +[[package]] +name = "jobserver" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2b099aaa34a9751c5bf0878add70444e1ed2dd73f347be99003d4577277de6e" +dependencies = [ + "libc", +] + [[package]] name = "js-sys" version = "0.3.69" @@ -982,9 +916,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.5.0" +version = "1.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" [[package]] name = "lazycell" @@ -992,81 +926,17 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" -[[package]] -name = "lexical-core" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" -dependencies = [ - "lexical-parse-float", - "lexical-parse-integer", - "lexical-util", - "lexical-write-float", - "lexical-write-integer", -] - -[[package]] -name = "lexical-parse-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" -dependencies = [ - "lexical-parse-integer", - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-parse-integer" -version = "0.8.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" -dependencies = [ - "lexical-util", - "static_assertions", -] - -[[package]] -name = "lexical-util" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" -dependencies = [ - "static_assertions", -] - -[[package]] -name = "lexical-write-float" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" -dependencies = [ - "lexical-util", - "lexical-write-integer", - "static_assertions", -] - -[[package]] -name = "lexical-write-integer" -version = "0.8.5" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" -dependencies = [ - "lexical-util", - "static_assertions", -] - [[package]] name = "libc" -version = "0.2.155" +version = "0.2.154" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" +checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" [[package]] name = "libloading" -version = "0.8.5" +version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" +checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" dependencies = [ "cfg-if", "windows-targets", @@ -1080,9 +950,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "linux-raw-sys" -version = "0.4.14" +version = "0.4.13" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" [[package]] name = "lock_api" @@ -1096,9 +966,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.22" +version = "0.4.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" [[package]] name = "matrixcompare" @@ -1128,9 +998,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.4" +version = "2.7.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" [[package]] name = "memoffset" @@ -1147,6 +1017,15 @@ version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" +[[package]] +name = "miniz_oxide" +version = "0.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +dependencies = [ + "adler", +] + [[package]] name = "multiversion" version = "0.7.4" @@ -1169,76 +1048,6 @@ dependencies = [ "target-features", ] -[[package]] -name = "nano-gemm" -version = "0.1.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f563548d38f390ef9893e4883ec38c1fb312f569e98d76bededdd91a3b41a043" -dependencies = [ - "equator", - "nano-gemm-c32", - "nano-gemm-c64", - "nano-gemm-codegen", - "nano-gemm-core", - "nano-gemm-f32", - "nano-gemm-f64", - "num-complex", -] - -[[package]] -name = "nano-gemm-c32" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a40449e57a5713464c3a1208c4c3301c8d29ee1344711822cf022bc91373a91b" -dependencies = [ - "nano-gemm-codegen", - "nano-gemm-core", - "num-complex", -] - -[[package]] -name = "nano-gemm-c64" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "743a6e6211358fba85d1009616751e4107da86f4c95b24e684ce85f25c25b3bf" -dependencies = [ - "nano-gemm-codegen", - "nano-gemm-core", - "num-complex", -] - -[[package]] -name = "nano-gemm-codegen" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "963bf7c7110d55430169dc74c67096375491ed580cd2ef84842550ac72e781fa" - -[[package]] -name = "nano-gemm-core" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fe3fc4f83ae8861bad79dc3c016bd6b0220da5f9de302e07d3112d16efc24aa6" - -[[package]] -name = "nano-gemm-f32" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e3681b7ce35658f79da94b7f62c60a005e29c373c7111ed070e3bf64546a8bb" -dependencies = [ - "nano-gemm-codegen", - "nano-gemm-core", -] - -[[package]] -name = "nano-gemm-f64" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc1e619ed04d801809e1f63e61b669d380c4119e8b0cdd6ed184c6b111f046d8" -dependencies = [ - "nano-gemm-codegen", - "nano-gemm-core", -] - [[package]] name = "ndarray" version = "0.15.6" @@ -1273,69 +1082,39 @@ dependencies = [ "py_literal", ] -[[package]] -name = "num" -version = "0.4.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" -dependencies = [ - "num-bigint", - "num-complex", - "num-integer", - "num-iter", - "num-rational", - "num-traits", -] - [[package]] name = "num-bigint" -version = "0.4.6" +version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" +checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ + "autocfg", "num-integer", "num-traits", ] [[package]] name = "num-complex" -version = "0.4.6" +version = "0.4.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" dependencies = [ "bytemuck", "num-traits", - "rand", -] - -[[package]] -name = "num-integer" -version = "0.1.46" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" -dependencies = [ - "num-traits", ] [[package]] -name = "num-iter" -version = "0.1.45" +name = "num-conv" +version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" -dependencies = [ - "autocfg", - "num-integer", - "num-traits", -] +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] -name = "num-rational" -version = "0.4.2" +name = "num-integer" +version = "0.1.46" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +checksum = "7969661fd2958a5cb096e56c8e1ad0444ac2bbcd0061bd28660485a44879858f" dependencies = [ - "num-bigint", - "num-integer", "num-traits", ] @@ -1347,13 +1126,7 @@ checksum = "071dfc062690e90b734c0b2273ce72ad0ffa95f0c74596bc250dcfd960262841" dependencies = [ "autocfg", "libm", -] - -[[package]] -name = "number_prefix" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" +] [[package]] name = "numpy" @@ -1372,22 +1145,25 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.13.1" +version = "0.10.0" dependencies = [ "anyhow", - "arrow", + "arrow2", "bridgestan", "criterion", - "indicatif", - "itertools 0.13.0", + "eerie", + "itertools 0.12.1", + "ndarray", "numpy", "nuts-rs", + "ort", "pyo3", "rand", "rand_chacha", "rand_distr", "rayon", "smallvec", + "tch", "thiserror", "time-humanize", "upon", @@ -1395,14 +1171,12 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.12.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d8573e3b5c83e8ec0570ebbd75dd6fdc7dfcfa5da9b5f9d9d63fedefebbd9cf8" +version = "0.9.0" dependencies = [ "anyhow", - "arrow", + "arrow2", "faer", - "itertools 0.13.0", + "itertools 0.12.1", "multiversion", "pulp", "rand", @@ -1420,15 +1194,43 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oorandom" -version = "11.1.4" +version = "11.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" + +[[package]] +name = "ort" +version = "2.0.0-rc.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" +dependencies = [ + "half", + "js-sys", + "libloading", + "ndarray", + "ort-sys", + "thiserror", + "tracing", + "web-sys", +] + +[[package]] +name = "ort-sys" +version = "2.0.0-rc.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" +checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" +dependencies = [ + "flate2", + "sha2", + "tar", + "ureq", +] [[package]] name = "parking_lot" -version = "0.12.3" +version = "0.12.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" +checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" dependencies = [ "lock_api", "parking_lot_core", @@ -1442,40 +1244,51 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall", + "redox_syscall 0.5.1", "smallvec", "windows-targets", ] +[[package]] +name = "password-hash" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7676374caaee8a325c9e7a2ae557f216c5563a171d6997b0ef8a65af35147700" +dependencies = [ + "base64ct", + "rand_core", + "subtle", +] + [[package]] name = "paste" -version = "1.0.15" +version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] -name = "path-absolutize" -version = "3.1.1" +name = "pbkdf2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e4af381fe79fa195b4909485d99f73a80792331df0625188e707854f0b3383f5" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" dependencies = [ - "path-dedot", + "digest", + "hmac", + "password-hash", + "sha2", ] [[package]] -name = "path-dedot" -version = "3.1.1" +name = "percent-encoding" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "07ba0ad7e047712414213ff67533e6dd477af0a4e1d14fb52343e53d30ea9397" -dependencies = [ - "once_cell", -] +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" -version = "2.7.11" +version = "2.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cd53dff83f26735fdc1ca837098ccf133605d794cdae66acfc2bfac3ec809d95" +checksum = "560131c633294438da9f7c4b08189194b20946c8274c6b9e38881a7874dc8ee8" dependencies = [ "memchr", "thiserror", @@ -1484,9 +1297,9 @@ dependencies = [ [[package]] name = "pest_derive" -version = "2.7.11" +version = "2.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a548d2beca6773b1c244554d36fcf8548a8a58e74156968211567250e48e49a" +checksum = "26293c9193fbca7b1a3bf9b79dc1e388e927e6cacaa78b4a3ab705a1d3d41459" dependencies = [ "pest", "pest_generator", @@ -1494,33 +1307,45 @@ dependencies = [ [[package]] name = "pest_generator" -version = "2.7.11" +version = "2.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3c93a82e8d145725dcbaf44e5ea887c8a869efdcc28706df2d08c69e17077183" +checksum = "3ec22af7d3fb470a85dd2ca96b7c577a1eb4ef6f1683a9fe9a8c16e136c04687" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] name = "pest_meta" -version = "2.7.11" +version = "2.7.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a941429fea7e08bedec25e4f6785b6ffaacc6b755da98df5ef3e7dcf4a124c4f" +checksum = "d7a240022f37c361ec1878d646fc5b7d7c4d28d5946e1a80ad5a7a4f4ca0bdcd" dependencies = [ "once_cell", "pest", "sha2", ] +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "pkg-config" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" + [[package]] name = "plotters" -version = "0.3.6" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" +checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" dependencies = [ "num-traits", "plotters-backend", @@ -1531,24 +1356,30 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.6" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" +checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" [[package]] name = "plotters-svg" -version = "0.3.6" +version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" +checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" dependencies = [ "plotters-backend", ] [[package]] name = "portable-atomic" -version = "1.7.0" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" + +[[package]] +name = "powerfmt" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "ppv-lite86" @@ -1558,28 +1389,28 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.20" +version = "0.2.19" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" +checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550" dependencies = [ "proc-macro2", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] name = "proc-macro2" -version = "1.0.86" +version = "1.0.81" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" dependencies = [ "unicode-ident", ] [[package]] name = "pulp" -version = "0.18.21" +version = "0.18.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ec8d02258294f59e4e223b41ad7e81c874aa6b15bc4ced9ba3965826da0eed5" +checksum = "e14989307e408d9f4245d4fda09a7b144a08114ba124e26cab60ab83dc98db10" dependencies = [ "bytemuck", "libm", @@ -1648,7 +1479,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] @@ -1661,7 +1492,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] @@ -1756,18 +1587,27 @@ checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" [[package]] name = "redox_syscall" -version = "0.5.3" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "redox_syscall" +version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" +checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.5.0", ] [[package]] name = "regex" -version = "1.10.5" +version = "1.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" +checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" dependencies = [ "aho-corasick", "memchr", @@ -1777,9 +1617,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.7" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" +checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" dependencies = [ "aho-corasick", "memchr", @@ -1788,9 +1628,24 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.4" +version = "0.8.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" + +[[package]] +name = "ring" +version = "0.17.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" +checksum = "c17fa4cb658e3583423e915b9f3acc01cceaee1860e33d59ebae66adc3a2dc0d" +dependencies = [ + "cc", + "cfg-if", + "getrandom", + "libc", + "spin", + "untrusted", + "windows-sys", +] [[package]] name = "rustc-hash" @@ -1798,24 +1653,74 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + [[package]] name = "rustix" version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.5.0", "errno", "libc", "linux-raw-sys", "windows-sys", ] +[[package]] +name = "rustls" +version = "0.22.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +dependencies = [ + "log", + "ring", + "rustls-pki-types", + "rustls-webpki", + "subtle", + "zeroize", +] + +[[package]] +name = "rustls-pki-types" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" + +[[package]] +name = "rustls-webpki" +version = "0.102.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" +dependencies = [ + "ring", + "rustls-pki-types", + "untrusted", +] + [[package]] name = "ryu" -version = "1.0.18" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + +[[package]] +name = "safetensors" +version = "0.3.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" +checksum = "d93279b86b3de76f820a8854dd06cbc33cfa57a417b19c47f6a25280112fb1df" +dependencies = [ + "serde", + "serde_json", +] [[package]] name = "same-file" @@ -1832,6 +1737,12 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" +[[package]] +name = "semver" +version = "1.0.22" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" + [[package]] name = "seq-macro" version = "0.3.5" @@ -1840,35 +1751,46 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.204" +version = "1.0.200" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" +checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.204" +version = "1.0.200" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" +checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] name = "serde_json" -version = "1.0.120" +version = "1.0.116" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" dependencies = [ "itoa", "ryu", "serde", ] +[[package]] +name = "sha1" +version = "0.10.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" +dependencies = [ + "cfg-if", + "cpufeatures", + "digest", +] + [[package]] name = "sha2" version = "0.10.8" @@ -1886,6 +1808,12 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" +[[package]] +name = "simdutf8" +version = "0.1.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" + [[package]] name = "smallvec" version = "1.13.2" @@ -1893,10 +1821,16 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] -name = "static_assertions" -version = "1.1.0" +name = "spin" +version = "0.9.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" + +[[package]] +name = "subtle" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" [[package]] name = "syn" @@ -1911,9 +1845,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.72" +version = "2.0.60" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" dependencies = [ "proc-macro2", "quote", @@ -1926,7 +1860,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.6.0", + "bitflags 2.5.0", "byteorder", "enum-as-inner", "libc", @@ -1934,6 +1868,17 @@ dependencies = [ "walkdir", ] +[[package]] +name = "tar" +version = "0.4.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +dependencies = [ + "filetime", + "libc", + "xattr", +] + [[package]] name = "target-features" version = "0.1.6" @@ -1942,44 +1887,69 @@ checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.15" +version = "0.12.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" +checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" + +[[package]] +name = "tch" +version = "0.16.0" +dependencies = [ + "half", + "lazy_static", + "libc", + "ndarray", + "rand", + "safetensors", + "thiserror", + "torch-sys", + "zip", +] [[package]] name = "thiserror" -version = "1.0.63" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" +checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.63" +version = "1.0.59" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" +checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", ] [[package]] -name = "time-humanize" -version = "0.1.3" +name = "time" +version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e32d019b4f7c100bcd5494e40a27119d45b71fba2b07a4684153129279a4647" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "num-conv", + "powerfmt", + "serde", + "time-core", +] [[package]] -name = "tiny-keccak" -version = "2.0.2" +name = "time-core" +version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" -dependencies = [ - "crunchy", -] +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-humanize" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e32d019b4f7c100bcd5494e40a27119d45b71fba2b07a4684153129279a4647" [[package]] name = "tinytemplate" @@ -1991,6 +1961,61 @@ dependencies = [ "serde_json", ] +[[package]] +name = "tinyvec" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +dependencies = [ + "tinyvec_macros", +] + +[[package]] +name = "tinyvec_macros" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" + +[[package]] +name = "torch-sys" +version = "0.16.0" +dependencies = [ + "anyhow", + "cc", + "libc", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-attributes", + "tracing-core", +] + +[[package]] +name = "tracing-attributes" +version = "0.1.27" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" +dependencies = [ + "proc-macro2", + "quote", + "syn 2.0.60", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" +dependencies = [ + "once_cell", +] + [[package]] name = "typenum" version = "1.17.0" @@ -2003,6 +2028,12 @@ version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" +[[package]] +name = "unicode-bidi" +version = "0.3.15" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "08f95100a766bf4f8f28f90d77e0a5461bbdb219042e7679bebe79004fed8d75" + [[package]] name = "unicode-ident" version = "1.0.12" @@ -2010,10 +2041,13 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] -name = "unicode-width" -version = "0.1.13" +name = "unicode-normalization" +version = "0.1.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" +checksum = "a56d1686db2308d901306f92a263857ef59ea39678a5458e7cb17f01415101f5" +dependencies = [ + "tinyvec", +] [[package]] name = "unindent" @@ -2021,12 +2055,45 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c7de7d73e1754487cb58364ee906a499937a0dfabd86bcb980fa99ec8c8fa2ce" +[[package]] +name = "untrusted" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" + [[package]] name = "upon" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" +[[package]] +name = "ureq" +version = "2.9.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +dependencies = [ + "base64", + "log", + "once_cell", + "rustls", + "rustls-pki-types", + "rustls-webpki", + "url", + "webpki-roots", +] + +[[package]] +name = "url" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +dependencies = [ + "form_urlencoded", + "idna", + "percent-encoding", +] + [[package]] name = "version_check" version = "0.9.4" @@ -2070,7 +2137,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", "wasm-bindgen-shared", ] @@ -2092,7 +2159,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2113,6 +2180,15 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "webpki-roots" +version = "0.26.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +dependencies = [ + "rustls-pki-types", +] + [[package]] name = "which" version = "4.4.2" @@ -2134,15 +2210,6 @@ dependencies = [ "windows-sys", ] -[[package]] -name = "windows-core" -version = "0.52.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" -dependencies = [ - "windows-targets", -] - [[package]] name = "windows-sys" version = "0.52.0" @@ -2154,9 +2221,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -2170,68 +2237,134 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" [[package]] name = "windows_aarch64_msvc" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" [[package]] name = "windows_i686_gnu" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" [[package]] name = "windows_i686_gnullvm" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" [[package]] name = "windows_i686_msvc" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" [[package]] name = "windows_x86_64_gnu" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" [[package]] name = "windows_x86_64_msvc" -version = "0.52.6" +version = "0.52.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "xattr" +version = "1.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8da84f1a25939b27f6820d92aed108f83ff920fdf11a7b19366c27c4cda81d4f" +dependencies = [ + "libc", + "linux-raw-sys", + "rustix", +] [[package]] name = "zerocopy" -version = "0.7.35" +version = "0.7.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" +checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.35" +version = "0.7.33" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" +checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393" dependencies = [ "proc-macro2", "quote", - "syn 2.0.72", + "syn 2.0.60", +] + +[[package]] +name = "zeroize" +version = "1.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "aes", + "byteorder", + "bzip2", + "constant_time_eq", + "crc32fast", + "crossbeam-utils", + "flate2", + "hmac", + "pbkdf2", + "sha1", + "time", + "zstd", +] + +[[package]] +name = "zstd" +version = "0.11.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "20cc960326ece64f010d2d2107537f26dc589a6573a316bd5b1dba685fa5fde4" +dependencies = [ + "zstd-safe", +] + +[[package]] +name = "zstd-safe" +version = "5.0.2+zstd.1.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1d2a5585e04f9eea4b2a3d1eca508c4dee9592a89ef6f450c11719da0726f4db" +dependencies = [ + "libc", + "zstd-sys", +] + +[[package]] +name = "zstd-sys" +version = "2.0.10+zstd.1.5.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +dependencies = [ + "cc", + "pkg-config", ] diff --git a/Cargo.toml b/Cargo.toml index 48634d8..8c2dc02 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -14,8 +14,11 @@ rust-version = "1.76" [features] extension-module = ["pyo3/extension-module"] -default = ["extension-module"] +default = ["extension-module", "onnx"] simd_support = ["nuts-rs/simd_support"] +iree = ["dep:eerie"] +torch = ["dep:tch"] +onnx = ["dep:ort"] [lib] name = "_lib" @@ -38,6 +41,11 @@ smallvec = "1.11.0" upon = { version = "0.8.1", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" +tch = { version = "0.16.0", optional = true } +ort = { version = "2.0.0-rc.2", optional = true, features = [ + "cuda", + "load-dynamic", +] } [dependencies.pyo3] version = "0.21.0" diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index aaee456..435d16d 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -427,6 +427,7 @@ def _compute_shapes(model): def _make_functions(model, *, mode, compute_grad, join_expanded): + # TODO do we want to freeze the model? import pytensor import pytensor.link.numba.dispatch import pytensor.tensor as pt diff --git a/src/iree.rs b/src/iree.rs new file mode 100644 index 0000000..7bd0ea6 --- /dev/null +++ b/src/iree.rs @@ -0,0 +1,723 @@ +use std::{ + io::{stderr, stdout, Write}, + mem::{forget, transmute, ManuallyDrop}, + sync::{ + mpsc::{sync_channel, Receiver, SyncSender}, + Arc, Mutex, OnceLock, + }, + thread::{spawn, JoinHandle}, +}; + +use anyhow::{anyhow, Context, Result}; +use arrow2::{ + array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, + datatypes::{DataType, Field}, +}; +use eerie::runtime::{ + api::{Call, Instance, InstanceOptions, Session, SessionOptions}, + hal::{BufferMapping, BufferView, Device, DriverRegistry, EncodingType}, + vm::{DynamicList, Function, List, Ref, ToRef, Undefined, Value}, +}; +use numpy::{PyArray1, PyReadonlyArray1, PyReadwriteArray1}; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Math, Model}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, Py, Python, +}; +use rand_distr::{num_traits::ToPrimitive, Distribution, StandardNormal}; +use thiserror::Error; + +static INSTANCE: OnceLock> = OnceLock::new(); + +fn get_instance() -> Result<&'static Instance> { + match INSTANCE.get_or_init(|| { + let mut registry = DriverRegistry::new(); + let options = InstanceOptions::new(&mut registry).use_all_available_drivers(); + let instance = Instance::new(&options)?; + + Ok(instance) + }) { + &Ok(ref instance) => Ok(instance), + &Err(ref err) => Err(anyhow!("Could not access iree instance: {}", err)), + } +} + +#[pyclass] +#[derive(Clone, Debug)] +pub struct IreeModel { + //logp_module: Box<[u8]>, + //expand_module: Box<[u8]>, + //devices: Arc]>>, + //devices: Box<[String]>, + ndim: usize, + session_maker: Arc>>>>>, + maker_thread: Arc>>, +} + +#[pymethods] +impl IreeModel { + #[new] + pub fn new_py<'py>( + device: String, + logp_module: Bound<'py, PyBytes>, + expand_module: Bound<'py, PyBytes>, + ndim: usize, + ) -> Result { + let logp_module: Box<[u8]> = logp_module.as_bytes().into(); + let expand_module: Box<[u8]> = expand_module.as_bytes().into(); + + Self::new(device, logp_module, expand_module, ndim) + } + + pub fn call_logp( + &self, + position: PyReadonlyArray1, + mut gradient: PyReadwriteArray1, + ) -> Result { + let mut math = self.math()?; + let logp = math.logp(&position.as_slice()?, gradient.as_slice_mut()?)?; + Ok(logp) + } +} + +impl IreeModel { + fn new( + device: String, + logp_module: Box<[u8]>, + expand_module: Box<[u8]>, + ndim: usize, + ) -> Result { + let (session_maker_sender, session_maker) = sync_channel(0); + + let maker_thread = spawn(move || { + let run_loop = move || { + let instance = get_instance()?; + let device = instance.try_create_default_device(&device)?; + + // FIXME + let device: Device<'static> = unsafe { transmute(device) }; + let devices = vec![device]; + + let logp_module: Arc<[u8]> = Arc::from(logp_module); + let expand_module: Arc<[u8]> = Arc::from(expand_module); + + for device in devices.iter().cycle() { + let make_math = || { + let logp_func = LogpFunc::new( + ndim, + logp_module.clone(), + expand_module.clone(), + device, + )?; + Ok(CpuMath::new(logp_func)) + }; + + let math_result = make_math(); + session_maker_sender + .send(math_result) + .map_err(|_| anyhow!("Could not send iree math"))?; + } + Ok(()) + }; + let res = run_loop(); + dbg!(res) + }); + + let session_maker = Arc::new(Mutex::new(session_maker)); + let maker_thread = Arc::new(maker_thread); + + Ok(IreeModel { + ndim, + //devices: vec![device].into(), + //logp_module: logp_module.as_bytes().into(), + //expand_module: expand_module.as_bytes().into(), + maker_thread, + session_maker, + }) + } +} + +#[derive(Debug)] +pub struct LogpFunc<'model> { + pub outputs: DynamicList<'model, Undefined>, + //pub inputs: DynamicList<'model, Ref<'model, BufferView<'model, f32>>>, + pub inputs: DynamicList<'model, Undefined>, + pub logp_func: ManuallyDrop>, + pub session: ManuallyDrop>, + //pub device: ManuallyDrop>, + //pub device: &'model Device<'model>, + logp_compiled: Arc<[u8]>, + expand_compiled: Arc<[u8]>, + pub ndim: usize, + pub buffer: Box<[f32]>, +} + +impl<'model> LogpFunc<'model> { + pub fn new( + ndim: usize, + //device: &'model Device<'model>, + logp_compiled: Arc<[u8]>, + expand_compiled: Arc<[u8]>, + //session: Session<'model>, + //device: &'model str, + device: &Device<'static>, + ) -> Result { + let instance = get_instance()?; + //let device = instance.try_create_default_device(device)?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let session: Session<'static> = unsafe { transmute(session) }; + + // TODO fix the lifetime of this reference + unsafe { session.append_module_from_memory(&logp_compiled) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let call = Call::new(&session, &logp_func)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let call: Call<'model> = unsafe { transmute(call) }; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'model> = unsafe { transmute(logp_func) }; + + let inputs = DynamicList::new(2, instance)?; + let outputs = DynamicList::new(2, instance)?; + + Ok(Self { + //device: ManuallyDrop::new(device), + //device, + inputs, + outputs, + logp_compiled, + expand_compiled, + ndim, + session: ManuallyDrop::new(session), + logp_func: ManuallyDrop::new(logp_func), + buffer: vec![0.; ndim].into(), + //call, + }) + } +} + +impl<'model> Drop for LogpFunc<'model> { + fn drop(&mut self) { + unsafe { + drop(ManuallyDrop::take(&mut self.logp_func)); + drop(ManuallyDrop::take(&mut self.session)); + //drop(ManuallyDrop::take(&mut self.device)); + } + } +} + +#[derive(Error, Debug)] +pub enum IreeLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for IreeLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +impl<'model> CpuLogpFunc for LogpFunc<'model> { + type LogpError = IreeLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + let instance = get_instance()?; + + self.buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as f32); + + let input_buffer = BufferView::::new( + &self.session, + &[position.len()], + EncodingType::DenseRowMajor, + &self.buffer, + ) + .context("Could not create buffer view")?; + + let input_buffer_ref = input_buffer + .to_ref(instance) + .context("Could not create iree ref")?; + + //dbg!(&input_buffer_ref); + + self.inputs + .push_ref(&input_buffer_ref) + .context("Could not push input buffer to inputs")?; + + let logp_func = self + .session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //dbg!(&self.inputs.get_ref::>(0)); + //stderr().lock().flush(); + //stdout().lock().flush(); + + logp_func + .invoke(&self.inputs, &self.outputs) + .context("Could not invoke logp function")?; + //let mut call = Call::new(&self.session, &self.logp_func).context("Could not create iree Call")?; + + //let inputs = call.input_list(); + + //inputs.push_ref(&input_buffer_ref).context("Could not push input")?; + //drop(input_buffer_ref); + //drop(input_buffer); + //drop(inputs); + + //call.invoke().context("Could not invoke iree function")?; + + let output_val: Value = self + .outputs + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + + /* + let logp_buffer_ref: Ref> = self + .outputs + .get_ref(0) + .context("Could not get logp buffer")?; + let logp_buffer = logp_buffer_ref.to_buffer_view(&self.session); + */ + + let gradient_buffer_ref: Ref> = self + .outputs + .get_ref(1) + .context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer_ref.to_buffer_view(&self.session); + + gradient_buffer + .copy_to_host(&mut self.buffer) + .context("Could not copy gradient buffer from iree device")?; + + //let mut logp_array = [0f32]; + //logp_buffer.copy_to_host(&self.device, &mut logp_array).context("Could not copy logp value")?; + //let logp = logp_array[0]; + + drop(input_buffer_ref); + drop(input_buffer); + + drop(gradient_buffer_ref); + drop(gradient_buffer); + + //drop(logp_buffer_ref); + //drop(logp_buffer); + + self.inputs.clear(); + self.outputs.clear(); + + let mut has_bad_grad = false; + gradient + .iter_mut() + .zip(self.buffer.iter()) + .for_each(|(out, &val)| { + *out = val as f64; + if !val.is_finite() { + has_bad_grad = true; + } + }); + + if (!logp.is_finite()) | has_bad_grad { + return Err(IreeLogpError::BadLogp()); + } + + Ok(logp as f64) + } +} + +#[derive(Clone)] +pub struct IreeTrace { + trace: MutableFixedSizeListArray>, +} + +impl DrawStorage for IreeTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + Ok(()) + } + + fn finalize(mut self) -> Result> { + let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + let fields = vec![field]; + let data_type = DataType::Struct(fields); + let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); + Ok(Box::new(struct_array)) + } + + fn inspect(&mut self) -> Result> { + self.clone().finalize() + } +} + +impl Model for IreeModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = IreeTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = MutablePrimitiveArray::new(); + let trace = MutableFixedSizeListArray::new(items, self.ndim); + + Ok(IreeTrace { trace }) + } + + fn math(&self) -> Result> { + self.session_maker + .lock() + .expect("Poisoned mutex") + .recv() + .context("Could not create iree session")? + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = StandardNormal; + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use std::{ + fs::File, + io::Read, + mem::{transmute, ManuallyDrop}, + path::Path, + }; + + use anyhow::{Context, Result}; + use eerie::runtime::{ + api::{Call, Session, SessionOptions}, + hal::{BufferView, EncodingType}, + vm::{DynamicList, Function, List, Ref, ToRef, Undefined, Value}, + }; + use nuts_rs::{Math, Model}; + + use super::{get_instance, IreeModel, LogpFunc}; + + #[test] + fn run_logp_manual1() -> Result<()> { + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let instance = get_instance()?; + let device = instance.try_create_default_device("local-task")?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let session: Session<'static> = unsafe { transmute(session) }; + + unsafe { session.append_module_from_memory(&logp_compiled) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let inputs: DynamicList>> = DynamicList::new(2, instance)?; + //let inputs: DynamicList = DynamicList::new(2, instance)?; + //let outputs: DynamicList = DynamicList::new(2, instance)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'static> = unsafe { transmute(logp_func) }; + + let mut call = Call::new(&session, &logp_func)?; + + let mut buffer: Box<[f32]> = vec![0.; 2].into(); + let position = vec![1., 2.]; + let mut gradient: Box<[f64]> = vec![-1., -1.].into(); + + buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as _); + + let input_buffer = BufferView::::new( + &session, + &[position.len()], + EncodingType::DenseRowMajor, + &buffer, + ) + .context("Could not create buffer view")?; + + let input_buffer_ref = input_buffer + .to_ref(instance) + .context("Could not create iree ref")?; + + let inputs = call.input_list(); + inputs + .push_ref(&input_buffer_ref) + .context("Could not push input buffer to inputs")?; + + //dbg!(&inputs.get_ref::>(0)); + //dbg!(&input_buffer_ref); + + /* + logp_func + .invoke(&inputs, &outputs) + .context("Could not invoke logp function")?; + */ + drop(inputs); + call.invoke().context("Could not invoke iree function")?; + + drop(input_buffer_ref); + drop(input_buffer); + + // TODO For some reason it seems we need to keep this alive until after the call... + // Maybe a missing refcount increase somewhere? + + let outputs = call.output_list(); + + let output_val: Value = outputs + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + dbg!(logp); + + let gradient_buffer: Ref> = + outputs.get_ref(1).context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer.to_buffer_view(&session); + + gradient_buffer + .copy_to_host(&mut buffer) + .context("Could not copy gradient buffer from iree device")?; + + gradient + .iter_mut() + .zip(buffer.iter()) + .for_each(|(out, &val)| *out = val as _); + + dbg!(gradient); + + Ok(()) + } + + #[test] + fn run_logp_seg() -> Result<()> { + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let logp_expand = vec![]; + + let model = IreeModel::new( + "local-task".into(), + logp_compiled.into(), + logp_expand.into(), + 2, + )?; + + let mut math = model.math()?; + + let position = vec![1., 2.]; + let mut gradient = vec![-1., -1.]; + math.logp(&position, &mut gradient)?; + + drop(math); + drop(model); + + Ok(()) + } + + #[test] + fn run_logp_manual2() -> Result<()> { + /* + let path = Path::new(env!("CARGO_MANIFEST_DIR")) + .join("example-iree") + .join("example-logp.fbvm"); + let mut logp_compiled = Vec::new(); + File::open(path)?.read_to_end(&mut logp_compiled)?; + + let logp_expand = vec![]; + + let model = IreeModel::new( + "cuda".into(), + logp_compiled.into(), + logp_expand.into(), + 2, + ); + + let instance = get_instance()?; + + let device = instance.try_create_default_device(&model.devices[0])?; + + //let mut math_obj = LogpFunc::new(model.ndim, &model.logp_module, &model.expand_module, device)?; + + let mut math_obj = { + + let instance = get_instance()?; + + let options = SessionOptions::default(); + let session = Session::create_with_device(instance, &options, &device) + .context("Could not create session")?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let session: Session<'static> = unsafe { transmute(session) }; + + unsafe { session.append_module_from_memory(&model.logp_module) } + .context("Could not load iree logp function")?; + //unsafe { session.append_module_from_memory(expand_compiled) }.context("Coxd not load iree expand function")?; + + let logp_func = session + .lookup_function("jit_jax_funcified_fgraph.logp") + .context("Could not find gradient function in module")?; + + //let call = Call::new(&session, &logp_func)?; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + //let call: Call<'model> = unsafe { transmute(call) }; + + // TODO iree things are ref counted internall, so this is probably fine, but I hate it... + let logp_func: Function<'static> = unsafe { transmute(logp_func) }; + + let inputs = DynamicList::new(2, instance)?; + let outputs = DynamicList::new(2, instance)?; + + LogpFunc { + device: ManuallyDrop::new(device), + inputs, + outputs, + logp_compiled: &model.logp_module, + expand_compiled: &model.expand_module, + ndim: 2, + session: ManuallyDrop::new(session), + //logp_func: ManuallyDrop::new(logp_func), + buffer: vec![0.; 2].into(), + } + + }; + + let math = &mut math_obj; + + let position = vec![1., 2.]; + let mut gradient = vec![-1., -1.]; + + let instance = get_instance()?; + + math.buffer + .iter_mut() + .zip(position.iter()) + .for_each(|(out, &val)| *out = val as _); + + let input_buffer = BufferView::::new( + &math.session, + &[position.len()], + EncodingType::DenseRowMajor, + &math.buffer, + ) + .context("Could not create buffer view")?; + + (math.inputs) + .push_ref( + &input_buffer + .to_ref(instance) + .context("Could not dereference input buffer")?, + ) + .context("Could not push input buffer to inputs")?; + + (math.logp_func) + .invoke(&math.inputs, &math.outputs) + .context("Could not invoke logp function")?; + + drop(input_buffer); + + let output_val: Value = (math.outputs) + .get_value(0) + .context("Could not extract logp value")?; + let logp: f64 = output_val.from_value().into(); + + let gradient_buffer_ref: Ref> = (&math.outputs) + .get_ref(1) + .context("Could not get output buffer")?; + let gradient_buffer = gradient_buffer_ref.to_buffer_view(&math.session); + + gradient_buffer + .copy_into(&math.device, &mut math.buffer) + .context("Could not copy gradient buffer from iree device")?; + + let mut has_bad_grad = false; + gradient + .iter_mut() + .zip(math.buffer.iter()) + .for_each(|(out, &val)| { + *out = val as f64; + if !val.is_finite() { + has_bad_grad = true; + } + }); + + drop(gradient_buffer_ref); + drop(gradient_buffer); + + math.inputs.clear(); + math.outputs.clear(); + + drop(math); + drop(math_obj); + + /* + drop(gradient_buffer); + drop(gradient_buffer_ref); + drop(math); + drop(math_obj); + drop(model); + */ + + */ + Ok(()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 6154f92..8e956ae 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,13 @@ +#[cfg(feature = "iree")] +mod iree; +#[cfg(feature = "onnx")] +mod ort; mod progress; mod pyfunc; mod pymc; mod stan; +#[cfg(feature = "torch")] +mod torch; mod wrapper; pub use wrapper::_lib; diff --git a/src/ort.rs b/src/ort.rs new file mode 100644 index 0000000..37189a4 --- /dev/null +++ b/src/ort.rs @@ -0,0 +1,215 @@ +use anyhow::{anyhow, Context}; +use anyhow::{bail, Result}; +use arrow2::{ + array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, + datatypes::{DataType, Field}, +}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use ort::{ + inputs, CPUExecutionProvider, CUDAExecutionProvider, ExecutionProviderDispatch, + InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, Value, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, +}; +use rand_distr::{Distribution, Uniform}; +use thiserror::Error; + +#[pyclass] +#[derive(Clone, Debug)] +pub struct OnnxModel { + ndim: usize, + logp_model: Box<[u8]>, + providers: Vec, +} + +impl OnnxModel { + fn make_logp_session<'a>(&'a self) -> Result> { + let logp_session = Session::builder()? + .with_execution_providers(self.providers.iter().cloned())? + .with_memory_pattern(true)? + .commit_from_memory_directly(&self.logp_model)?; + + Ok(logp_session) + } +} + +#[pymethods] +impl OnnxModel { + #[new] + pub fn new_py<'py>( + ndim: usize, + logp_model: Bound<'py, PyBytes>, + providers: &OnnxProviders, + ) -> Result { + Ok(Self { + ndim, + providers: providers.providers.iter().cloned().collect(), + logp_model: logp_model.as_bytes().into(), + }) + } +} + +#[derive(Clone)] +pub struct OnnxTrace { + trace: MutableFixedSizeListArray>, +} + +impl DrawStorage for OnnxTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + Ok(()) + } + + fn finalize(mut self) -> Result> { + let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + let fields = vec![field]; + let data_type = DataType::Struct(fields); + let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); + Ok(Box::new(struct_array)) + } + + fn inspect(&mut self) -> Result> { + self.clone().finalize() + } +} + +#[derive(Error, Debug)] +pub enum OnnxLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for OnnxLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +pub struct OnnxLogpFunc<'model> { + session: InMemorySession<'model>, + ndim: usize, +} + +impl<'model> OnnxLogpFunc<'model> { + fn new(ndim: usize, session: InMemorySession<'model>) -> Result { + Ok(Self { session, ndim }) + } +} + +impl<'model> CpuLogpFunc for OnnxLogpFunc<'model> { + type LogpError = OnnxLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + let position = position.iter().map(|&x| x as f32).collect_vec(); + let position = + Value::from_array(([position.len()], position)).context("Could not create input")?; + let inputs = SessionInputs::ValueArray([position.into()]); + let mut outputs = self + .session + .run(inputs) + .context("Could not run logp function")?; + let logp = outputs + .pop_first() + .context("Could not extract first output")?; + let grad = outputs + .pop_first() + .context("Could not extract second output")?; + let logp: f32 = logp + .1 + .try_extract_raw_tensor() + .context("Could not read logp value")? + .1[0]; + let vals = grad + .1 + .try_extract_raw_tensor::() + .context("Could not read grad value")? + .1; + if vals.len() != gradient.len() { + Err(anyhow!("Logp return gradient with incorrect length"))?; + } + gradient + .iter_mut() + .zip(vals.iter()) + .for_each(|(mut out, &val)| *out = val as f64); + Ok(logp as f64) + } +} + +impl Model for OnnxModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = OnnxTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = MutablePrimitiveArray::new(); + let trace = MutableFixedSizeListArray::new(items, self.ndim); + + Ok(OnnxTrace { trace }) + } + + fn math(&self) -> Result> { + let session = self.make_logp_session()?; + Ok(CpuMath::new(OnnxLogpFunc::new(self.ndim, session)?)) + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = Uniform::new(-2., 2.); + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} + +#[pyclass] +pub struct OnnxProviders { + providers: Vec, +} + +#[pymethods] +impl OnnxProviders { + #[new] + pub fn new() -> Self { + Self { providers: vec![] } + } + + pub fn add_cpu(&mut self) -> Result<()> { + self.providers.push(CPUExecutionProvider::default().into()); + Ok(()) + } + + pub fn add_cuda(&mut self) -> Result<()> { + self.providers.push(CUDAExecutionProvider::default().into()); + Ok(()) + } +} diff --git a/src/torch.rs b/src/torch.rs new file mode 100644 index 0000000..d11cecb --- /dev/null +++ b/src/torch.rs @@ -0,0 +1,147 @@ +use anyhow::{anyhow, Context}; +use anyhow::{bail, Result}; +use arrow2::{ + array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, + datatypes::{DataType, Field}, +}; +use itertools::Itertools; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use ort::{ + inputs, CPUExecutionProvider, CUDAExecutionProvider, ExecutionProviderDispatch, + InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, Value, +}; +use pyo3::{ + pyclass, pymethods, + types::{PyBytes, PyBytesMethods}, + Bound, +}; +use rand_distr::{Distribution, Uniform}; +use thiserror::Error; + +#[pyclass] +#[derive(Clone, Debug)] +pub struct TorchModel { + ndim: usize, + logp_model: Box<[u8]>, + providers: Vec, +} + +impl TorchModel { + fn make_logp_session<'a>(&'a self) -> Result<()> { + todo!() + } +} + +#[pymethods] +impl TorchModel { + #[new] + pub fn new_py<'py>(ndim: usize, logp_model: Bound<'py, PyBytes>) -> Result { + todo!() + } +} + +#[derive(Clone)] +pub struct TorchTrace { + trace: MutableFixedSizeListArray>, +} + +impl DrawStorage for TorchTrace { + fn append_value(&mut self, point: &[f64]) -> Result<()> { + self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + Ok(()) + } + + fn finalize(mut self) -> Result> { + let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + let fields = vec![field]; + let data_type = DataType::Struct(fields); + let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); + Ok(Box::new(struct_array)) + } + + fn inspect(&mut self) -> Result> { + self.clone().finalize() + } +} + +#[derive(Error, Debug)] +pub enum TorchLogpError { + #[error("Error while computing logp and gradient: {0:?}")] + Iree(#[from] anyhow::Error), + #[error("Bad logp value in gradient evaluation")] + BadLogp(), +} + +impl LogpError for TorchLogpError { + fn is_recoverable(&self) -> bool { + match self { + Self::BadLogp() => true, + _ => false, + } + } +} + +pub struct TorchLogpFunc<'model> { + ndim: usize, +} + +impl<'model> TorchLogpFunc<'model> { + fn new(ndim: usize) -> Result { + todo!() + } +} + +impl<'model> CpuLogpFunc for TorchLogpFunc<'model> { + type LogpError = TorchLogpError; + + fn dim(&self) -> usize { + self.ndim + } + + fn logp( + &mut self, + position: &[f64], + gradient: &mut [f64], + ) -> std::result::Result { + todo!() + } +} + +impl Model for TorchModel { + type Math<'model> = CpuMath> + where + Self: 'model; + + type DrawStorage<'model, S: nuts_rs::Settings> = TorchTrace + where + Self: 'model; + + fn new_trace<'model, S: nuts_rs::Settings, R: rand::prelude::Rng + ?Sized>( + &'model self, + rng: &mut R, + chain_id: u64, + settings: &'model S, + ) -> Result> { + let items = MutablePrimitiveArray::new(); + let trace = MutableFixedSizeListArray::new(items, self.ndim); + + Ok(OnnxTrace { trace }) + } + + fn math(&self) -> Result> { + let session = self.make_logp_session()?; + Ok(CpuMath::new(OnnxLogpFunc::new(self.ndim, session)?)) + } + + fn init_position( + &self, + rng: &mut R, + position: &mut [f64], + ) -> Result<()> { + let dist = Uniform::new(-2., 2.); + dist.sample_iter(rng) + .zip(position.iter_mut()) + .for_each(|(val, pos)| *pos = val); + Ok(()) + } +} diff --git a/src/wrapper.rs b/src/wrapper.rs index 6f9ad49..4a6c699 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -4,7 +4,12 @@ use std::{ time::{Duration, Instant}, }; +#[cfg(feature = "onnx")] +use crate::ort::OnnxModel; + use crate::{ + ort::OnnxProviders, + progress::ProgressHandler, progress::{IndicatifHandler, ProgressHandler}, pyfunc::{ExpandDtype, PyModel, PyVariable, TensorShape}, pymc::{ExpandFunc, LogpFunc, PyMcModel}, @@ -569,6 +574,38 @@ impl PySampler { } } + #[cfg(feature = "onnx")] + #[staticmethod] + fn from_onnx( + settings: PyDiagGradNutsSettings, + cores: usize, + model: OnnxModel, + template: String, + rate: u64, + callback: Option>, + ) -> PyResult { + let rate = Duration::from_millis(rate); + let callback = make_callback(template, cores, rate, callback)?; + let sampler = Sampler::new(model, settings.0, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + + #[cfg(feature = "iree")] + #[staticmethod] + fn from_iree( + settings: PyDiagGradNutsSettings, + cores: usize, + model: IreeModel, + template: String, + rate: u64, + callback: Option>, + ) -> PyResult { + let rate = Duration::from_millis(rate); + let callback = make_callback(template, cores, rate, callback)?; + let sampler = Sampler::new(model, settings.0, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + fn is_finished(&mut self, py: Python<'_>) -> PyResult { py.allow_threads(|| { let state = std::mem::replace(&mut self.0, SamplerState::Empty); @@ -773,6 +810,11 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; m.add_class::()?; m.add_class::()?; + #[cfg(feature = "onnx")] + m.add_class::()?; + #[cfg(feature = "onnx")] + m.add_class::()?; + m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?; From 5271ec64c4e7dbf494519697a274e5347cbe63d9 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 6 May 2024 15:48:38 +0200 Subject: [PATCH 2/5] Update deps --- Cargo.lock | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 28a79fe..44419ad 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -762,9 +762,9 @@ dependencies = [ [[package]] name = "getrandom" -version = "0.2.14" +version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", "js-sys", @@ -1172,6 +1172,8 @@ dependencies = [ [[package]] name = "nuts-rs" version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ffa91fe8bbfc18a0d3d2068f4d3ad9edf1c19d4f36b61cb3cd1e7f0ed71079c" dependencies = [ "anyhow", "arrow2", @@ -1894,6 +1896,8 @@ checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" [[package]] name = "tch" version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "61fd89a98303b22acd6d4969b4c8940f7a30ba79af32b744a2028375d156e95a" dependencies = [ "half", "lazy_static", @@ -1979,10 +1983,13 @@ checksum = "1f3ccbac311fea05f86f61904b462b55fb3df8837a366dfc601a0161d0532f20" [[package]] name = "torch-sys" version = "0.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5997681f7f3700fa475f541fcda44c8959ea42a724194316fe7297cb96ebb08" dependencies = [ "anyhow", "cc", "libc", + "zip", ] [[package]] From e973b8f3680e41a582b0360b8340cc1e22cdccc0 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 6 May 2024 15:48:47 +0200 Subject: [PATCH 3/5] Add tensorrt --- src/ort.rs | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/src/ort.rs b/src/ort.rs index 37189a4..4450d3b 100644 --- a/src/ort.rs +++ b/src/ort.rs @@ -8,7 +8,8 @@ use itertools::Itertools; use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; use ort::{ inputs, CPUExecutionProvider, CUDAExecutionProvider, ExecutionProviderDispatch, - InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, Value, + InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, + TensorRTExecutionProvider, Value, }; use pyo3::{ pyclass, pymethods, @@ -212,4 +213,10 @@ impl OnnxProviders { self.providers.push(CUDAExecutionProvider::default().into()); Ok(()) } + + pub fn add_tensorrt(&mut self) -> Result<()> { + self.providers + .push(TensorRTExecutionProvider::default().into()); + Ok(()) + } } From 74924a0bc4a1f008c5497137e7cd29fd515c14e0 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Mon, 6 May 2024 17:47:52 +0200 Subject: [PATCH 4/5] Add python onnx code --- python/nutpie/__init__.py | 11 +++- python/nutpie/compile_onnx.py | 67 ++++++++++++++++++++++++ python/nutpie/compile_pymc.py | 2 +- python/nutpie/compile_stan.py | 2 +- python/nutpie/{sample.py => sampling.py} | 0 5 files changed, 78 insertions(+), 4 deletions(-) create mode 100644 python/nutpie/compile_onnx.py rename python/nutpie/{sample.py => sampling.py} (100%) diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 980b7e5..8707d90 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,7 +1,14 @@ from nutpie import _lib from nutpie.compile_pymc import compile_pymc_model from nutpie.compile_stan import compile_stan_model -from nutpie.sample import sample +from nutpie.compile_onnx import compile_pytensor_module +from nutpie.sampling import sample __version__: str = _lib.__version__ -__all__ = ["__version__", "sample", "compile_pymc_model", "compile_stan_model"] +__all__ = [ + "__version__", + "sample", + "compile_pymc_model", + "compile_stan_model", + "compile_pytensor_module", +] diff --git a/python/nutpie/compile_onnx.py b/python/nutpie/compile_onnx.py new file mode 100644 index 0000000..13f430a --- /dev/null +++ b/python/nutpie/compile_onnx.py @@ -0,0 +1,67 @@ +from typing import Any +import dataclasses +import io + +from nutpie.sampling import CompiledModel +from nutpie import _lib + + +def compile_pytensor_module(module, n_dim): + import torch + + x = torch.zeros(n_dim) + exported = torch.onnx.dynamo_export(module, x) + + exported_bytes = io.BytesIO() + exported.save(exported_bytes) + exported_bytes = exported_bytes.getvalue() + + compiled = CompiledOnnx( + _n_dim=n_dim, + providers=None, + logp_module_bytes=exported_bytes, + dims={"unconstrained_draw": ("unconstrained_parameter",)}, + ) + + return compiled.with_providers(["cpu"]) + + +@dataclasses.dataclass(frozen=True) +class CompiledOnnx(CompiledModel): + logp_module_bytes: Any + providers: Any + _n_dim: int + + @property + def shapes(self): + return {"unconstrained_draw": (self.n_dim,)} + + @property + def coords(self): + return {} + + @property + def n_dim(self): + return self._n_dim + + def _make_model(self, init_mean): + return _lib.OnnxModel(self.n_dim, self.logp_module_bytes, self.providers) + + def _make_sampler(self, settings, init_mean, cores, template, rate, callback=None): + model = self._make_model(init_mean) + return _lib.PySampler.from_onnx( + settings, cores, model, template, rate, callback + ) + + def with_providers(self, provider_names): + providers = _lib.OnnxProviders() + for name in provider_names: + if name == "cuda": + providers.add_cuda() + elif name == "tensorrt": + providers.add_tensorrt() + elif name == "cpu": + providers.add_cpu() + else: + raise ValueError(f"Unknown provider {name}") + return dataclasses.replace(self, providers=providers) diff --git a/python/nutpie/compile_pymc.py b/python/nutpie/compile_pymc.py index 435d16d..cf439e3 100644 --- a/python/nutpie/compile_pymc.py +++ b/python/nutpie/compile_pymc.py @@ -12,7 +12,7 @@ from nutpie import _lib from nutpie.compiled_pyfunc import from_pyfunc -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel try: from numba.extending import intrinsic diff --git a/python/nutpie/compile_stan.py b/python/nutpie/compile_stan.py index 7a28052..26e74f6 100644 --- a/python/nutpie/compile_stan.py +++ b/python/nutpie/compile_stan.py @@ -10,7 +10,7 @@ from numpy.typing import NDArray from nutpie import _lib -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel class _NumpyArrayEncoder(json.JSONEncoder): diff --git a/python/nutpie/sample.py b/python/nutpie/sampling.py similarity index 100% rename from python/nutpie/sample.py rename to python/nutpie/sampling.py From 28bd30403bc562b97eb154185978541928518aab Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 24 Jul 2024 11:26:27 +0200 Subject: [PATCH 5/5] WIP onnx --- Cargo.lock | 1079 ++++++++++++++++++++---------- Cargo.toml | 11 +- python/nutpie/__init__.py | 2 +- python/nutpie/compile_onnx.py | 4 +- python/nutpie/compiled_pyfunc.py | 2 +- src/lib.rs | 2 - src/ort.rs | 252 +++++-- src/wrapper.rs | 39 +- 8 files changed, 960 insertions(+), 431 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 44419ad..6ad4fe6 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -26,6 +26,7 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", + "const-random", "getrandom", "once_cell", "version_check", @@ -41,6 +42,21 @@ dependencies = [ "memchr", ] +[[package]] +name = "android-tzdata" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" + +[[package]] +name = "android_system_properties" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "819e7219dbd41043ac279b19830f2efc897156490d7fd6ea916720117ee66311" +dependencies = [ + "libc", +] + [[package]] name = "anes" version = "0.1.6" @@ -55,29 +71,179 @@ checksum = "038dfcf04a5feb68e9c60b21c9625a54c2c0616e79b72b0fd87075a056ae1d1b" [[package]] name = "anyhow" -version = "1.0.82" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f538837af36e6f6a9be0faa67f9a314f8119e4e4b5867c6ab40ed60360142519" +checksum = "b3d1d046238990b9cf5bcde22a3fb3584ee5cf65fb2765f454ed428c7a0063da" [[package]] -name = "arrow2" -version = "0.18.0" +name = "arrow" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6127ea5e585a12ec9f742232442828ebaf264dfa5eefdd71282376c599562b77" +dependencies = [ + "arrow-arith", + "arrow-array", + "arrow-buffer", + "arrow-cast", + "arrow-data", + "arrow-ord", + "arrow-row", + "arrow-schema", + "arrow-select", + "arrow-string", +] + +[[package]] +name = "arrow-arith" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7add7f39210b7d726e2a8efc0083e7bf06e8f2d15bdb4896b564dce4410fbf5d" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "chrono", + "half", + "num", +] + +[[package]] +name = "arrow-array" +version = "52.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "963fef509b757bcbbf9e5ffa23bcb345614d99f4f6f531f97417b27b8604d389" +checksum = "81c16ec702d3898c2f5cfdc148443c6cd7dbe5bac28399859eb0a3d38f072827" dependencies = [ "ahash", - "bytemuck", + "arrow-buffer", + "arrow-data", + "arrow-schema", "chrono", - "dyn-clone", - "either", - "ethnum", - "foreign_vec", - "getrandom", - "hash_hasher", + "half", + "hashbrown", + "num", +] + +[[package]] +name = "arrow-buffer" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cae6970bab043c4fbc10aee1660ceb5b306d0c42c8cc5f6ae564efcd9759b663" +dependencies = [ + "bytes", + "half", + "num", +] + +[[package]] +name = "arrow-cast" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1c7ef44f26ef4f8edc392a048324ed5d757ad09135eff6d5509e6450d39e0398" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "atoi", + "base64", + "chrono", + "half", + "lexical-core", + "num", + "ryu", +] + +[[package]] +name = "arrow-data" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a769666ffac256dd301006faca1ca553d0ae7cffcf4cd07095f73f95eb226514" +dependencies = [ + "arrow-buffer", + "arrow-schema", + "half", + "num", +] + +[[package]] +name = "arrow-ord" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e8008370e624e8e3c68174faaf793540287106cfda8ad1da862fdc53d8e096b4" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "half", + "num", +] + +[[package]] +name = "arrow-row" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ca5e3a6b7fda8d9fe03f3b18a2d946354ea7f3c8e4076dbdb502ad50d9d44824" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "half", "hashbrown", +] + +[[package]] +name = "arrow-schema" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dab1c12b40e29d9f3b699e0203c2a73ba558444c05e388a4377208f8f9c97eee" +dependencies = [ + "bitflags 2.6.0", +] + +[[package]] +name = "arrow-select" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e80159088ffe8c48965cb9b1a7c968b2729f29f37363df7eca177fc3281fe7c3" +dependencies = [ + "ahash", + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "num", +] + +[[package]] +name = "arrow-string" +version = "52.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fd04a6ea7de183648edbcb7a6dd925bbd04c210895f6384c780e27a9b54afcd" +dependencies = [ + "arrow-array", + "arrow-buffer", + "arrow-data", + "arrow-schema", + "arrow-select", + "memchr", + "num", + "regex", + "regex-syntax", +] + +[[package]] +name = "atoi" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f28d99ec8bfea296261ca1af174f24225171fea9664ba9003cbebee704810528" +dependencies = [ "num-traits", - "rustc_version", - "simdutf8", ] [[package]] @@ -104,7 +270,7 @@ version = "0.69.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a00dc851838a2120612785d195287475a3ac45514741da670b735818822129a0" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "cexpr", "clang-sys", "itertools 0.12.1", @@ -117,7 +283,7 @@ dependencies = [ "regex", "rustc-hash", "shlex", - "syn 2.0.60", + "syn 2.0.72", "which", ] @@ -129,9 +295,9 @@ checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" -version = "2.5.0" +version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" +checksum = "b048fb63fd8b5923fc5aa7b340d8e156aec7ec02f0c78fa8a6ddc2613f6f71de" [[package]] name = "block-buffer" @@ -144,12 +310,14 @@ dependencies = [ [[package]] name = "bridgestan" -version = "2.4.1" +version = "2.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9f9cac326d10621223fcf840aee74a9b8087ce93921ecf1ae82ede1581d4366c" +checksum = "cc8db213e11ba8b22c444912e269f8164d4af17292e6e78d9d6a4162225e929b" dependencies = [ "bindgen", "libloading", + "log", + "path-absolutize", "thiserror", ] @@ -161,22 +329,22 @@ checksum = "79296716171880943b8470b5f8d03aa55eb2e645a4874bdbb28adb49162e012c" [[package]] name = "bytemuck" -version = "1.15.0" +version = "1.16.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" +checksum = "b236fc92302c97ed75b38da1f4917b5cdda4984745740f153a5d3059e48d725e" dependencies = [ "bytemuck_derive", ] [[package]] name = "bytemuck_derive" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" +checksum = "1ee891b04274a59bd38b412188e24b849617b2e45a0fd8d057deb63e7403761b" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] @@ -185,6 +353,12 @@ version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" +[[package]] +name = "bytes" +version = "1.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a12916984aab3fa6e39d655a33e09c0071eb36d6ab3aea5c2d78551f1df6d952" + [[package]] name = "bzip2" version = "0.4.4" @@ -214,13 +388,12 @@ checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cc" -version = "1.0.97" +version = "1.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "099a5357d84c4c61eb35fc8eafa9a79a902c2f76911e5747ced4e032edd8d9b4" +checksum = "2aba8f4e9906c7ce3c73463f62a7f0c65183ada1a2d47e397cc8810827f9694f" dependencies = [ "jobserver", "libc", - "once_cell", ] [[package]] @@ -244,7 +417,10 @@ version = "0.4.38" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a21f936df1771bf62b77f047b726c4625ff2e8aa607c01ec06e5a05bd8463401" dependencies = [ + "android-tzdata", + "iana-time-zone", "num-traits", + "windows-targets", ] [[package]] @@ -286,9 +462,9 @@ dependencies = [ [[package]] name = "clang-sys" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "67523a3b4be3ce1989d607a828d036249522dd9c1c8de7f4dd2dae43a37369d1" +checksum = "0b023947811758c97c59bf9d1c188fd619ad4718dcaa767947df1cadb14f39f4" dependencies = [ "glob", "libc", @@ -297,18 +473,18 @@ dependencies = [ [[package]] name = "clap" -version = "4.5.4" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90bc066a67923782aa8515dbaea16946c5bcc5addbd668bb80af688e53e548a0" +checksum = "64acc1846d54c1fe936a78dc189c34e28d3f5afc348403f28ecf53660b9b8462" dependencies = [ "clap_builder", ] [[package]] name = "clap_builder" -version = "4.5.2" +version = "4.5.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae129e2e766ae0ec03484e609954119f123cc1fe650337e155d03b022f24f7b4" +checksum = "6fb8393d67ba2e7bfaf28a23458e4e2b543cc73a99595511eb207fdb8aede942" dependencies = [ "anstyle", "clap_lex", @@ -316,24 +492,48 @@ dependencies = [ [[package]] name = "clap_lex" -version = "0.7.0" +version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "98cc8fbded0c607b7ba9dd60cd98df59af97e84d24e49c8557331cfc26d301ce" +checksum = "4b82cf0babdbd58558212896d1a4272303a57bdb245c2bf1147185fb45640e70" + +[[package]] +name = "coe-rs" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" [[package]] -name = "cmake" -version = "0.1.50" +name = "console" +version = "0.15.8" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a31c789563b815f77f4250caee12365734369f942439b7defd71e18a48197130" +checksum = "0e1f83fc076bd6dd27517eacdf25fef6c4dfe5f1d7448bafaaf3a26f13b5e4eb" dependencies = [ - "cc", + "encode_unicode", + "lazy_static", + "libc", + "unicode-width", + "windows-sys", ] [[package]] -name = "coe-rs" -version = "0.1.2" +name = "const-random" +version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e8f1e641542c07631228b1e0dc04b69ae3c1d58ef65d5691a439711d805c698" +checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" +dependencies = [ + "const-random-macro", +] + +[[package]] +name = "const-random-macro" +version = "0.1.16" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" +dependencies = [ + "getrandom", + "once_cell", + "tiny-keccak", +] [[package]] name = "constant_time_eq" @@ -341,6 +541,12 @@ version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "245097e9a4535ee1e3e3931fcfcd55a796a44c643e8596ff6566d68f09b87bbc" +[[package]] +name = "core-foundation-sys" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06ea2b9bc92be3c2baa9334a323ebca2d6f074ff852cd1d7b11064035cd3868f" + [[package]] name = "cpufeatures" version = "0.2.12" @@ -352,9 +558,9 @@ dependencies = [ [[package]] name = "crc32fast" -version = "1.4.0" +version = "1.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +checksum = "a97769d94ddab943e4510d138150169a2758b5ef3eb191a9ee688de3e23ef7b3" dependencies = [ "cfg-if", ] @@ -416,9 +622,9 @@ dependencies = [ [[package]] name = "crossbeam-utils" -version = "0.8.19" +version = "0.8.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" +checksum = "22ec99545bb0ed0ea7bb9b8e1e9122ea386ff8a48c0922e43f36d45ab09e0e80" [[package]] name = "crunchy" @@ -462,12 +668,6 @@ dependencies = [ "subtle", ] -[[package]] -name = "dyn-clone" -version = "1.0.17" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0d6ef0072f8a535281e4876be788938b528e9a1d43900b82c2569af7da799125" - [[package]] name = "dyn-stack" version = "0.10.0" @@ -479,31 +679,16 @@ dependencies = [ ] [[package]] -name = "eerie" -version = "0.2.2" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4f697edddf865085bb1238c234ff46713c2baf0a353687b1ae3d15e5eaa91c6b" -dependencies = [ - "eerie-sys", - "log", - "thiserror", -] - -[[package]] -name = "eerie-sys" -version = "0.2.2" +name = "either" +version = "1.13.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "470184b575c1acd2895b719e7510efd844b8fac49a00a6b671905888664fda89" -dependencies = [ - "bindgen", - "cmake", -] +checksum = "60b1af1c220855b6ceac025d3f6ecdd2b7c4894bfe9cd9bda4fbb4bc7c0d4cf0" [[package]] -name = "either" -version = "1.11.0" +name = "encode_unicode" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" +checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" [[package]] name = "enum-as-inner" @@ -514,50 +699,44 @@ dependencies = [ "heck", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] name = "equator" -version = "0.1.10" +version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a3b0a88aa91d0ad2b9684e4479aed31a17d3f9051bdbbc634bd2c01bc5a5eee8" +checksum = "c35da53b5a021d2484a7cc49b2ac7f2d840f8236a286f84202369bd338d761ea" dependencies = [ "equator-macro", ] [[package]] name = "equator-macro" -version = "0.1.9" +version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "60d08acb9849f7fb4401564f251be5a526829183a3645a90197dea8e786cf3ae" +checksum = "3bf679796c0322556351f287a51b49e48f7c4986e727b5dd78c972d30e2e16cc" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] name = "errno" -version = "0.3.8" +version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", "windows-sys", ] -[[package]] -name = "ethnum" -version = "1.5.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b90ca2580b73ab6a1f724b76ca11ab632df820fd6040c336200d2c1df7b3c82c" - [[package]] name = "faer" -version = "0.18.2" +version = "0.19.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e547492d9b55c4ea882584e691ed092228981e337d0c800bc721301d7e61e40a" +checksum = "41543c4de4bfb32efdffdd75cbcca5ef41b800e8a811ea4a41fb9393c6ef3bc0" dependencies = [ "bytemuck", "coe-rs", @@ -569,22 +748,18 @@ dependencies = [ "libm", "matrixcompare", "matrixcompare-core", - "npyz", + "nano-gemm", "num-complex", "num-traits", "paste", - "rand", - "rand_distr", - "rayon", "reborrow", - "serde", ] [[package]] name = "faer-entity" -version = "0.18.0" +version = "0.19.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "22ea5c06233193392c614a46aa3bbe3de29c1404692c8053abd9c2765a1cd159" +checksum = "ab968a02be27be95de0f1ad0af901b865fa0866b6a9b553a6cc9cf7f19c2ce71" dependencies = [ "bytemuck", "coe-rs", @@ -617,12 +792,6 @@ dependencies = [ "miniz_oxide", ] -[[package]] -name = "foreign_vec" -version = "0.1.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ee1b05cbd864bcaecbd3455d6d967862d446e4ebfc3c2e5e5b9841e53cba6673" - [[package]] name = "form_urlencoded" version = "1.2.1" @@ -634,9 +803,9 @@ dependencies = [ [[package]] name = "gemm" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +checksum = "e400f2ffd14e7548356236c35dc39cad6666d833a852cb8a8f3f28029359bb03" dependencies = [ "dyn-stack", "gemm-c32", @@ -654,9 +823,9 @@ dependencies = [ [[package]] name = "gemm-c32" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +checksum = "10dc4a6176c8452d60eac1a155b454c91c668f794151a303bf3c75ea2874812d" dependencies = [ "dyn-stack", "gemm-common", @@ -669,9 +838,9 @@ dependencies = [ [[package]] name = "gemm-c64" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +checksum = "cc2032ce2c0bb150da0256338759a6fb01ca056f6dfe28c4d14af32d7f878f6f" dependencies = [ "dyn-stack", "gemm-common", @@ -684,9 +853,9 @@ dependencies = [ [[package]] name = "gemm-common" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +checksum = "90fd234fc525939654f47b39325fd5f55e552ceceea9135f3aa8bdba61eabef6" dependencies = [ "bytemuck", "dyn-stack", @@ -697,16 +866,15 @@ dependencies = [ "paste", "pulp", "raw-cpuid", - "rayon", "seq-macro", "sysctl", ] [[package]] name = "gemm-f16" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +checksum = "3fc3652651f96a711d46b8833e1fac27a864be4bdfa81a374055f33ddd25c0c6" dependencies = [ "dyn-stack", "gemm-common", @@ -716,15 +884,14 @@ dependencies = [ "num-traits", "paste", "raw-cpuid", - "rayon", "seq-macro", ] [[package]] name = "gemm-f32" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +checksum = "acbc51c44ae3defd207e6d9416afccb3c4af1e7cef5e4960e4c720ac4d6f998e" dependencies = [ "dyn-stack", "gemm-common", @@ -737,9 +904,9 @@ dependencies = [ [[package]] name = "gemm-f64" -version = "0.17.1" +version = "0.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +checksum = "3f37fc86e325c2415a4d0cab8324a0c5371ec06fc7d2f9cb1636fcfc9536a8d8" dependencies = [ "dyn-stack", "gemm-common", @@ -767,10 +934,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c4567c8db10ae91089c99af84c68c38da3ec2f087c3f82960bcdbf3656b6f4d7" dependencies = [ "cfg-if", - "js-sys", "libc", "wasi", - "wasm-bindgen", ] [[package]] @@ -791,20 +956,11 @@ dependencies = [ "num-traits", ] -[[package]] -name = "hash_hasher" -version = "2.0.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "74721d007512d0cb3338cd20f0654ac913920061a4c4d0d8708edb3f2a698c0c" - [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" -dependencies = [ - "ahash", -] [[package]] name = "heck" @@ -836,6 +992,29 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "iana-time-zone" +version = "0.1.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e7ffbb5a1b541ea2561f8c41c087286cc091e21e556a4f09a8f6cbf17b69b141" +dependencies = [ + "android_system_properties", + "core-foundation-sys", + "iana-time-zone-haiku", + "js-sys", + "wasm-bindgen", + "windows-core", +] + +[[package]] +name = "iana-time-zone-haiku" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f31827a206f56af32e590ba56d5d2d085f558508192593743f16b2306495269f" +dependencies = [ + "cc", +] + [[package]] name = "idna" version = "0.5.0" @@ -846,6 +1025,19 @@ dependencies = [ "unicode-normalization", ] +[[package]] +name = "indicatif" +version = "0.17.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "763a5a8f45087d6bcea4222e7b72c291a054edf80e4ef6efd2a4979878c7bea3" +dependencies = [ + "console", + "instant", + "number_prefix", + "portable-atomic", + "unicode-width", +] + [[package]] name = "indoc" version = "2.0.5" @@ -861,6 +1053,15 @@ dependencies = [ "generic-array", ] +[[package]] +name = "instant" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" +dependencies = [ + "cfg-if", +] + [[package]] name = "is-terminal" version = "0.4.12" @@ -890,6 +1091,15 @@ dependencies = [ "either", ] +[[package]] +name = "itertools" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "413ee7dfc52ee1a4949ceeb7dbc8a33f2d6c088194d9f922fb8318faf1f01186" +dependencies = [ + "either", +] + [[package]] name = "itoa" version = "1.0.11" @@ -916,9 +1126,9 @@ dependencies = [ [[package]] name = "lazy_static" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" +checksum = "bbd2bcb4c963f2ddae06a2efc7e9f3591312473c50c6685e1f298068316e66fe" [[package]] name = "lazycell" @@ -926,17 +1136,81 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "830d08ce1d1d941e6b30645f1a0eb5643013d835ce3779a5fc208261dbe10f55" +[[package]] +name = "lexical-core" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2cde5de06e8d4c2faabc400238f9ae1c74d5412d03a7bd067645ccbc47070e46" +dependencies = [ + "lexical-parse-float", + "lexical-parse-integer", + "lexical-util", + "lexical-write-float", + "lexical-write-integer", +] + +[[package]] +name = "lexical-parse-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "683b3a5ebd0130b8fb52ba0bdc718cc56815b6a097e28ae5a6997d0ad17dc05f" +dependencies = [ + "lexical-parse-integer", + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-parse-integer" +version = "0.8.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d0994485ed0c312f6d965766754ea177d07f9c00c9b82a5ee62ed5b47945ee9" +dependencies = [ + "lexical-util", + "static_assertions", +] + +[[package]] +name = "lexical-util" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5255b9ff16ff898710eb9eb63cb39248ea8a5bb036bea8085b1a767ff6c4e3fc" +dependencies = [ + "static_assertions", +] + +[[package]] +name = "lexical-write-float" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "accabaa1c4581f05a3923d1b4cfd124c329352288b7b9da09e766b0668116862" +dependencies = [ + "lexical-util", + "lexical-write-integer", + "static_assertions", +] + +[[package]] +name = "lexical-write-integer" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e1b6f3d1f4422866b68192d62f77bc5c700bee84f3069f2469d7bc8c77852446" +dependencies = [ + "lexical-util", + "static_assertions", +] + [[package]] name = "libc" -version = "0.2.154" +version = "0.2.155" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ae743338b92ff9146ce83992f766a31066a91a8c84a45e0e9f21e7cf6de6d346" +checksum = "97b3888a4aecf77e811145cadf6eef5901f4782c53886191b2f693f24761847c" [[package]] name = "libloading" -version = "0.8.3" +version = "0.8.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0c2a198fb6b0eada2a8df47933734e6d35d350665a33a3593d7164fa52c75c19" +checksum = "4979f22fdb869068da03c9f7528f8297c6fd2606bc3a4affe42e6a823fdb8da4" dependencies = [ "cfg-if", "windows-targets", @@ -950,9 +1224,9 @@ checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" [[package]] name = "linux-raw-sys" -version = "0.4.13" +version = "0.4.14" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" +checksum = "78b3ae25bc7c8c38cec158d1f2757ee79e9b3740fbc7ccf0e59e4b08d793fa89" [[package]] name = "lock_api" @@ -966,9 +1240,9 @@ dependencies = [ [[package]] name = "log" -version = "0.4.21" +version = "0.4.22" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" +checksum = "a7a70ba024b9dc04c27ea2f0c0548feb474ec5c54bba33a7f72f873a39d07b24" [[package]] name = "matrixcompare" @@ -998,9 +1272,9 @@ dependencies = [ [[package]] name = "memchr" -version = "2.7.2" +version = "2.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" +checksum = "78ca9ab1a0babb1e7d5695e3530886289c18cf2f87ec19a575a0abdce112e3a3" [[package]] name = "memoffset" @@ -1019,9 +1293,9 @@ checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" -version = "0.7.2" +version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9d811f3e15f28568be3407c8e7fdb6514c1cda3cb30683f15b6a1a1dc4ea14a7" +checksum = "b8a240ddb74feaf34a79a7add65a741f3167852fba007066dcac1ca548d89c08" dependencies = [ "adler", ] @@ -1048,6 +1322,76 @@ dependencies = [ "target-features", ] +[[package]] +name = "nano-gemm" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f563548d38f390ef9893e4883ec38c1fb312f569e98d76bededdd91a3b41a043" +dependencies = [ + "equator", + "nano-gemm-c32", + "nano-gemm-c64", + "nano-gemm-codegen", + "nano-gemm-core", + "nano-gemm-f32", + "nano-gemm-f64", + "num-complex", +] + +[[package]] +name = "nano-gemm-c32" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a40449e57a5713464c3a1208c4c3301c8d29ee1344711822cf022bc91373a91b" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", + "num-complex", +] + +[[package]] +name = "nano-gemm-c64" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "743a6e6211358fba85d1009616751e4107da86f4c95b24e684ce85f25c25b3bf" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", + "num-complex", +] + +[[package]] +name = "nano-gemm-codegen" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "963bf7c7110d55430169dc74c67096375491ed580cd2ef84842550ac72e781fa" + +[[package]] +name = "nano-gemm-core" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe3fc4f83ae8861bad79dc3c016bd6b0220da5f9de302e07d3112d16efc24aa6" + +[[package]] +name = "nano-gemm-f32" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e3681b7ce35658f79da94b7f62c60a005e29c373c7111ed070e3bf64546a8bb" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", +] + +[[package]] +name = "nano-gemm-f64" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc1e619ed04d801809e1f63e61b669d380c4119e8b0cdd6ed184c6b111f046d8" +dependencies = [ + "nano-gemm-codegen", + "nano-gemm-core", +] + [[package]] name = "ndarray" version = "0.15.6" @@ -1072,32 +1416,34 @@ dependencies = [ ] [[package]] -name = "npyz" -version = "0.8.3" +name = "num" +version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "13f27ea175875c472b3df61ece89a6d6ef4e0627f43704e400c782f174681ebd" +checksum = "35bd024e8b2ff75562e5f34e7f4905839deb4b22955ef5e73d2fea1b9813cb23" dependencies = [ - "byteorder", "num-bigint", - "py_literal", + "num-complex", + "num-integer", + "num-iter", + "num-rational", + "num-traits", ] [[package]] name = "num-bigint" -version = "0.4.4" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" +checksum = "a5e44f723f1133c9deac646763579fdb3ac745e418f2a7af9cd0c431da1f20b9" dependencies = [ - "autocfg", "num-integer", "num-traits", ] [[package]] name = "num-complex" -version = "0.4.5" +version = "0.4.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +checksum = "73f88a1307638156682bada9d7604135552957b7818057dcef22705b4d509495" dependencies = [ "bytemuck", "num-traits", @@ -1118,6 +1464,28 @@ dependencies = [ "num-traits", ] +[[package]] +name = "num-iter" +version = "0.1.45" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1429034a0490724d0075ebb2bc9e875d6503c3cf69e235a8941aa757d83ef5bf" +dependencies = [ + "autocfg", + "num-integer", + "num-traits", +] + +[[package]] +name = "num-rational" +version = "0.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f83d14da390562dca69fc84082e73e548e1ad308d24accdedd2720017cb37824" +dependencies = [ + "num-bigint", + "num-integer", + "num-traits", +] + [[package]] name = "num-traits" version = "0.2.19" @@ -1128,6 +1496,12 @@ dependencies = [ "libm", ] +[[package]] +name = "number_prefix" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "830b246a0e5f20af87141b25c173cd1b609bd7779a4617d6ec582abaf90870f3" + [[package]] name = "numpy" version = "0.21.0" @@ -1145,15 +1519,14 @@ dependencies = [ [[package]] name = "nutpie" -version = "0.10.0" +version = "0.13.1" dependencies = [ "anyhow", - "arrow2", + "arrow", "bridgestan", "criterion", - "eerie", - "itertools 0.12.1", - "ndarray", + "indicatif", + "itertools 0.13.0", "numpy", "nuts-rs", "ort", @@ -1171,14 +1544,12 @@ dependencies = [ [[package]] name = "nuts-rs" -version = "0.9.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1ffa91fe8bbfc18a0d3d2068f4d3ad9edf1c19d4f36b61cb3cd1e7f0ed71079c" +version = "0.12.1" dependencies = [ "anyhow", - "arrow2", + "arrow", "faer", - "itertools 0.12.1", + "itertools 0.13.0", "multiversion", "pulp", "rand", @@ -1196,15 +1567,15 @@ checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" [[package]] name = "oorandom" -version = "11.1.3" +version = "11.1.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" +checksum = "b410bbe7e14ab526a0e86877eb47c6996a2bd7746f027ba551028c925390e4e9" [[package]] name = "ort" -version = "2.0.0-rc.2" +version = "2.0.0-rc.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0bc80894094c6a875bfac64415ed456fa661081a278a035e22be661305c87e14" +checksum = "86d83095ae3c1258738d70ae7a06195c94d966a8e546f0d3609dc90885fb61f5" dependencies = [ "half", "js-sys", @@ -1218,11 +1589,12 @@ dependencies = [ [[package]] name = "ort-sys" -version = "2.0.0-rc.2" +version = "2.0.0-rc.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3d9c1373fc813d3f024d394f621f4c6dde0734c79b1c17113c3bb5bf0084bbe" +checksum = "0f2f6427193c808010b126bef45ebd33f8dee43770223a1200f84d3734d6c656" dependencies = [ "flate2", + "pkg-config", "sha2", "tar", "ureq", @@ -1230,9 +1602,9 @@ dependencies = [ [[package]] name = "parking_lot" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7e4af0ca4f6caed20e900d564c242b8e5d4903fdacf31d3daf527b66fe6f42fb" +checksum = "f1bf18183cf54e8d6059647fc3063646a1801cf30896933ec2311622cc4b9a27" dependencies = [ "lock_api", "parking_lot_core", @@ -1246,7 +1618,7 @@ checksum = "1e401f977ab385c9e4e3ab30627d6f26d00e2c73eef317493c4ec6d468726cf8" dependencies = [ "cfg-if", "libc", - "redox_syscall 0.5.1", + "redox_syscall 0.5.3", "smallvec", "windows-targets", ] @@ -1264,72 +1636,45 @@ dependencies = [ [[package]] name = "paste" -version = "1.0.14" +version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" +checksum = "57c0d7b74b563b49d38dae00a0c37d4d6de9b432382b2892f0574ddcae73fd0a" [[package]] -name = "pbkdf2" -version = "0.11.0" +name = "path-absolutize" +version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" +checksum = "e4af381fe79fa195b4909485d99f73a80792331df0625188e707854f0b3383f5" dependencies = [ - "digest", - "hmac", - "password-hash", - "sha2", + "path-dedot", ] [[package]] -name = "percent-encoding" -version = "2.3.1" +name = "path-dedot" +version = "3.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" - -[[package]] -name = "pest" -version = "2.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "560131c633294438da9f7c4b08189194b20946c8274c6b9e38881a7874dc8ee8" -dependencies = [ - "memchr", - "thiserror", - "ucd-trie", -] - -[[package]] -name = "pest_derive" -version = "2.7.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "26293c9193fbca7b1a3bf9b79dc1e388e927e6cacaa78b4a3ab705a1d3d41459" +checksum = "07ba0ad7e047712414213ff67533e6dd477af0a4e1d14fb52343e53d30ea9397" dependencies = [ - "pest", - "pest_generator", + "once_cell", ] [[package]] -name = "pest_generator" -version = "2.7.10" +name = "pbkdf2" +version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3ec22af7d3fb470a85dd2ca96b7c577a1eb4ef6f1683a9fe9a8c16e136c04687" +checksum = "83a0692ec44e4cf1ef28ca317f14f8f07da2d95ec3fa01f86e4467b725e60917" dependencies = [ - "pest", - "pest_meta", - "proc-macro2", - "quote", - "syn 2.0.60", + "digest", + "hmac", + "password-hash", + "sha2", ] [[package]] -name = "pest_meta" -version = "2.7.10" +name = "percent-encoding" +version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d7a240022f37c361ec1878d646fc5b7d7c4d28d5946e1a80ad5a7a4f4ca0bdcd" -dependencies = [ - "once_cell", - "pest", - "sha2", -] +checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pin-project-lite" @@ -1345,9 +1690,9 @@ checksum = "d231b230927b5e4ad203db57bbcbee2802f6bce620b1e4a9024a07d94e2907ec" [[package]] name = "plotters" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" +checksum = "a15b6eccb8484002195a3e44fe65a4ce8e93a625797a063735536fd59cb01cf3" dependencies = [ "num-traits", "plotters-backend", @@ -1358,24 +1703,24 @@ dependencies = [ [[package]] name = "plotters-backend" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" +checksum = "414cec62c6634ae900ea1c56128dfe87cf63e7caece0852ec76aba307cebadb7" [[package]] name = "plotters-svg" -version = "0.3.5" +version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" +checksum = "81b30686a7d9c3e010b84284bdd26a29f2138574f52f5eb6f794fc0ad924e705" dependencies = [ "plotters-backend", ] [[package]] name = "portable-atomic" -version = "1.6.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7170ef9988bc169ba16dd36a7fa041e5c4cbeb6a35b76d4c03daded371eae7c0" +checksum = "da544ee218f0d287a911e9c99a39a8c9bc8fcad3cb8db5959940044ecfc67265" [[package]] name = "powerfmt" @@ -1391,28 +1736,28 @@ checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" [[package]] name = "prettyplease" -version = "0.2.19" +version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ac2cf0f2e4f42b49f5ffd07dae8d746508ef7526c13940e5f524012ae6c6550" +checksum = "5f12335488a2f3b0a83b14edad48dca9879ce89b2edd10e80237e4e852dd645e" dependencies = [ "proc-macro2", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] name = "proc-macro2" -version = "1.0.81" +version = "1.0.86" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +checksum = "5e719e8df665df0d1c8fbfd238015744736151d4445ec0836b8e628aae103b77" dependencies = [ "unicode-ident", ] [[package]] name = "pulp" -version = "0.18.10" +version = "0.18.21" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e14989307e408d9f4245d4fda09a7b144a08114ba124e26cab60ab83dc98db10" +checksum = "0ec8d02258294f59e4e223b41ad7e81c874aa6b15bc4ced9ba3965826da0eed5" dependencies = [ "bytemuck", "libm", @@ -1420,19 +1765,6 @@ dependencies = [ "reborrow", ] -[[package]] -name = "py_literal" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "102df7a3d46db9d3891f178dcc826dc270a6746277a9ae6436f8d29fd490a8e1" -dependencies = [ - "num-bigint", - "num-complex", - "num-traits", - "pest", - "pest_derive", -] - [[package]] name = "pyo3" version = "0.21.2" @@ -1481,7 +1813,7 @@ dependencies = [ "proc-macro2", "pyo3-macros-backend", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] @@ -1494,7 +1826,7 @@ dependencies = [ "proc-macro2", "pyo3-build-config", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] @@ -1598,18 +1930,18 @@ dependencies = [ [[package]] name = "redox_syscall" -version = "0.5.1" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "469052894dcb553421e483e4209ee581a45100d31b4018de03e5a7ad86374a7e" +checksum = "2a908a6e00f1fdd0dfd9c0eb08ce85126f6d8bbda50017e74bc4a4b7d4a926a4" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", ] [[package]] name = "regex" -version = "1.10.4" +version = "1.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c117dbdfde9c8308975b6a18d71f3f385c89461f7b3fb054288ecf2a2058ba4c" +checksum = "b91213439dad192326a0d7c6ee3955910425f441d7038e0d6933b0aec5c4517f" dependencies = [ "aho-corasick", "memchr", @@ -1619,9 +1951,9 @@ dependencies = [ [[package]] name = "regex-automata" -version = "0.4.6" +version = "0.4.7" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "86b83b8b9847f9bf95ef68afb0b8e6cdb80f498442f5179a29fad448fcc1eaea" +checksum = "38caf58cc5ef2fed281f89292ef23f6365465ed9a41b7a7754eb4e26496c92df" dependencies = [ "aho-corasick", "memchr", @@ -1630,9 +1962,9 @@ dependencies = [ [[package]] name = "regex-syntax" -version = "0.8.3" +version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "adad44e29e4c806119491a7f06f03de4d1af22c3a680dd47f1e6e179439d1f56" +checksum = "7a66a03ae7c801facd77a29370b4faec201768915ac14a721ba36f20bc9c209b" [[package]] name = "ring" @@ -1655,22 +1987,13 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "08d43f7aa6b08d49f382cde6a7982047c3426db949b1424bc4b7ec9ae12c6ce2" -[[package]] -name = "rustc_version" -version = "0.4.0" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" -dependencies = [ - "semver", -] - [[package]] name = "rustix" version = "0.38.34" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "errno", "libc", "linux-raw-sys", @@ -1679,11 +2002,12 @@ dependencies = [ [[package]] name = "rustls" -version = "0.22.4" +version = "0.23.11" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bf4ef73721ac7bcd79b2b315da7779d8fc09718c6b3d2d1b2d94850eb8c18432" +checksum = "4828ea528154ae444e5a642dbb7d5623354030dc9822b83fd9bb79683c7399d0" dependencies = [ "log", + "once_cell", "ring", "rustls-pki-types", "rustls-webpki", @@ -1693,15 +2017,15 @@ dependencies = [ [[package]] name = "rustls-pki-types" -version = "1.5.0" +version = "1.7.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "beb461507cee2c2ff151784c52762cf4d9ff6a61f3e80968600ed24fa837fa54" +checksum = "976295e77ce332211c0d24d92c0e83e50f5c5f046d11082cea19f3df13a3562d" [[package]] name = "rustls-webpki" -version = "0.102.3" +version = "0.102.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f3bce581c0dd41bce533ce695a1437fa16a7ab5ac3ccfa99fe1a620a7885eabf" +checksum = "8e6b52d4fda176fd835fdc55a835d4a89b8499cad995885a21149d5ad62f852e" dependencies = [ "ring", "rustls-pki-types", @@ -1710,9 +2034,9 @@ dependencies = [ [[package]] name = "ryu" -version = "1.0.17" +version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" +checksum = "f3cb5ba0dc43242ce17de99c180e96db90b235b8a9fdc9543c96d2209116bd9f" [[package]] name = "safetensors" @@ -1739,12 +2063,6 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" -[[package]] -name = "semver" -version = "1.0.22" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "92d43fe69e652f3df9bdc2b85b2854a0825b86e4fb76bc44d945137d053639ca" - [[package]] name = "seq-macro" version = "0.3.5" @@ -1753,29 +2071,29 @@ checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" [[package]] name = "serde" -version = "1.0.200" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ddc6f9cc94d67c0e21aaf7eda3a010fd3af78ebf6e096aa6e2e13c79749cce4f" +checksum = "bc76f558e0cbb2a839d37354c575f1dc3fdc6546b5be373ba43d95f231bf7c12" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" -version = "1.0.200" +version = "1.0.204" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "856f046b9400cee3c8c94ed572ecdb752444c24528c035cd35882aad6f492bcb" +checksum = "e0cd7e117be63d3c3678776753929474f3b04a43a080c744d6b0ae2a8c28e222" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] name = "serde_json" -version = "1.0.116" +version = "1.0.120" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +checksum = "4e0d21c9a8cae1235ad58a00c11cb40d4b1e5c784f1ef2c537876ed6ffd8b7c5" dependencies = [ "itoa", "ryu", @@ -1810,29 +2128,40 @@ version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" -[[package]] -name = "simdutf8" -version = "0.1.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f27f6278552951f1f2b8cf9da965d10969b2efdea95a6ec47987ab46edfe263a" - [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" +[[package]] +name = "socks" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c3dbbd9ae980613c6dd8e28a9407b50509d3803b57624d5dfe8315218cd58b" +dependencies = [ + "byteorder", + "libc", + "winapi", +] + [[package]] name = "spin" version = "0.9.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6980e8d7511241f8acf4aebddbb1ff938df5eebe98691418c4468d0b72a96a67" +[[package]] +name = "static_assertions" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" + [[package]] name = "subtle" -version = "2.5.0" +version = "2.6.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "81cdd64d312baedb58e21336b31bc043b77e01cc99033ce76ef539f78e965ebc" +checksum = "13c2bddecc57b384dee18652358fb23172facb8a2c51ccc10d74c157bdea3292" [[package]] name = "syn" @@ -1847,9 +2176,9 @@ dependencies = [ [[package]] name = "syn" -version = "2.0.60" +version = "2.0.72" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +checksum = "dc4b9b9bf2add8093d3f2c0204471e951b2285580335de42f9d2534f3ae7a8af" dependencies = [ "proc-macro2", "quote", @@ -1862,7 +2191,7 @@ version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" dependencies = [ - "bitflags 2.5.0", + "bitflags 2.6.0", "byteorder", "enum-as-inner", "libc", @@ -1872,9 +2201,9 @@ dependencies = [ [[package]] name = "tar" -version = "0.4.40" +version = "0.4.41" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b16afcea1f22891c49a00c751c7b63b2233284064f11a200fc624137c51e2ddb" +checksum = "cb797dad5fb5b76fcf519e702f4a589483b5ef06567f160c392832c1f5e44909" dependencies = [ "filetime", "libc", @@ -1889,9 +2218,9 @@ checksum = "c1bbb9f3c5c463a01705937a24fdabc5047929ac764b2d5b9cf681c1f5041ed5" [[package]] name = "target-lexicon" -version = "0.12.14" +version = "0.12.15" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e1fc403891a21bcfb7c37834ba66a547a8f402146eba7265b5a6d88059c9ff2f" +checksum = "4873307b7c257eddcb50c9bedf158eb669578359fb28428bef438fec8e6ba7c2" [[package]] name = "tch" @@ -1912,22 +2241,22 @@ dependencies = [ [[package]] name = "thiserror" -version = "1.0.59" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +checksum = "c0342370b38b6a11b6cc11d6a805569958d54cfa061a29969c3b5ce2ea405724" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "1.0.59" +version = "1.0.63" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +checksum = "a4558b58466b9ad7ca0f102865eccc95938dca1a74a856f2b57b6629050da261" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] @@ -1955,6 +2284,15 @@ version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e32d019b4f7c100bcd5494e40a27119d45b71fba2b07a4684153129279a4647" +[[package]] +name = "tiny-keccak" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" +dependencies = [ + "crunchy", +] + [[package]] name = "tinytemplate" version = "1.2.1" @@ -1967,9 +2305,9 @@ dependencies = [ [[package]] name = "tinyvec" -version = "1.6.0" +version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87cc5ceb3875bb20c2890005a4e226a4651264a5c75edb2421b52861a0a0cb50" +checksum = "445e881f4f6d382d5f27c034e25eb92edd7c784ceab92a0937db7f2e9471b938" dependencies = [ "tinyvec_macros", ] @@ -2011,7 +2349,7 @@ checksum = "34704c8d6ebcbc939824180af020566b01a7c01f80641264eba0999f6c2b6be7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] @@ -2029,12 +2367,6 @@ version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" -[[package]] -name = "ucd-trie" -version = "0.1.6" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" - [[package]] name = "unicode-bidi" version = "0.3.15" @@ -2056,6 +2388,12 @@ dependencies = [ "tinyvec", ] +[[package]] +name = "unicode-width" +version = "0.1.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0336d538f7abc86d282a4189614dfaa90810dfc2c6f6427eaf88e16311dd225d" + [[package]] name = "unindent" version = "0.2.3" @@ -2076,25 +2414,25 @@ checksum = "9fe29601d1624f104fa9a35ea71a5f523dd8bd1cfc8c31f8124ad2b829f013c0" [[package]] name = "ureq" -version = "2.9.7" +version = "2.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "d11a831e3c0b56e438a28308e7c810799e3c118417f342d30ecec080105395cd" +checksum = "72139d247e5f97a3eff96229a7ae85ead5328a39efe76f8bf5a06313d505b6ea" dependencies = [ "base64", "log", "once_cell", "rustls", "rustls-pki-types", - "rustls-webpki", + "socks", "url", "webpki-roots", ] [[package]] name = "url" -version = "2.5.0" +version = "2.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "31e6302e3bb753d46e83516cae55ae196fc0c309407cf11ab35cc51a4c2a4633" +checksum = "22784dbdf76fdde8af1aeda5622b546b422b6fc585325248a2bf9f5e41e94d6c" dependencies = [ "form_urlencoded", "idna", @@ -2144,7 +2482,7 @@ dependencies = [ "once_cell", "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", "wasm-bindgen-shared", ] @@ -2166,7 +2504,7 @@ checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", "wasm-bindgen-backend", "wasm-bindgen-shared", ] @@ -2189,9 +2527,9 @@ dependencies = [ [[package]] name = "webpki-roots" -version = "0.26.1" +version = "0.26.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b3de34ae270483955a94f4b21bdaaeb83d508bb84a01435f393818edb0012009" +checksum = "bd7c23921eeb1713a4e851530e9b9756e4fb0e89978582942612524cf09f01cd" dependencies = [ "rustls-pki-types", ] @@ -2208,6 +2546,22 @@ dependencies = [ "rustix", ] +[[package]] +name = "winapi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" +dependencies = [ + "winapi-i686-pc-windows-gnu", + "winapi-x86_64-pc-windows-gnu", +] + +[[package]] +name = "winapi-i686-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" + [[package]] name = "winapi-util" version = "0.1.8" @@ -2217,6 +2571,21 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "winapi-x86_64-pc-windows-gnu" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" + +[[package]] +name = "windows-core" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "33ab640c8d7e35bf8ba19b884ba838ceb4fba93a4e8c65a9059d08afcfc683d9" +dependencies = [ + "windows-targets", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -2228,9 +2597,9 @@ dependencies = [ [[package]] name = "windows-targets" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm", "windows_aarch64_msvc", @@ -2244,51 +2613,51 @@ dependencies = [ [[package]] name = "windows_aarch64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" -version = "0.52.5" +version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "xattr" @@ -2303,29 +2672,29 @@ dependencies = [ [[package]] name = "zerocopy" -version = "0.7.33" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "087eca3c1eaf8c47b94d02790dd086cd594b912d2043d4de4bfdd466b3befb7c" +checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" -version = "0.7.33" +version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f4b6c273f496d8fd4eaf18853e6b448760225dc030ff2c485a786859aea6393" +checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", - "syn 2.0.60", + "syn 2.0.72", ] [[package]] name = "zeroize" -version = "1.7.0" +version = "1.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "525b4ec142c6b68a2d10f01f7bbf6755599ca3f81ea53b8431b7dd348f5fdb2d" +checksum = "ced3678a2879b30306d323f4542626697a464a97c0a07c9aebf7ebca65cd4dde" [[package]] name = "zip" @@ -2368,9 +2737,9 @@ dependencies = [ [[package]] name = "zstd-sys" -version = "2.0.10+zstd.1.5.6" +version = "2.0.12+zstd.1.5.6" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c253a4914af5bafc8fa8c86ee400827e83cf6ec01195ec1f1ed8441bf00d65aa" +checksum = "0a4e40c320c3cb459d9a9ff6de98cff88f4751ee9275d140e2be94a2b74e4c13" dependencies = [ "cc", "pkg-config", diff --git a/Cargo.toml b/Cargo.toml index 8c2dc02..387339a 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -16,7 +16,6 @@ rust-version = "1.76" extension-module = ["pyo3/extension-module"] default = ["extension-module", "onnx"] simd_support = ["nuts-rs/simd_support"] -iree = ["dep:eerie"] torch = ["dep:tch"] onnx = ["dep:ort"] @@ -30,20 +29,22 @@ numpy = "0.21.0" rand = "0.8.5" thiserror = "1.0.44" rand_chacha = "0.3.1" -rayon = "1.9.0" +rayon = "1.10.0" # Keep arrow in sync with nuts-rs requirements -arrow = { version = "52.0.0", default-features = false, features = ["ffi"] } +arrow = { version = "52.1.0", default-features = false, features = ["ffi"] } anyhow = "1.0.72" itertools = "0.13.0" bridgestan = "2.5.0" rand_distr = "0.4.3" -smallvec = "1.11.0" +smallvec = "1.13.0" upon = { version = "0.8.1", default-features = false, features = [] } time-humanize = { version = "0.1.3", default-features = false } indicatif = "0.17.8" tch = { version = "0.16.0", optional = true } -ort = { version = "2.0.0-rc.2", optional = true, features = [ +ort = { version = "2.0.0-rc.4", optional = true, features = [ "cuda", + #"tensorrt", + #"openvino", "load-dynamic", ] } diff --git a/python/nutpie/__init__.py b/python/nutpie/__init__.py index 8707d90..bf0fb20 100644 --- a/python/nutpie/__init__.py +++ b/python/nutpie/__init__.py @@ -1,7 +1,7 @@ from nutpie import _lib +from nutpie.compile_onnx import compile_pytensor_module from nutpie.compile_pymc import compile_pymc_model from nutpie.compile_stan import compile_stan_model -from nutpie.compile_onnx import compile_pytensor_module from nutpie.sampling import sample __version__: str = _lib.__version__ diff --git a/python/nutpie/compile_onnx.py b/python/nutpie/compile_onnx.py index 13f430a..4b65701 100644 --- a/python/nutpie/compile_onnx.py +++ b/python/nutpie/compile_onnx.py @@ -1,9 +1,9 @@ -from typing import Any import dataclasses import io +from typing import Any -from nutpie.sampling import CompiledModel from nutpie import _lib +from nutpie.sampling import CompiledModel def compile_pytensor_module(module, n_dim): diff --git a/python/nutpie/compiled_pyfunc.py b/python/nutpie/compiled_pyfunc.py index 4db549c..5debd19 100644 --- a/python/nutpie/compiled_pyfunc.py +++ b/python/nutpie/compiled_pyfunc.py @@ -6,7 +6,7 @@ import numpy as np from nutpie import _lib -from nutpie.sample import CompiledModel +from nutpie.sampling import CompiledModel @dataclass(frozen=True) diff --git a/src/lib.rs b/src/lib.rs index 8e956ae..719e66f 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,5 +1,3 @@ -#[cfg(feature = "iree")] -mod iree; #[cfg(feature = "onnx")] mod ort; mod progress; diff --git a/src/ort.rs b/src/ort.rs index 4450d3b..1754883 100644 --- a/src/ort.rs +++ b/src/ort.rs @@ -1,14 +1,16 @@ +use std::sync::atomic::AtomicUsize; +use std::sync::Arc; + +use anyhow::Result; use anyhow::{anyhow, Context}; -use anyhow::{bail, Result}; -use arrow2::{ - array::{MutableArray, MutableFixedSizeListArray, MutablePrimitiveArray, StructArray, TryPush}, - datatypes::{DataType, Field}, -}; +use arrow::array::{Array, FixedSizeListBuilder, PrimitiveBuilder, StructArray}; +use arrow::datatypes::{Field, Float64Type}; use itertools::Itertools; -use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Model}; +use nuts_rs::{CpuLogpFunc, CpuMath, DrawStorage, LogpError, Math, Model}; use ort::{ - inputs, CPUExecutionProvider, CUDAExecutionProvider, ExecutionProviderDispatch, - InMemorySession, Session, SessionBuilder, SessionInputValue, SessionInputs, + AllocationDevice, Allocator, CPUExecutionProvider, CUDAExecutionProvider, + ExecutionProviderDispatch, InMemorySession, IoBinding, MemoryInfo, MemoryType, + OpenVINOExecutionProvider, Session, SessionInputs, TVMExecutionProvider, Tensor, TensorRTExecutionProvider, Value, }; use pyo3::{ @@ -19,20 +21,42 @@ use pyo3::{ use rand_distr::{Distribution, Uniform}; use thiserror::Error; +#[derive(Debug, Clone)] #[pyclass] -#[derive(Clone, Debug)] pub struct OnnxModel { ndim: usize, logp_model: Box<[u8]>, - providers: Vec, + providers: OnnxProviders, + sessions: Arc>, + count: Arc, } impl OnnxModel { - fn make_logp_session<'a>(&'a self) -> Result> { + fn make_plain_logp_session<'a>(&'a self) -> Result { let logp_session = Session::builder()? - .with_execution_providers(self.providers.iter().cloned())? + .with_optimization_level(ort::GraphOptimizationLevel::Level3)? .with_memory_pattern(true)? - .commit_from_memory_directly(&self.logp_model)?; + //.commit_from_memory_directly(&self.logp_model)?; + .commit_from_memory(&self.logp_model)?; + // + + Ok(logp_session) + } + + fn make_logp_session<'a>(&'a self) -> Result { + let logp_session = Session::builder()? + .with_optimization_level(ort::GraphOptimizationLevel::Level3)? + .with_execution_providers( + self.providers + .clone() + .providers + .into_iter() + .map(|val| val.into()), + )? + .with_memory_pattern(true)? + //.commit_from_memory_directly(&self.logp_model)?; + .commit_from_memory(&self.logp_model)?; + // Ok(logp_session) } @@ -44,37 +68,57 @@ impl OnnxModel { pub fn new_py<'py>( ndim: usize, logp_model: Bound<'py, PyBytes>, - providers: &OnnxProviders, + providers: OnnxProviders, ) -> Result { - Ok(Self { + let mut model = Self { ndim, - providers: providers.providers.iter().cloned().collect(), + providers, logp_model: logp_model.as_bytes().into(), - }) + sessions: Arc::new(vec![]), + count: Arc::new(0usize.into()), + }; + for _ in 0..6 { + let session = model.make_logp_session()?; + Arc::get_mut(&mut model.sessions).unwrap().push(session); + } + + let session = model.make_plain_logp_session()?; + + let pos = vec![0f32; ndim]; + let input = Tensor::from_array(([ndim], pos))?; + + session.run(ort::inputs![input]?)?; + + Ok(model) } } -#[derive(Clone)] pub struct OnnxTrace { - trace: MutableFixedSizeListArray>, + trace: FixedSizeListBuilder>, } impl DrawStorage for OnnxTrace { fn append_value(&mut self, point: &[f64]) -> Result<()> { - self.trace.try_push(Some(point.iter().map(|&x| Some(x))))?; + self.trace.values().append_slice(point); + self.trace.append(true); Ok(()) } - fn finalize(mut self) -> Result> { - let field = Field::new("unconstrained_draw", self.trace.data_type().clone(), false); + fn finalize(mut self) -> Result> { + //let data_type = DataType::Struct(fields.into()); + let data: Arc = Arc::new(self.trace.finish()); + let field = Field::new("unconstrained_draw", data.data_type().clone(), false); let fields = vec![field]; - let data_type = DataType::Struct(fields); - let struct_array = StructArray::new(data_type, vec![self.trace.as_box()], None); - Ok(Box::new(struct_array)) + let struct_array = StructArray::new(fields.into(), vec![data], None); + Ok(Arc::new(struct_array)) } - fn inspect(&mut self) -> Result> { - self.clone().finalize() + fn inspect(&self) -> Result> { + let data: Arc = Arc::new(self.trace.finish_cloned()); + let field = Field::new("unconstrained_draw", data.data_type().clone(), false); + let fields = vec![field]; + let struct_array = StructArray::new(fields.into(), vec![data], None); + Ok(Arc::new(struct_array)) } } @@ -96,13 +140,33 @@ impl LogpError for OnnxLogpError { } pub struct OnnxLogpFunc<'model> { - session: InMemorySession<'model>, + //session: &'model InMemorySession<'model>, + input: Tensor, + binding: IoBinding<'model>, + session: &'model Session, ndim: usize, + input_allocator: Allocator, + output_allocator: Allocator, } impl<'model> OnnxLogpFunc<'model> { - fn new(ndim: usize, session: InMemorySession<'model>) -> Result { - Ok(Self { session, ndim }) + //fn new(ndim: usize, session: &'model InMemorySession<'model>) -> Result { + fn new( + ndim: usize, + binding: IoBinding<'model>, + session: &'model Session, + input: Tensor, + input_allocator: Allocator, + output_allocator: Allocator, + ) -> Result { + Ok(Self { + session, + binding, + ndim, + input, + input_allocator, + output_allocator, + }) } } @@ -118,6 +182,7 @@ impl<'model> CpuLogpFunc for OnnxLogpFunc<'model> { position: &[f64], gradient: &mut [f64], ) -> std::result::Result { + /* let position = position.iter().map(|&x| x as f32).collect_vec(); let position = Value::from_array(([position.len()], position)).context("Could not create input")?; @@ -148,7 +213,35 @@ impl<'model> CpuLogpFunc for OnnxLogpFunc<'model> { gradient .iter_mut() .zip(vals.iter()) - .for_each(|(mut out, &val)| *out = val as f64); + .for_each(|(out, &val)| *out = val as f64); + */ + + let (_, input_vals) = self.input.extract_raw_tensor_mut(); + position + .iter() + .zip(input_vals.iter_mut()) + .for_each(|(val, loc)| *loc = *val as _); + + self.binding + .bind_input(&self.session.inputs[0].name, &self.input) + .context("Coud not bind input to logp function")?; + + let outputs = self.binding.run().context("Could not run logp function")?; + let first = &outputs[0]; + let logp: f32 = first + .try_extract_scalar() + .context("First output wnot a scalar")?; + + let grad = &outputs[1]; + let (_, grad): (_, &[f32]) = grad + .try_extract_raw_tensor() + .context("First output wnot a scalar")?; + + gradient + .iter_mut() + .zip(grad.iter()) + .for_each(|(out, &val)| *out = val as f64); + Ok(logp as f64) } } @@ -168,15 +261,59 @@ impl Model for OnnxModel { chain_id: u64, settings: &'model S, ) -> Result> { - let items = MutablePrimitiveArray::new(); - let trace = MutableFixedSizeListArray::new(items, self.ndim); + let items = PrimitiveBuilder::new(); + let trace = FixedSizeListBuilder::new(items, self.ndim.try_into().unwrap()); Ok(OnnxTrace { trace }) } fn math(&self) -> Result> { - let session = self.make_logp_session()?; - Ok(CpuMath::new(OnnxLogpFunc::new(self.ndim, session)?)) + //let session = self.make_logp_session()?; + let count = self.count.fetch_add(1, std::sync::atomic::Ordering::SeqCst); + let count = count % self.sessions.len(); + + let session = &self.sessions[count]; + + let input_allocator = Allocator::new( + session, + MemoryInfo::new( + AllocationDevice::CUDAPinned, + 0, + ort::AllocatorType::Device, + MemoryType::CPUInput, + )?, + )?; + let output_allocator = Allocator::new( + session, + MemoryInfo::new( + AllocationDevice::CUDAPinned, + 0, + ort::AllocatorType::Device, + MemoryType::CPUOutput, + )?, + )?; + + let mut binding = session.create_binding()?; + + let input = Tensor::::new(&input_allocator, [self.ndim])?; + + binding.bind_input(&session.inputs[0].name, &input)?; + + let scalar_shape: [usize; 0] = []; + let logp_output = Tensor::::new(&output_allocator, scalar_shape)?; + let grad_output = Tensor::::new(&output_allocator, [self.ndim])?; + + binding.bind_output(&session.outputs[0].name, logp_output)?; + binding.bind_output(&session.outputs[1].name, grad_output)?; + + Ok(CpuMath::new(OnnxLogpFunc::new( + self.ndim, + binding, + session, + input, + input_allocator, + output_allocator, + )?)) } fn init_position( @@ -192,9 +329,31 @@ impl Model for OnnxModel { } } +#[derive(Debug, Clone)] +enum Provider { + Cpu(CPUExecutionProvider), + Cuda(CUDAExecutionProvider), + TensorRt(TensorRTExecutionProvider), + Tvm(TVMExecutionProvider), + OpenVINO(OpenVINOExecutionProvider), +} + +impl Into for Provider { + fn into(self) -> ExecutionProviderDispatch { + match self { + Self::Cpu(val) => val.build().error_on_failure().into(), + Self::Cuda(val) => val.build().error_on_failure().into(), + Self::TensorRt(val) => val.build().error_on_failure().into(), + Self::Tvm(val) => val.build().error_on_failure().into(), + Self::OpenVINO(val) => val.build().error_on_failure().into(), + } + } +} + +#[derive(Debug, Clone)] #[pyclass] pub struct OnnxProviders { - providers: Vec, + providers: Vec, } #[pymethods] @@ -205,18 +364,33 @@ impl OnnxProviders { } pub fn add_cpu(&mut self) -> Result<()> { - self.providers.push(CPUExecutionProvider::default().into()); + self.providers + .push(Provider::Cpu(CPUExecutionProvider::default())); Ok(()) } pub fn add_cuda(&mut self) -> Result<()> { - self.providers.push(CUDAExecutionProvider::default().into()); + self.providers.push(Provider::Cuda( + CUDAExecutionProvider::default().with_cuda_graph(), + )); + Ok(()) + } + + pub fn add_tvm(&mut self) -> Result<()> { + let provider = TVMExecutionProvider::default(); + self.providers.push(Provider::Tvm(provider)); + Ok(()) + } + + pub fn add_openvino(&mut self) -> Result<()> { + let provider = OpenVINOExecutionProvider::default(); + self.providers.push(Provider::OpenVINO(provider)); Ok(()) } pub fn add_tensorrt(&mut self) -> Result<()> { self.providers - .push(TensorRTExecutionProvider::default().into()); + .push(Provider::TensorRt(TensorRTExecutionProvider::default())); Ok(()) } } diff --git a/src/wrapper.rs b/src/wrapper.rs index 4a6c699..99ce92f 100644 --- a/src/wrapper.rs +++ b/src/wrapper.rs @@ -9,7 +9,6 @@ use crate::ort::OnnxModel; use crate::{ ort::OnnxProviders, - progress::ProgressHandler, progress::{IndicatifHandler, ProgressHandler}, pyfunc::{ExpandDtype, PyModel, PyVariable, TensorShape}, pymc::{ExpandFunc, LogpFunc, PyMcModel}, @@ -577,33 +576,22 @@ impl PySampler { #[cfg(feature = "onnx")] #[staticmethod] fn from_onnx( - settings: PyDiagGradNutsSettings, + settings: PyNutsSettings, cores: usize, model: OnnxModel, - template: String, - rate: u64, - callback: Option>, - ) -> PyResult { - let rate = Duration::from_millis(rate); - let callback = make_callback(template, cores, rate, callback)?; - let sampler = Sampler::new(model, settings.0, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) - } - - #[cfg(feature = "iree")] - #[staticmethod] - fn from_iree( - settings: PyDiagGradNutsSettings, - cores: usize, - model: IreeModel, - template: String, - rate: u64, - callback: Option>, + progress_type: ProgressType, ) -> PyResult { - let rate = Duration::from_millis(rate); - let callback = make_callback(template, cores, rate, callback)?; - let sampler = Sampler::new(model, settings.0, cores, callback)?; - Ok(PySampler(SamplerState::Running(sampler))) + let callback = progress_type.into_callback()?; + match settings.into_settings() { + Settings::LowRank(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + Settings::Diag(settings) => { + let sampler = Sampler::new(model, settings, cores, callback)?; + Ok(PySampler(SamplerState::Running(sampler))) + } + } } fn is_finished(&mut self, py: Python<'_>) -> PyResult { @@ -814,7 +802,6 @@ pub fn _lib(m: &Bound<'_, PyModule>) -> PyResult<()> { m.add_class::()?; #[cfg(feature = "onnx")] m.add_class::()?; - m.add_class::()?; m.add_class::()?; m.add_class::()?; m.add_class::()?;