From d3ce103d67fc7e5c9c16a5d7b1128c4b5f47b869 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Thu, 25 Apr 2024 17:32:04 +0800 Subject: [PATCH] add cuda support --- justfile | 2 ++ pyr/Cargo.toml | 4 ++++ pyr/src/app/agent.rs | 25 ++++++++++++++++--------- pyr/src/app/state.rs | 4 ++-- pyr/src/main.rs | 25 ------------------------- 5 files changed, 24 insertions(+), 36 deletions(-) delete mode 100644 pyr/src/main.rs diff --git a/justfile b/justfile index 3809fed..492218d 100644 --- a/justfile +++ b/justfile @@ -2,6 +2,8 @@ test: python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3 build: cd pyr && cargo build --release +build-cuda: + cd pyr && cargo build --release --feature cuda train level: run level: python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3 diff --git a/pyr/Cargo.toml b/pyr/Cargo.toml index 218f251..6ac0301 100644 --- a/pyr/Cargo.toml +++ b/pyr/Cargo.toml @@ -22,3 +22,7 @@ candle-core = "0.4.1" rand = "0.8.5" toml = "0.8.12" serde = {version = "1.0.198", features = ["derive"]} + +[features] +default = [] +cuda = [] \ No newline at end of file diff --git a/pyr/src/app/agent.rs b/pyr/src/app/agent.rs index 8c0c88f..b28efb0 100644 --- a/pyr/src/app/agent.rs +++ b/pyr/src/app/agent.rs @@ -12,7 +12,7 @@ use crate::CONFIG; use super::state::OBSERVATION_SPACE; use super::{action::AIAction, huber::huber_loss, state::AIState}; -const DEVICE: Device = Device::Cpu; +// const DEVICE: Device = Device::Cpu; const ACTION_SPACE: usize = 5; @@ -23,16 +23,22 @@ pub struct AIAgent { memory: VecDeque<(Tensor, u32, Tensor, f64)>, old_state: Option, step: usize, + device: Device, accumulate_rewards: f64, } impl AIAgent { pub async fn new() -> Self { + #[cfg(not(feature = "cuda"))] + let device=Device::Cpu; + #[cfg(feature = "cuda")] + let device=Device::new_cuda(0).unwrap(); + let mut var_map = VarMap::new(); if Path::new("model.bin").exists() { var_map.load("model.bin").unwrap(); } - let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE); + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device); let model = seq() .add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap()) .add(Activation::LeakyRelu(0.01)) @@ -52,6 +58,7 @@ impl AIAgent { memory: VecDeque::new(), old_state: None, step: 0, + device, accumulate_rewards: 0.0, } } @@ -90,7 +97,7 @@ impl AIAgent { true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)), _ => self .model - .forward(&old_state.into_tensor()) + .forward(&old_state.into_tensor(&self.device)) .unwrap() .squeeze(0) .unwrap() @@ -108,11 +115,11 @@ impl AIAgent { self.old_state .as_ref() .unwrap() - .into_tensor() + .into_tensor(&self.device) .squeeze(0) .unwrap(), action, - state.into_tensor().squeeze(0).unwrap(), + state.into_tensor(&self.device).squeeze(0).unwrap(), reward, )); self.memory.truncate(CONFIG.replay_size); @@ -144,7 +151,7 @@ impl AIAgent { let states = Tensor::stack(&states, 0).unwrap(); let actions = batch.iter().map(|e| e.1); - let actions = Tensor::from_iter(actions, &DEVICE) + let actions = Tensor::from_iter(actions, &self.device) .unwrap() .unsqueeze(1) .unwrap(); @@ -153,13 +160,13 @@ impl AIAgent { let next_states = Tensor::stack(&next_states, 0).unwrap(); let rewards = batch.iter().map(|e| e.3 as f32); - let rewards = Tensor::from_iter(rewards, &DEVICE) + let rewards = Tensor::from_iter(rewards, &self.device) .unwrap() .unsqueeze(1) .unwrap(); let non_final_mask = batch.iter().map(|_| true as u8 as f32); - let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE) + let non_final_mask = Tensor::from_iter(non_final_mask, &self.device) .unwrap() .unsqueeze(1) .unwrap(); @@ -181,7 +188,7 @@ impl AIAgent { let loss = huber_loss(1.0_f32)(&x, &y); log::trace!("loss: {:?}", loss); self.optimizer - .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap()) + .backward_step(&Tensor::new(&[loss], &self.device).unwrap()) .unwrap(); } pub fn check_point(&mut self) { diff --git a/pyr/src/app/state.rs b/pyr/src/app/state.rs index 5cf98d4..1a87c38 100644 --- a/pyr/src/app/state.rs +++ b/pyr/src/app/state.rs @@ -45,8 +45,8 @@ impl AIState { .filter(|x| x.score.is_sign_negative()) .min_by_key(food_distance(&self.player)) } - pub fn into_tensor(&self) -> Tensor { - Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap() + pub fn into_tensor(&self, device:&Device) -> Tensor { + Tensor::new(&[self.into_feature()],device).unwrap() } fn into_feature(&self) -> [f32; OBSERVATION_SPACE] { let x = self.player.x; diff --git a/pyr/src/main.rs b/pyr/src/main.rs deleted file mode 100644 index e6a37c9..0000000 --- a/pyr/src/main.rs +++ /dev/null @@ -1,25 +0,0 @@ -// use candle_core::{DType, Device}; -// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap}; - -fn main() { - // let mut var_map = VarMap::new(); - // var_map.load("model.bin").unwrap(); - // let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu); - // let model = seq() - // .add(linear(14, 60, vb.pp("linear_in")).unwrap()) - // .add(Activation::LeakyRelu(0.01)) - // .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap()) - // .add(Activation::LeakyRelu(0.01)) - // .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap()) - // .add(Activation::LeakyRelu(0.01)) - // .add(linear(48, 5, vb.pp("linear_out")).unwrap()) - // .add(Activation::LeakyRelu(0.01)); - - // let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap(); - - // let target = Tensor::new(&[0.0], &Device::Cpu).unwrap(); - - // self.optimizer - // .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap()) - // .unwrap(); -}