From 5a0f7d066c1a584c005e71cf7cd208c56287cd8d Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Wed, 12 Jun 2024 20:53:11 +0800 Subject: [PATCH] add NaN check --- tank-rust/src/dqn/collect.rs | 4 +- tank-rust/src/dqn/feature.rs | 6 +- tank-rust/src/dqn/mod.rs | 4 +- tank-rust/src/fit.rs | 160 +++++++++++++++++++---------------- 4 files changed, 96 insertions(+), 78 deletions(-) diff --git a/tank-rust/src/dqn/collect.rs b/tank-rust/src/dqn/collect.rs index 65eaf9a..c6517ab 100644 --- a/tank-rust/src/dqn/collect.rs +++ b/tank-rust/src/dqn/collect.rs @@ -65,7 +65,7 @@ impl<'a> App<'a> { } let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { - true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) { + true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) { 0 => Action::TurnRight, 1 => Action::TurnLeft, 2 => Action::AimRight, @@ -83,7 +83,7 @@ impl<'a> App<'a> { pub fn predict_action(&self, state: &Info) -> Action { let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size] let ans = self.model.forward(input); - match ans.argmax(1).into_scalar(){ + match ans.argmax(1).into_scalar() { 0 => Action::TurnRight, 1 => Action::TurnLeft, 2 => Action::AimRight, diff --git a/tank-rust/src/dqn/feature.rs b/tank-rust/src/dqn/feature.rs index ea80027..78d611d 100644 --- a/tank-rust/src/dqn/feature.rs +++ b/tank-rust/src/dqn/feature.rs @@ -118,9 +118,6 @@ impl<'a> Info<'a> { let wall = self .player .closest(self.walls.iter().map(|wall| (wall.x, wall.y))); - let bullet = self - .player - .closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y))); let target = self.get_target().get_pos(self).clip(); @@ -139,6 +136,9 @@ impl<'a> Info<'a> { } pub fn into_feature_tensor(&self, device: &B::Device) -> Tensor { let feature = self.into_feature(); + for feature in feature.iter() { + assert!(!feature.is_nan()); + } Tensor::from_floats(feature, device) } fn get_target(&self) -> Target { diff --git a/tank-rust/src/dqn/mod.rs b/tank-rust/src/dqn/mod.rs index 3050f00..9159de9 100644 --- a/tank-rust/src/dqn/mod.rs +++ b/tank-rust/src/dqn/mod.rs @@ -6,7 +6,7 @@ mod training; pub mod prelude { pub use super::collect::App as DQNApp; - pub use super::dataset::{TankDataset, TankItem}; + pub use super::dataset::{TankBatcher, TankDataset, TankItem}; pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE}; - pub use super::training::run as train; + pub use super::training::{run as train, ExpConfig}; } diff --git a/tank-rust/src/fit.rs b/tank-rust/src/fit.rs index 404c6da..e28a390 100644 --- a/tank-rust/src/fit.rs +++ b/tank-rust/src/fit.rs @@ -1,77 +1,95 @@ -use burn::data::dataset::Dataset; +use burn::{ + data::dataloader::DataLoaderBuilder, + optim::{AdamConfig, SgdConfig}, + record::{CompactRecorder, NoStdTrainingRecorder}, + tensor::backend::AutodiffBackend, + train::{ + metric::{ + store::{Aggregate, Direction, Split}, + LossMetric, + }, + LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition, + }, +}; -use crate::dqn::prelude::TankItem; -use crate::ffi::prelude::*; -use rand::Rng; +use crate::dqn::prelude::{ExpConfig, TankBatcher, TankDataset}; -// fn random_action() -> Action { -// let mut rng = rand::thread_rng(); -// match rng.gen_range(0..2) { -// 0 => Action::AimLeft, -// 1 => Action::Forward, -// _ => unreachable!(), -// } -// } +pub fn run(device: B::Device) { + // let d = [ + // feature[0], + // -feature[0], + // shoot_target_angle*0.7*feature[2], + // -shoot_target_angle*0.7*feature[2], + // 8.0 * feature[2] / shoot_target_distance / shoot_target_angle, + // feature[2]*shoot_target_distance*0.3-feature[2], + // ]; -// fn random_item() -> TankItem { -// let mut previous_info=Info::default(); -// TankItem { -// previous_state: todo!(), -// new_state: todo!(), -// action: todo!(), -// reward: todo!(), -// } -// } + let optimizer = AdamConfig::new(); + let config = ExpConfig::new(optimizer); + let mut model = DQNModelConfig::new().init(&device); -pub struct FitDataset; - -impl FitDataset { - /// Get closer to the power station - fn close_power_station() -> TankItem { - let mut power_stations = Station::default(); - - let mut previous_info = Info::default(); - let mut new_info = Info::default(); - let mut rng = rand::thread_rng(); - previous_info.player.power = rng.gen_range(0..2); - new_info.player.power = previous_info.player.power; - previous_info.player.angle = rng.gen_range(0..360); - new_info.player.angle = previous_info.player.angle; - - TankItem { - previous_state: todo!(), - new_state: todo!(), - action: Action::Forward, - reward: todo!(), - } - } - /// Flee from power station if power is high - fn flee_power_station() -> TankItem { - let mut previous_info = Info::default(); - TankItem { - previous_state: todo!(), - new_state: todo!(), - action: Action::Backward, - reward: todo!(), - } - } -} - -impl Dataset for FitDataset { - fn get(&self, _: usize) -> Option { - let previous_state = todo!(); - let new_state = todo!(); - let action = Action::AimLeft; - let reward = 0.0; - Some(TankItem { - previous_state, - new_state, - action, - reward, - }) - } - - fn len(&self) -> usize { - 1 + if fs::metadata(format!("{model_path}/model")).is_ok() { + model = model + .load_file( + format!("{model_path}/model"), + &NoStdTrainingRecorder::new(), + &device, + ) + .unwrap(); } + + // Define train/test datasets and dataloaders + + let train_dataset = TankDataset::train(); + let test_dataset = TankDataset::test(); + + println!("Train Dataset Size: {}", train_dataset.len()); + println!("Test Dataset Size: {}", test_dataset.len()); + + let batcher_train = TankBatcher::::new(device.clone()); + + let batcher_test = TankBatcher::::new(device.clone()); + + // Since dataset size is small, we do full batch gradient descent and set batch size equivalent to size of dataset + + let dataloader_train = DataLoaderBuilder::new(batcher_train) + .batch_size(train_dataset.len()) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(train_dataset); + + let dataloader_test = DataLoaderBuilder::new(batcher_test) + .batch_size(test_dataset.len()) + .shuffle(config.seed) + .num_workers(config.num_workers) + .build(test_dataset); + + // Model + let learner = LearnerBuilder::new(ARTIFACT_DIR) + .metric_train_numeric(LossMetric::new()) + .metric_valid_numeric(LossMetric::new()) + .with_file_checkpointer(CompactRecorder::new()) + .early_stopping(MetricEarlyStoppingStrategy::new::>( + Aggregate::Mean, + Direction::Lowest, + Split::Valid, + StoppingCondition::NoImprovementSince { n_epochs: 1 }, + )) + .devices(vec![device.clone()]) + .num_epochs(config.num_epochs) + .summary() + .build(model, config.optimizer.init(), 5e-3); + + let model_trained = learner.fit(dataloader_train, dataloader_test); + + config + .save(format!("{ARTIFACT_DIR}/config.json").as_str()) + .unwrap(); + + model_trained + .save_file( + format!("{ARTIFACT_DIR}/model"), + &NoStdTrainingRecorder::new(), + ) + .expect("Failed to save trained model"); }