fix cuda support

This commit is contained in:
Eason 2024-04-26 09:21:43 +08:00
parent e2a59c827a
commit b4840596cb
3 changed files with 34 additions and 15 deletions

40
pyr/Cargo.lock generated
View File

@ -2,6 +2,15 @@
# It is not intended for manual editing. # It is not intended for manual editing.
version = 3 version = 3
[[package]]
name = "arbitrary"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
dependencies = [
"derive_arbitrary",
]
[[package]] [[package]]
name = "async-channel" name = "async-channel"
version = "2.2.1" version = "2.2.1"
@ -203,9 +212,8 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]] [[package]]
name = "candle-core" name = "candle-core"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
dependencies = [ dependencies = [
"byteorder", "byteorder",
"candle-kernels", "candle-kernels",
@ -226,18 +234,16 @@ dependencies = [
[[package]] [[package]]
name = "candle-kernels" name = "candle-kernels"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
dependencies = [ dependencies = [
"bindgen_cuda", "bindgen_cuda",
] ]
[[package]] [[package]]
name = "candle-nn" name = "candle-nn"
version = "0.4.1" version = "0.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
checksum = "66a27533c8edfc915a6459f9850641ef523a829fa1a181c670766c1f752d873a"
dependencies = [ dependencies = [
"candle-core", "candle-core",
"half", "half",
@ -331,6 +337,17 @@ dependencies = [
"powerfmt", "powerfmt",
] ]
[[package]]
name = "derive_arbitrary"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]] [[package]]
name = "dyn-stack" name = "dyn-stack"
version = "0.10.0" version = "0.10.0"
@ -1406,10 +1423,11 @@ dependencies = [
[[package]] [[package]]
name = "zip" name = "zip"
version = "0.6.6" version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index" source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" checksum = "f2655979068a1f8fa91cb9e8e5b9d3ee54d18e0ddc358f2f4a395afc0929a84b"
dependencies = [ dependencies = [
"arbitrary",
"byteorder", "byteorder",
"crc32fast", "crc32fast",
"crossbeam-utils", "crossbeam-utils",

View File

@ -17,12 +17,13 @@ smol = "2.0.0"
log = "0.4.21" log = "0.4.21"
simple_logger = "4.3.3" simple_logger = "4.3.3"
lazy_static = "1.4.0" lazy_static = "1.4.0"
candle-nn = "0.4.1"
candle-core = "0.4.1"
rand = "0.8.5" rand = "0.8.5"
toml = "0.8.12" toml = "0.8.12"
serde = {version = "1.0.198", features = ["derive"]} serde = {version = "1.0.198", features = ["derive"]}
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
[features] [features]
default = [] default = []
cuda = ["candle-nn/cuda"] cuda = ["candle-nn/cuda"]
cudnn = ["cuda", "candle-core/cudnn"]

View File

@ -29,7 +29,7 @@ impl Default for Config {
replay_size: 250, replay_size: 250,
learning_rate: 0.04, learning_rate: 0.04,
gamma: 0.99, gamma: 0.99,
train: true, train: false,
} }
} }
} }