add cuda support

This commit is contained in:
Eason 2024-04-25 17:32:04 +08:00
parent 1a777943ff
commit d3ce103d67
5 changed files with 24 additions and 36 deletions

View File

@ -2,6 +2,8 @@ test:
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3 python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
build: build:
cd pyr && cargo build --release cd pyr && cargo build --release
build-cuda:
cd pyr && cargo build --release --feature cuda
train level: train level:
run level: run level:
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3 python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3

View File

@ -22,3 +22,7 @@ candle-core = "0.4.1"
rand = "0.8.5" rand = "0.8.5"
toml = "0.8.12" toml = "0.8.12"
serde = {version = "1.0.198", features = ["derive"]} serde = {version = "1.0.198", features = ["derive"]}
[features]
default = []
cuda = []

View File

@ -12,7 +12,7 @@ use crate::CONFIG;
use super::state::OBSERVATION_SPACE; use super::state::OBSERVATION_SPACE;
use super::{action::AIAction, huber::huber_loss, state::AIState}; use super::{action::AIAction, huber::huber_loss, state::AIState};
const DEVICE: Device = Device::Cpu; // const DEVICE: Device = Device::Cpu;
const ACTION_SPACE: usize = 5; const ACTION_SPACE: usize = 5;
@ -23,16 +23,22 @@ pub struct AIAgent {
memory: VecDeque<(Tensor, u32, Tensor, f64)>, memory: VecDeque<(Tensor, u32, Tensor, f64)>,
old_state: Option<AIState>, old_state: Option<AIState>,
step: usize, step: usize,
device: Device,
accumulate_rewards: f64, accumulate_rewards: f64,
} }
impl AIAgent { impl AIAgent {
pub async fn new() -> Self { pub async fn new() -> Self {
#[cfg(not(feature = "cuda"))]
let device=Device::Cpu;
#[cfg(feature = "cuda")]
let device=Device::new_cuda(0).unwrap();
let mut var_map = VarMap::new(); let mut var_map = VarMap::new();
if Path::new("model.bin").exists() { if Path::new("model.bin").exists() {
var_map.load("model.bin").unwrap(); var_map.load("model.bin").unwrap();
} }
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE); let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device);
let model = seq() let model = seq()
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap()) .add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
.add(Activation::LeakyRelu(0.01)) .add(Activation::LeakyRelu(0.01))
@ -52,6 +58,7 @@ impl AIAgent {
memory: VecDeque::new(), memory: VecDeque::new(),
old_state: None, old_state: None,
step: 0, step: 0,
device,
accumulate_rewards: 0.0, accumulate_rewards: 0.0,
} }
} }
@ -90,7 +97,7 @@ impl AIAgent {
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)), true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
_ => self _ => self
.model .model
.forward(&old_state.into_tensor()) .forward(&old_state.into_tensor(&self.device))
.unwrap() .unwrap()
.squeeze(0) .squeeze(0)
.unwrap() .unwrap()
@ -108,11 +115,11 @@ impl AIAgent {
self.old_state self.old_state
.as_ref() .as_ref()
.unwrap() .unwrap()
.into_tensor() .into_tensor(&self.device)
.squeeze(0) .squeeze(0)
.unwrap(), .unwrap(),
action, action,
state.into_tensor().squeeze(0).unwrap(), state.into_tensor(&self.device).squeeze(0).unwrap(),
reward, reward,
)); ));
self.memory.truncate(CONFIG.replay_size); self.memory.truncate(CONFIG.replay_size);
@ -144,7 +151,7 @@ impl AIAgent {
let states = Tensor::stack(&states, 0).unwrap(); let states = Tensor::stack(&states, 0).unwrap();
let actions = batch.iter().map(|e| e.1); let actions = batch.iter().map(|e| e.1);
let actions = Tensor::from_iter(actions, &DEVICE) let actions = Tensor::from_iter(actions, &self.device)
.unwrap() .unwrap()
.unsqueeze(1) .unsqueeze(1)
.unwrap(); .unwrap();
@ -153,13 +160,13 @@ impl AIAgent {
let next_states = Tensor::stack(&next_states, 0).unwrap(); let next_states = Tensor::stack(&next_states, 0).unwrap();
let rewards = batch.iter().map(|e| e.3 as f32); let rewards = batch.iter().map(|e| e.3 as f32);
let rewards = Tensor::from_iter(rewards, &DEVICE) let rewards = Tensor::from_iter(rewards, &self.device)
.unwrap() .unwrap()
.unsqueeze(1) .unsqueeze(1)
.unwrap(); .unwrap();
let non_final_mask = batch.iter().map(|_| true as u8 as f32); let non_final_mask = batch.iter().map(|_| true as u8 as f32);
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE) let non_final_mask = Tensor::from_iter(non_final_mask, &self.device)
.unwrap() .unwrap()
.unsqueeze(1) .unsqueeze(1)
.unwrap(); .unwrap();
@ -181,7 +188,7 @@ impl AIAgent {
let loss = huber_loss(1.0_f32)(&x, &y); let loss = huber_loss(1.0_f32)(&x, &y);
log::trace!("loss: {:?}", loss); log::trace!("loss: {:?}", loss);
self.optimizer self.optimizer
.backward_step(&Tensor::new(&[loss], &DEVICE).unwrap()) .backward_step(&Tensor::new(&[loss], &self.device).unwrap())
.unwrap(); .unwrap();
} }
pub fn check_point(&mut self) { pub fn check_point(&mut self) {

View File

@ -45,8 +45,8 @@ impl AIState {
.filter(|x| x.score.is_sign_negative()) .filter(|x| x.score.is_sign_negative())
.min_by_key(food_distance(&self.player)) .min_by_key(food_distance(&self.player))
} }
pub fn into_tensor(&self) -> Tensor { pub fn into_tensor(&self, device:&Device) -> Tensor {
Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap() Tensor::new(&[self.into_feature()],device).unwrap()
} }
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] { fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
let x = self.player.x; let x = self.player.x;

View File

@ -1,25 +0,0 @@
// use candle_core::{DType, Device};
// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap};
fn main() {
// let mut var_map = VarMap::new();
// var_map.load("model.bin").unwrap();
// let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu);
// let model = seq()
// .add(linear(14, 60, vb.pp("linear_in")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(48, 5, vb.pp("linear_out")).unwrap())
// .add(Activation::LeakyRelu(0.01));
// let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap();
// let target = Tensor::new(&[0.0], &Device::Cpu).unwrap();
// self.optimizer
// .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
// .unwrap();
}