Compare commits

..

No commits in common. "main" and "main" have entirely different histories.
main ... main

7 changed files with 90 additions and 119 deletions

View File

@ -1 +0,0 @@
**請助教看`tank-rust/dqn/`下的程式**

View File

@ -10,14 +10,14 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use super::model::{DQNModel, DQNModelConfig}; use super::model::{DQNModel, DQNModelConfig};
const EXPLORE_RATE: f32 = 0.8;
pub struct App<'a> { pub struct App<'a> {
model: DQNModel<Backend>, model: DQNModel<Backend>,
device: WgpuDevice, device: WgpuDevice,
last_state_action: Option<(Info<'a>, Action)>, last_state_action: Option<(Info<'a>, Action)>,
#[cfg(feature = "train")] #[cfg(feature = "train")]
outlet: BufWriter<File>, outlet: BufWriter<File>,
#[cfg(feature = "train")]
explore_rate: f32,
} }
impl<'a> App<'a> { impl<'a> App<'a> {
@ -33,17 +33,6 @@ impl<'a> App<'a> {
&device, &device,
) )
.unwrap(); .unwrap();
#[cfg(feature = "train")]
let explore_rate = std::env::var("EPSILON")
.map(|x| {
let n: usize = x.parse().ok()?;
Some(1.0 / (n as f32 + 2.0).log2() - 0.03)
})
.into_iter()
.flatten()
.next()
.unwrap_or(0.4);
Self { Self {
model, model,
device, device,
@ -56,8 +45,6 @@ impl<'a> App<'a> {
.open(format!("{model_path}/dataset")) .open(format!("{model_path}/dataset"))
.unwrap(), .unwrap(),
), ),
#[cfg(feature = "train")]
explore_rate,
} }
} }
#[cfg(feature = "train")] #[cfg(feature = "train")]
@ -77,7 +64,7 @@ impl<'a> App<'a> {
bincode::serialize_into(&mut self.outlet, &item).unwrap(); bincode::serialize_into(&mut self.outlet, &item).unwrap();
} }
let action = match thread_rng().gen_ratio((4096.0 * self.explore_rate) as u32, 4096) { let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) { true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) {
0 => Action::TurnRight, 0 => Action::TurnRight,
1 => Action::TurnLeft, 1 => Action::TurnLeft,

View File

@ -18,7 +18,7 @@ impl Polar {
pub fn clip(&self) -> Self { pub fn clip(&self) -> Self {
Polar { Polar {
angle: self.angle, angle: self.angle,
distance: self.distance.max(0.0).min(1e3), distance: self.distance.min(1e6).max(0.0),
} }
} }
} }
@ -118,6 +118,9 @@ impl<'a> Info<'a> {
let wall = self let wall = self
.player .player
.closest(self.walls.iter().map(|wall| (wall.x, wall.y))); .closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
let bullet = self
.player
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
let target = self.get_target().get_pos(self).clip(); let target = self.get_target().get_pos(self).clip();
@ -128,7 +131,7 @@ impl<'a> Info<'a> {
normalize_angle(target.angle - angle).tanh(), normalize_angle(target.angle - angle).tanh(),
(wall.distance - target.distance).tanh(), (wall.distance - target.distance).tanh(),
(self.player.power as f32).tanh(), (self.player.power as f32).tanh(),
(wall.clip().distance + 1.0).log2(), (wall.distance + 1.0).log2(),
(emeny.distance + 1.0).log2(), (emeny.distance + 1.0).log2(),
normalize_angle(emeny.angle - gun_angle).tanh(), normalize_angle(emeny.angle - gun_angle).tanh(),
normalize_angle(wall.angle - gun_angle).tanh(), normalize_angle(wall.angle - gun_angle).tanh(),
@ -136,7 +139,6 @@ impl<'a> Info<'a> {
} }
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> { pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature(); let feature = self.into_feature();
Tensor::from_floats(feature, device) Tensor::from_floats(feature, device)
} }
fn get_target(&self) -> Target { fn get_target(&self) -> Target {
@ -177,7 +179,7 @@ impl<'a> Info<'a> {
reward reward
+ match next.player.score - self.player.score { + match next.player.score - self.player.score {
x if x > 2 => 20.0, // bypass emeny x if x > 2 => 20.0,
x if x > 0 => 10.0, // too high, tank my ignore power station x if x > 0 => 10.0, // too high, tank my ignore power station
_ => -1.0, _ => -1.0,
} }

View File

@ -6,7 +6,7 @@ mod training;
pub mod prelude { pub mod prelude {
pub use super::collect::App as DQNApp; pub use super::collect::App as DQNApp;
pub use super::dataset::{TankBatcher, TankDataset, TankItem}; pub use super::dataset::{TankDataset, TankItem};
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE}; pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
pub use super::training::{run as train, ExpConfig}; pub use super::training::run as train;
} }

View File

@ -1,95 +1,77 @@
use burn::{ use burn::data::dataset::Dataset;
data::dataloader::DataLoaderBuilder,
optim::{AdamConfig, SgdConfig},
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{
metric::{
store::{Aggregate, Direction, Split},
LossMetric,
},
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
},
};
use crate::dqn::prelude::{ExpConfig, TankBatcher, TankDataset}; use crate::dqn::prelude::TankItem;
use crate::ffi::prelude::*;
use rand::Rng;
pub fn run<B: AutodiffBackend>(device: B::Device) { // fn random_action() -> Action {
// let d = [ // let mut rng = rand::thread_rng();
// feature[0], // match rng.gen_range(0..2) {
// -feature[0], // 0 => Action::AimLeft,
// shoot_target_angle*0.7*feature[2], // 1 => Action::Forward,
// -shoot_target_angle*0.7*feature[2], // _ => unreachable!(),
// 8.0 * feature[2] / shoot_target_distance / shoot_target_angle, // }
// feature[2]*shoot_target_distance*0.3-feature[2], // }
// ];
let optimizer = AdamConfig::new(); // fn random_item() -> TankItem {
let config = ExpConfig::new(optimizer); // let mut previous_info=Info::default();
let mut model = DQNModelConfig::new().init(&device); // TankItem {
// previous_state: todo!(),
// new_state: todo!(),
// action: todo!(),
// reward: todo!(),
// }
// }
if fs::metadata(format!("{model_path}/model")).is_ok() { pub struct FitDataset;
model = model
.load_file( impl FitDataset {
format!("{model_path}/model"), /// Get closer to the power station
&NoStdTrainingRecorder::new(), fn close_power_station() -> TankItem {
&device, let mut power_stations = Station::default();
)
.unwrap(); let mut previous_info = Info::default();
let mut new_info = Info::default();
let mut rng = rand::thread_rng();
previous_info.player.power = rng.gen_range(0..2);
new_info.player.power = previous_info.player.power;
previous_info.player.angle = rng.gen_range(0..360);
new_info.player.angle = previous_info.player.angle;
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Forward,
reward: todo!(),
}
}
/// Flee from power station if power is high
fn flee_power_station() -> TankItem {
let mut previous_info = Info::default();
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Backward,
reward: todo!(),
}
}
} }
// Define train/test datasets and dataloaders impl Dataset<TankItem> for FitDataset {
fn get(&self, _: usize) -> Option<TankItem> {
let train_dataset = TankDataset::train(); let previous_state = todo!();
let test_dataset = TankDataset::test(); let new_state = todo!();
let action = Action::AimLeft;
println!("Train Dataset Size: {}", train_dataset.len()); let reward = 0.0;
println!("Test Dataset Size: {}", test_dataset.len()); Some(TankItem {
previous_state,
let batcher_train = TankBatcher::<B>::new(device.clone()); new_state,
action,
let batcher_test = TankBatcher::<B::InnerBackend>::new(device.clone()); reward,
})
// Since dataset size is small, we do full batch gradient descent and set batch size equivalent to size of dataset }
let dataloader_train = DataLoaderBuilder::new(batcher_train) fn len(&self) -> usize {
.batch_size(train_dataset.len()) 1
.shuffle(config.seed) }
.num_workers(config.num_workers)
.build(train_dataset);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(test_dataset.len())
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(test_dataset);
// Model
let learner = LearnerBuilder::new(ARTIFACT_DIR)
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Valid,
StoppingCondition::NoImprovementSince { n_epochs: 1 },
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(model, config.optimizer.init(), 5e-3);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
.unwrap();
model_trained
.save_file(
format!("{ARTIFACT_DIR}/model"),
&NoStdTrainingRecorder::new(),
)
.expect("Failed to save trained model");
} }

View File

@ -1,6 +1,6 @@
mod dqn; mod dqn;
mod ffi; mod ffi;
use std::ffi::OsString; use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
use burn::backend::{wgpu::AutoGraphicsApi, Wgpu}; use burn::backend::{wgpu::AutoGraphicsApi, Wgpu};
use dqn::prelude::*; use dqn::prelude::*;
@ -12,9 +12,11 @@ type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
#[no_mangle] #[no_mangle]
pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> { pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> {
let model_path = unsafe { std::slice::from_raw_parts(model_path, len as usize) }; let model_path =
let model_path = unsafe { OsString::from_encoded_bytes_unchecked(model_path.to_vec()) }; OsStr::from_bytes(unsafe { std::slice::from_raw_parts(model_path, len as usize) })
let app = DQNApp::new(model_path.to_str().unwrap()); .to_str()
.unwrap();
let app = DQNApp::new(model_path);
Box::into_raw(Box::new(app)) Box::into_raw(Box::new(app))
} }

View File

@ -9,8 +9,7 @@ for i in $(seq 1 $1);
do do
echo "epoch $i" echo "epoch $i"
cd TankMan cd TankMan
export EPSILON=$i timeout 240 python -m mlgame -f 3000 \
timeout 200 python -m mlgame -f 3000 \
-i ../ml/collect1.py -i ../ml/collect2.py \ -i ../ml/collect1.py -i ../ml/collect2.py \
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \ . --green_team_num 1 --blue_team_num 1 --is_manual 1 \
--frame_limit 1000 > /dev/null --frame_limit 1000 > /dev/null