diff --git a/pyr/Cargo.lock b/pyr/Cargo.lock index 23b3c27..b183b14 100644 --- a/pyr/Cargo.lock +++ b/pyr/Cargo.lock @@ -136,6 +136,17 @@ version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" +[[package]] +name = "bindgen_cuda" +version = "0.1.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1f8489af5b7d17a81bffe37e0f4d6e1e4de87c87329d05447f22c35d95a1227d" +dependencies = [ + "glob", + "num_cpus", + "rayon", +] + [[package]] name = "bitflags" version = "1.3.2" @@ -197,6 +208,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00" dependencies = [ "byteorder", + "candle-kernels", + "cudarc", "gemm", "half", "memmap2", @@ -211,6 +224,15 @@ dependencies = [ "zip", ] +[[package]] +name = "candle-kernels" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2" +dependencies = [ + "bindgen_cuda", +] + [[package]] name = "candle-nn" version = "0.4.1" @@ -291,6 +313,15 @@ version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" +[[package]] +name = "cudarc" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9395df0cab995685664e79cc35ad6302bf08fb9c5d82301875a183affe1278b1" +dependencies = [ + "half", +] + [[package]] name = "deranged" version = "0.3.11" @@ -546,6 +577,12 @@ dependencies = [ "wasi", ] +[[package]] +name = "glob" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d2fabcfbdc87f4758337ca535fb41a6d701b65693ce38287d856d1674551ec9b" + [[package]] name = "half" version = "2.4.1" diff --git a/pyr/Cargo.toml b/pyr/Cargo.toml index 25e8285..2b7cdc0 100644 --- a/pyr/Cargo.toml +++ b/pyr/Cargo.toml @@ -25,4 +25,4 @@ serde = {version = "1.0.198", features = ["derive"]} [features] default = [] -cuda = [] \ No newline at end of file +cuda = ["candle-nn/cuda"] \ No newline at end of file