From b4840596cb3a2e2a713edac82907c040ed985c21 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Fri, 26 Apr 2024 09:21:43 +0800 Subject: [PATCH] fix cuda support --- pyr/Cargo.lock | 40 +++++++++++++++++++++++++++++----------- pyr/Cargo.toml | 7 ++++--- pyr/src/data/config.rs | 2 +- 3 files changed, 34 insertions(+), 15 deletions(-) diff --git a/pyr/Cargo.lock b/pyr/Cargo.lock index b183b14..bb89f2c 100644 --- a/pyr/Cargo.lock +++ b/pyr/Cargo.lock @@ -2,6 +2,15 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110" +dependencies = [ + "derive_arbitrary", +] + [[package]] name = "async-channel" version = "2.2.1" @@ -203,9 +212,8 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" [[package]] name = "candle-core" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00" +version = "0.5.0" +source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749" dependencies = [ "byteorder", "candle-kernels", @@ -226,18 +234,16 @@ dependencies = [ [[package]] name = "candle-kernels" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2" +version = "0.5.0" +source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749" dependencies = [ "bindgen_cuda", ] [[package]] name = "candle-nn" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "66a27533c8edfc915a6459f9850641ef523a829fa1a181c670766c1f752d873a" +version = "0.5.0" +source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749" dependencies = [ "candle-core", "half", @@ -331,6 +337,17 @@ dependencies = [ "powerfmt", ] +[[package]] +name = "derive_arbitrary" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + [[package]] name = "dyn-stack" version = "0.10.0" @@ -1406,10 +1423,11 @@ dependencies = [ [[package]] name = "zip" -version = "0.6.6" +version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +checksum = "f2655979068a1f8fa91cb9e8e5b9d3ee54d18e0ddc358f2f4a395afc0929a84b" dependencies = [ + "arbitrary", "byteorder", "crc32fast", "crossbeam-utils", diff --git a/pyr/Cargo.toml b/pyr/Cargo.toml index 2b7cdc0..683e5e1 100644 --- a/pyr/Cargo.toml +++ b/pyr/Cargo.toml @@ -17,12 +17,13 @@ smol = "2.0.0" log = "0.4.21" simple_logger = "4.3.3" lazy_static = "1.4.0" -candle-nn = "0.4.1" -candle-core = "0.4.1" rand = "0.8.5" toml = "0.8.12" serde = {version = "1.0.198", features = ["derive"]} +candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" } +candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" } [features] default = [] -cuda = ["candle-nn/cuda"] \ No newline at end of file +cuda = ["candle-nn/cuda"] +cudnn = ["cuda", "candle-core/cudnn"] \ No newline at end of file diff --git a/pyr/src/data/config.rs b/pyr/src/data/config.rs index 5f49982..0b741bd 100644 --- a/pyr/src/data/config.rs +++ b/pyr/src/data/config.rs @@ -29,7 +29,7 @@ impl Default for Config { replay_size: 250, learning_rate: 0.04, gamma: 0.99, - train: true, + train: false, } } }