add NaN check

This commit is contained in:
Eason 2024-06-12 20:53:11 +08:00
parent 9e641b5f09
commit 5a0f7d066c
4 changed files with 96 additions and 78 deletions

View File

@ -65,7 +65,7 @@ impl<'a> App<'a> {
} }
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) { true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
0 => Action::TurnRight, 0 => Action::TurnRight,
1 => Action::TurnLeft, 1 => Action::TurnLeft,
2 => Action::AimRight, 2 => Action::AimRight,
@ -83,7 +83,7 @@ impl<'a> App<'a> {
pub fn predict_action(&self, state: &Info) -> Action { pub fn predict_action(&self, state: &Info) -> Action {
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size] let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
let ans = self.model.forward(input); let ans = self.model.forward(input);
match ans.argmax(1).into_scalar(){ match ans.argmax(1).into_scalar() {
0 => Action::TurnRight, 0 => Action::TurnRight,
1 => Action::TurnLeft, 1 => Action::TurnLeft,
2 => Action::AimRight, 2 => Action::AimRight,

View File

@ -118,9 +118,6 @@ impl<'a> Info<'a> {
let wall = self let wall = self
.player .player
.closest(self.walls.iter().map(|wall| (wall.x, wall.y))); .closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
let bullet = self
.player
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
let target = self.get_target().get_pos(self).clip(); let target = self.get_target().get_pos(self).clip();
@ -139,6 +136,9 @@ impl<'a> Info<'a> {
} }
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> { pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature(); let feature = self.into_feature();
for feature in feature.iter() {
assert!(!feature.is_nan());
}
Tensor::from_floats(feature, device) Tensor::from_floats(feature, device)
} }
fn get_target(&self) -> Target { fn get_target(&self) -> Target {

View File

@ -6,7 +6,7 @@ mod training;
pub mod prelude { pub mod prelude {
pub use super::collect::App as DQNApp; pub use super::collect::App as DQNApp;
pub use super::dataset::{TankDataset, TankItem}; pub use super::dataset::{TankBatcher, TankDataset, TankItem};
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE}; pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
pub use super::training::run as train; pub use super::training::{run as train, ExpConfig};
} }

View File

@ -1,77 +1,95 @@
use burn::data::dataset::Dataset; use burn::{
data::dataloader::DataLoaderBuilder,
optim::{AdamConfig, SgdConfig},
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::{
store::{Aggregate, Direction, Split},
LossMetric,
},
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
},
};
use crate::dqn::prelude::TankItem; use crate::dqn::prelude::{ExpConfig, TankBatcher, TankDataset};
use crate::ffi::prelude::*;
use rand::Rng;
// fn random_action() -> Action { pub fn run<B: AutodiffBackend>(device: B::Device) {
// let mut rng = rand::thread_rng(); // let d = [
// match rng.gen_range(0..2) { // feature[0],
// 0 => Action::AimLeft, // -feature[0],
// 1 => Action::Forward, // shoot_target_angle*0.7*feature[2],
// _ => unreachable!(), // -shoot_target_angle*0.7*feature[2],
// } // 8.0 * feature[2] / shoot_target_distance / shoot_target_angle,
// } // feature[2]*shoot_target_distance*0.3-feature[2],
// ];
// fn random_item() -> TankItem { let optimizer = AdamConfig::new();
// let mut previous_info=Info::default(); let config = ExpConfig::new(optimizer);
// TankItem { let mut model = DQNModelConfig::new().init(&device);
// previous_state: todo!(),
// new_state: todo!(),
// action: todo!(),
// reward: todo!(),
// }
// }
pub struct FitDataset; if fs::metadata(format!("{model_path}/model")).is_ok() {
model = model
impl FitDataset { .load_file(
/// Get closer to the power station format!("{model_path}/model"),
fn close_power_station() -> TankItem { &NoStdTrainingRecorder::new(),
let mut power_stations = Station::default(); &device,
)
let mut previous_info = Info::default(); .unwrap();
let mut new_info = Info::default();
let mut rng = rand::thread_rng();
previous_info.player.power = rng.gen_range(0..2);
new_info.player.power = previous_info.player.power;
previous_info.player.angle = rng.gen_range(0..360);
new_info.player.angle = previous_info.player.angle;
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Forward,
reward: todo!(),
}
}
/// Flee from power station if power is high
fn flee_power_station() -> TankItem {
let mut previous_info = Info::default();
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Backward,
reward: todo!(),
}
}
}
impl Dataset<TankItem> for FitDataset {
fn get(&self, _: usize) -> Option<TankItem> {
let previous_state = todo!();
let new_state = todo!();
let action = Action::AimLeft;
let reward = 0.0;
Some(TankItem {
previous_state,
new_state,
action,
reward,
})
}
fn len(&self) -> usize {
1
} }
// Define train/test datasets and dataloaders
let train_dataset = TankDataset::train();
let test_dataset = TankDataset::test();
println!("Train Dataset Size: {}", train_dataset.len());
println!("Test Dataset Size: {}", test_dataset.len());
let batcher_train = TankBatcher::<B>::new(device.clone());
let batcher_test = TankBatcher::<B::InnerBackend>::new(device.clone());
// Since dataset size is small, we do full batch gradient descent and set batch size equivalent to size of dataset
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(train_dataset.len())
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(train_dataset);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(test_dataset.len())
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(test_dataset);
// Model
let learner = LearnerBuilder::new(ARTIFACT_DIR)
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
StoppingCondition::NoImprovementSince { n_epochs: 1 },
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(model, config.optimizer.init(), 5e-3);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();
model_trained
.save_file(
format!("{ARTIFACT_DIR}/model"),
&NoStdTrainingRecorder::new(),
)
.expect("Failed to save trained model");
} }