add cuda support

This commit is contained in:
Eason 2024-04-25 17:32:04 +08:00
parent 1a777943ff
commit d3ce103d67
5 changed files with 24 additions and 36 deletions

View File

@ -2,6 +2,8 @@ test:
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
build:
cd pyr && cargo build --release
build-cuda:
cd pyr && cargo build --release --feature cuda
train level:
run level:
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3

View File

@ -22,3 +22,7 @@ candle-core = "0.4.1"
rand = "0.8.5"
toml = "0.8.12"
serde = {version = "1.0.198", features = ["derive"]}
[features]
default = []
cuda = []

View File

@ -12,7 +12,7 @@ use crate::CONFIG;
use super::state::OBSERVATION_SPACE;
use super::{action::AIAction, huber::huber_loss, state::AIState};
const DEVICE: Device = Device::Cpu;
// const DEVICE: Device = Device::Cpu;
const ACTION_SPACE: usize = 5;
@ -23,16 +23,22 @@ pub struct AIAgent {
memory: VecDeque<(Tensor, u32, Tensor, f64)>,
old_state: Option<AIState>,
step: usize,
device: Device,
accumulate_rewards: f64,
}
impl AIAgent {
pub async fn new() -> Self {
#[cfg(not(feature = "cuda"))]
let device=Device::Cpu;
#[cfg(feature = "cuda")]
let device=Device::new_cuda(0).unwrap();
let mut var_map = VarMap::new();
if Path::new("model.bin").exists() {
var_map.load("model.bin").unwrap();
}
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device);
let model = seq()
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
.add(Activation::LeakyRelu(0.01))
@ -52,6 +58,7 @@ impl AIAgent {
memory: VecDeque::new(),
old_state: None,
step: 0,
device,
accumulate_rewards: 0.0,
}
}
@ -90,7 +97,7 @@ impl AIAgent {
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
_ => self
.model
.forward(&old_state.into_tensor())
.forward(&old_state.into_tensor(&self.device))
.unwrap()
.squeeze(0)
.unwrap()
@ -108,11 +115,11 @@ impl AIAgent {
self.old_state
.as_ref()
.unwrap()
.into_tensor()
.into_tensor(&self.device)
.squeeze(0)
.unwrap(),
action,
state.into_tensor().squeeze(0).unwrap(),
state.into_tensor(&self.device).squeeze(0).unwrap(),
reward,
));
self.memory.truncate(CONFIG.replay_size);
@ -144,7 +151,7 @@ impl AIAgent {
let states = Tensor::stack(&states, 0).unwrap();
let actions = batch.iter().map(|e| e.1);
let actions = Tensor::from_iter(actions, &DEVICE)
let actions = Tensor::from_iter(actions, &self.device)
.unwrap()
.unsqueeze(1)
.unwrap();
@ -153,13 +160,13 @@ impl AIAgent {
let next_states = Tensor::stack(&next_states, 0).unwrap();
let rewards = batch.iter().map(|e| e.3 as f32);
let rewards = Tensor::from_iter(rewards, &DEVICE)
let rewards = Tensor::from_iter(rewards, &self.device)
.unwrap()
.unsqueeze(1)
.unwrap();
let non_final_mask = batch.iter().map(|_| true as u8 as f32);
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)
let non_final_mask = Tensor::from_iter(non_final_mask, &self.device)
.unwrap()
.unsqueeze(1)
.unwrap();
@ -181,7 +188,7 @@ impl AIAgent {
let loss = huber_loss(1.0_f32)(&x, &y);
log::trace!("loss: {:?}", loss);
self.optimizer
.backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
.backward_step(&Tensor::new(&[loss], &self.device).unwrap())
.unwrap();
}
pub fn check_point(&mut self) {

View File

@ -45,8 +45,8 @@ impl AIState {
.filter(|x| x.score.is_sign_negative())
.min_by_key(food_distance(&self.player))
}
pub fn into_tensor(&self) -> Tensor {
Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap()
pub fn into_tensor(&self, device:&Device) -> Tensor {
Tensor::new(&[self.into_feature()],device).unwrap()
}
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
let x = self.player.x;

View File

@ -1,25 +0,0 @@
// use candle_core::{DType, Device};
// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap};
fn main() {
// let mut var_map = VarMap::new();
// var_map.load("model.bin").unwrap();
// let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu);
// let model = seq()
// .add(linear(14, 60, vb.pp("linear_in")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
// .add(Activation::LeakyRelu(0.01))
// .add(linear(48, 5, vb.pp("linear_out")).unwrap())
// .add(Activation::LeakyRelu(0.01));
// let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap();
// let target = Tensor::new(&[0.0], &Device::Cpu).unwrap();
// self.optimizer
// .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
// .unwrap();
}