add cuda support
This commit is contained in:
parent
1a777943ff
commit
d3ce103d67
2
justfile
2
justfile
|
@ -2,6 +2,8 @@ test:
|
||||||
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
|
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
|
||||||
build:
|
build:
|
||||||
cd pyr && cargo build --release
|
cd pyr && cargo build --release
|
||||||
|
build-cuda:
|
||||||
|
cd pyr && cargo build --release --feature cuda
|
||||||
train level:
|
train level:
|
||||||
run level:
|
run level:
|
||||||
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3
|
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3
|
||||||
|
|
|
@ -22,3 +22,7 @@ candle-core = "0.4.1"
|
||||||
rand = "0.8.5"
|
rand = "0.8.5"
|
||||||
toml = "0.8.12"
|
toml = "0.8.12"
|
||||||
serde = {version = "1.0.198", features = ["derive"]}
|
serde = {version = "1.0.198", features = ["derive"]}
|
||||||
|
|
||||||
|
[features]
|
||||||
|
default = []
|
||||||
|
cuda = []
|
|
@ -12,7 +12,7 @@ use crate::CONFIG;
|
||||||
use super::state::OBSERVATION_SPACE;
|
use super::state::OBSERVATION_SPACE;
|
||||||
use super::{action::AIAction, huber::huber_loss, state::AIState};
|
use super::{action::AIAction, huber::huber_loss, state::AIState};
|
||||||
|
|
||||||
const DEVICE: Device = Device::Cpu;
|
// const DEVICE: Device = Device::Cpu;
|
||||||
|
|
||||||
const ACTION_SPACE: usize = 5;
|
const ACTION_SPACE: usize = 5;
|
||||||
|
|
||||||
|
@ -23,16 +23,22 @@ pub struct AIAgent {
|
||||||
memory: VecDeque<(Tensor, u32, Tensor, f64)>,
|
memory: VecDeque<(Tensor, u32, Tensor, f64)>,
|
||||||
old_state: Option<AIState>,
|
old_state: Option<AIState>,
|
||||||
step: usize,
|
step: usize,
|
||||||
|
device: Device,
|
||||||
accumulate_rewards: f64,
|
accumulate_rewards: f64,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl AIAgent {
|
impl AIAgent {
|
||||||
pub async fn new() -> Self {
|
pub async fn new() -> Self {
|
||||||
|
#[cfg(not(feature = "cuda"))]
|
||||||
|
let device=Device::Cpu;
|
||||||
|
#[cfg(feature = "cuda")]
|
||||||
|
let device=Device::new_cuda(0).unwrap();
|
||||||
|
|
||||||
let mut var_map = VarMap::new();
|
let mut var_map = VarMap::new();
|
||||||
if Path::new("model.bin").exists() {
|
if Path::new("model.bin").exists() {
|
||||||
var_map.load("model.bin").unwrap();
|
var_map.load("model.bin").unwrap();
|
||||||
}
|
}
|
||||||
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
|
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device);
|
||||||
let model = seq()
|
let model = seq()
|
||||||
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
|
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
|
||||||
.add(Activation::LeakyRelu(0.01))
|
.add(Activation::LeakyRelu(0.01))
|
||||||
|
@ -52,6 +58,7 @@ impl AIAgent {
|
||||||
memory: VecDeque::new(),
|
memory: VecDeque::new(),
|
||||||
old_state: None,
|
old_state: None,
|
||||||
step: 0,
|
step: 0,
|
||||||
|
device,
|
||||||
accumulate_rewards: 0.0,
|
accumulate_rewards: 0.0,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -90,7 +97,7 @@ impl AIAgent {
|
||||||
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
|
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
|
||||||
_ => self
|
_ => self
|
||||||
.model
|
.model
|
||||||
.forward(&old_state.into_tensor())
|
.forward(&old_state.into_tensor(&self.device))
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
|
@ -108,11 +115,11 @@ impl AIAgent {
|
||||||
self.old_state
|
self.old_state
|
||||||
.as_ref()
|
.as_ref()
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.into_tensor()
|
.into_tensor(&self.device)
|
||||||
.squeeze(0)
|
.squeeze(0)
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
action,
|
action,
|
||||||
state.into_tensor().squeeze(0).unwrap(),
|
state.into_tensor(&self.device).squeeze(0).unwrap(),
|
||||||
reward,
|
reward,
|
||||||
));
|
));
|
||||||
self.memory.truncate(CONFIG.replay_size);
|
self.memory.truncate(CONFIG.replay_size);
|
||||||
|
@ -144,7 +151,7 @@ impl AIAgent {
|
||||||
let states = Tensor::stack(&states, 0).unwrap();
|
let states = Tensor::stack(&states, 0).unwrap();
|
||||||
|
|
||||||
let actions = batch.iter().map(|e| e.1);
|
let actions = batch.iter().map(|e| e.1);
|
||||||
let actions = Tensor::from_iter(actions, &DEVICE)
|
let actions = Tensor::from_iter(actions, &self.device)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -153,13 +160,13 @@ impl AIAgent {
|
||||||
let next_states = Tensor::stack(&next_states, 0).unwrap();
|
let next_states = Tensor::stack(&next_states, 0).unwrap();
|
||||||
|
|
||||||
let rewards = batch.iter().map(|e| e.3 as f32);
|
let rewards = batch.iter().map(|e| e.3 as f32);
|
||||||
let rewards = Tensor::from_iter(rewards, &DEVICE)
|
let rewards = Tensor::from_iter(rewards, &self.device)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
|
||||||
let non_final_mask = batch.iter().map(|_| true as u8 as f32);
|
let non_final_mask = batch.iter().map(|_| true as u8 as f32);
|
||||||
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)
|
let non_final_mask = Tensor::from_iter(non_final_mask, &self.device)
|
||||||
.unwrap()
|
.unwrap()
|
||||||
.unsqueeze(1)
|
.unsqueeze(1)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
@ -181,7 +188,7 @@ impl AIAgent {
|
||||||
let loss = huber_loss(1.0_f32)(&x, &y);
|
let loss = huber_loss(1.0_f32)(&x, &y);
|
||||||
log::trace!("loss: {:?}", loss);
|
log::trace!("loss: {:?}", loss);
|
||||||
self.optimizer
|
self.optimizer
|
||||||
.backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
.backward_step(&Tensor::new(&[loss], &self.device).unwrap())
|
||||||
.unwrap();
|
.unwrap();
|
||||||
}
|
}
|
||||||
pub fn check_point(&mut self) {
|
pub fn check_point(&mut self) {
|
||||||
|
|
|
@ -45,8 +45,8 @@ impl AIState {
|
||||||
.filter(|x| x.score.is_sign_negative())
|
.filter(|x| x.score.is_sign_negative())
|
||||||
.min_by_key(food_distance(&self.player))
|
.min_by_key(food_distance(&self.player))
|
||||||
}
|
}
|
||||||
pub fn into_tensor(&self) -> Tensor {
|
pub fn into_tensor(&self, device:&Device) -> Tensor {
|
||||||
Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap()
|
Tensor::new(&[self.into_feature()],device).unwrap()
|
||||||
}
|
}
|
||||||
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
|
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
|
||||||
let x = self.player.x;
|
let x = self.player.x;
|
||||||
|
|
|
@ -1,25 +0,0 @@
|
||||||
// use candle_core::{DType, Device};
|
|
||||||
// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap};
|
|
||||||
|
|
||||||
fn main() {
|
|
||||||
// let mut var_map = VarMap::new();
|
|
||||||
// var_map.load("model.bin").unwrap();
|
|
||||||
// let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu);
|
|
||||||
// let model = seq()
|
|
||||||
// .add(linear(14, 60, vb.pp("linear_in")).unwrap())
|
|
||||||
// .add(Activation::LeakyRelu(0.01))
|
|
||||||
// .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
|
|
||||||
// .add(Activation::LeakyRelu(0.01))
|
|
||||||
// .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
|
|
||||||
// .add(Activation::LeakyRelu(0.01))
|
|
||||||
// .add(linear(48, 5, vb.pp("linear_out")).unwrap())
|
|
||||||
// .add(Activation::LeakyRelu(0.01));
|
|
||||||
|
|
||||||
// let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap();
|
|
||||||
|
|
||||||
// let target = Tensor::new(&[0.0], &Device::Cpu).unwrap();
|
|
||||||
|
|
||||||
// self.optimizer
|
|
||||||
// .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
|
||||||
// .unwrap();
|
|
||||||
}
|
|
Loading…
Reference in New Issue