add cuda support
This commit is contained in:
parent
1a777943ff
commit
d3ce103d67
2
justfile
2
justfile
|
@ -2,6 +2,8 @@ test:
|
|||
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
|
||||
build:
|
||||
cd pyr && cargo build --release
|
||||
build-cuda:
|
||||
cd pyr && cargo build --release --feature cuda
|
||||
train level:
|
||||
run level:
|
||||
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3
|
||||
|
|
|
@ -22,3 +22,7 @@ candle-core = "0.4.1"
|
|||
rand = "0.8.5"
|
||||
toml = "0.8.12"
|
||||
serde = {version = "1.0.198", features = ["derive"]}
|
||||
|
||||
[features]
|
||||
default = []
|
||||
cuda = []
|
|
@ -12,7 +12,7 @@ use crate::CONFIG;
|
|||
use super::state::OBSERVATION_SPACE;
|
||||
use super::{action::AIAction, huber::huber_loss, state::AIState};
|
||||
|
||||
const DEVICE: Device = Device::Cpu;
|
||||
// const DEVICE: Device = Device::Cpu;
|
||||
|
||||
const ACTION_SPACE: usize = 5;
|
||||
|
||||
|
@ -23,16 +23,22 @@ pub struct AIAgent {
|
|||
memory: VecDeque<(Tensor, u32, Tensor, f64)>,
|
||||
old_state: Option<AIState>,
|
||||
step: usize,
|
||||
device: Device,
|
||||
accumulate_rewards: f64,
|
||||
}
|
||||
|
||||
impl AIAgent {
|
||||
pub async fn new() -> Self {
|
||||
#[cfg(not(feature = "cuda"))]
|
||||
let device=Device::Cpu;
|
||||
#[cfg(feature = "cuda")]
|
||||
let device=Device::new_cuda(0).unwrap();
|
||||
|
||||
let mut var_map = VarMap::new();
|
||||
if Path::new("model.bin").exists() {
|
||||
var_map.load("model.bin").unwrap();
|
||||
}
|
||||
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
|
||||
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &device);
|
||||
let model = seq()
|
||||
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
|
||||
.add(Activation::LeakyRelu(0.01))
|
||||
|
@ -52,6 +58,7 @@ impl AIAgent {
|
|||
memory: VecDeque::new(),
|
||||
old_state: None,
|
||||
step: 0,
|
||||
device,
|
||||
accumulate_rewards: 0.0,
|
||||
}
|
||||
}
|
||||
|
@ -90,7 +97,7 @@ impl AIAgent {
|
|||
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
|
||||
_ => self
|
||||
.model
|
||||
.forward(&old_state.into_tensor())
|
||||
.forward(&old_state.into_tensor(&self.device))
|
||||
.unwrap()
|
||||
.squeeze(0)
|
||||
.unwrap()
|
||||
|
@ -108,11 +115,11 @@ impl AIAgent {
|
|||
self.old_state
|
||||
.as_ref()
|
||||
.unwrap()
|
||||
.into_tensor()
|
||||
.into_tensor(&self.device)
|
||||
.squeeze(0)
|
||||
.unwrap(),
|
||||
action,
|
||||
state.into_tensor().squeeze(0).unwrap(),
|
||||
state.into_tensor(&self.device).squeeze(0).unwrap(),
|
||||
reward,
|
||||
));
|
||||
self.memory.truncate(CONFIG.replay_size);
|
||||
|
@ -144,7 +151,7 @@ impl AIAgent {
|
|||
let states = Tensor::stack(&states, 0).unwrap();
|
||||
|
||||
let actions = batch.iter().map(|e| e.1);
|
||||
let actions = Tensor::from_iter(actions, &DEVICE)
|
||||
let actions = Tensor::from_iter(actions, &self.device)
|
||||
.unwrap()
|
||||
.unsqueeze(1)
|
||||
.unwrap();
|
||||
|
@ -153,13 +160,13 @@ impl AIAgent {
|
|||
let next_states = Tensor::stack(&next_states, 0).unwrap();
|
||||
|
||||
let rewards = batch.iter().map(|e| e.3 as f32);
|
||||
let rewards = Tensor::from_iter(rewards, &DEVICE)
|
||||
let rewards = Tensor::from_iter(rewards, &self.device)
|
||||
.unwrap()
|
||||
.unsqueeze(1)
|
||||
.unwrap();
|
||||
|
||||
let non_final_mask = batch.iter().map(|_| true as u8 as f32);
|
||||
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)
|
||||
let non_final_mask = Tensor::from_iter(non_final_mask, &self.device)
|
||||
.unwrap()
|
||||
.unsqueeze(1)
|
||||
.unwrap();
|
||||
|
@ -181,7 +188,7 @@ impl AIAgent {
|
|||
let loss = huber_loss(1.0_f32)(&x, &y);
|
||||
log::trace!("loss: {:?}", loss);
|
||||
self.optimizer
|
||||
.backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
||||
.backward_step(&Tensor::new(&[loss], &self.device).unwrap())
|
||||
.unwrap();
|
||||
}
|
||||
pub fn check_point(&mut self) {
|
||||
|
|
|
@ -45,8 +45,8 @@ impl AIState {
|
|||
.filter(|x| x.score.is_sign_negative())
|
||||
.min_by_key(food_distance(&self.player))
|
||||
}
|
||||
pub fn into_tensor(&self) -> Tensor {
|
||||
Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap()
|
||||
pub fn into_tensor(&self, device:&Device) -> Tensor {
|
||||
Tensor::new(&[self.into_feature()],device).unwrap()
|
||||
}
|
||||
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
|
||||
let x = self.player.x;
|
||||
|
|
|
@ -1,25 +0,0 @@
|
|||
// use candle_core::{DType, Device};
|
||||
// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap};
|
||||
|
||||
fn main() {
|
||||
// let mut var_map = VarMap::new();
|
||||
// var_map.load("model.bin").unwrap();
|
||||
// let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu);
|
||||
// let model = seq()
|
||||
// .add(linear(14, 60, vb.pp("linear_in")).unwrap())
|
||||
// .add(Activation::LeakyRelu(0.01))
|
||||
// .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
|
||||
// .add(Activation::LeakyRelu(0.01))
|
||||
// .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
|
||||
// .add(Activation::LeakyRelu(0.01))
|
||||
// .add(linear(48, 5, vb.pp("linear_out")).unwrap())
|
||||
// .add(Activation::LeakyRelu(0.01));
|
||||
|
||||
// let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap();
|
||||
|
||||
// let target = Tensor::new(&[0.0], &Device::Cpu).unwrap();
|
||||
|
||||
// self.optimizer
|
||||
// .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
||||
// .unwrap();
|
||||
}
|
Loading…
Reference in New Issue