fix cuda support

This commit is contained in:
Eason 2024-04-26 09:21:43 +08:00
parent e2a59c827a
commit b4840596cb
3 changed files with 34 additions and 15 deletions

40
pyr/Cargo.lock generated
View File

@ -2,6 +2,15 @@
# It is not intended for manual editing.
version = 3
[[package]]
name = "arbitrary"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7d5a26814d8dcb93b0e5a0ff3c6d80a8843bafb21b39e8e18a6f05471870e110"
dependencies = [
"derive_arbitrary",
]
[[package]]
name = "async-channel"
version = "2.2.1"
@ -203,9 +212,8 @@ checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "candle-core"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00"
version = "0.5.0"
source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
dependencies = [
"byteorder",
"candle-kernels",
@ -226,18 +234,16 @@ dependencies = [
[[package]]
name = "candle-kernels"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5845911a44164ebb73b56a0e23793ba1b583bad102af7400fe4768babc5815b2"
version = "0.5.0"
source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
dependencies = [
"bindgen_cuda",
]
[[package]]
name = "candle-nn"
version = "0.4.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66a27533c8edfc915a6459f9850641ef523a829fa1a181c670766c1f752d873a"
version = "0.5.0"
source = "git+https://github.com/huggingface/candle.git#cfab6e761696c18b1ce5d3a339ab57ef191ca749"
dependencies = [
"candle-core",
"half",
@ -331,6 +337,17 @@ dependencies = [
"powerfmt",
]
[[package]]
name = "derive_arbitrary"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "67e77553c4162a157adbf834ebae5b415acbecbeafc7a74b0e886657506a7611"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "dyn-stack"
version = "0.10.0"
@ -1406,10 +1423,11 @@ dependencies = [
[[package]]
name = "zip"
version = "0.6.6"
version = "1.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261"
checksum = "f2655979068a1f8fa91cb9e8e5b9d3ee54d18e0ddc358f2f4a395afc0929a84b"
dependencies = [
"arbitrary",
"byteorder",
"crc32fast",
"crossbeam-utils",

View File

@ -17,12 +17,13 @@ smol = "2.0.0"
log = "0.4.21"
simple_logger = "4.3.3"
lazy_static = "1.4.0"
candle-nn = "0.4.1"
candle-core = "0.4.1"
rand = "0.8.5"
toml = "0.8.12"
serde = {version = "1.0.198", features = ["derive"]}
candle-core = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
candle-nn = { git = "https://github.com/huggingface/candle.git", version = "0.5.0" }
[features]
default = []
cuda = ["candle-nn/cuda"]
cudnn = ["cuda", "candle-core/cudnn"]

View File

@ -29,7 +29,7 @@ impl Default for Config {
replay_size: 250,
learning_rate: 0.04,
gamma: 0.99,
train: true,
train: false,
}
}
}