forked from easonabc-public/paia-hw5
Compare commits
12 Commits
Author | SHA1 | Date |
---|---|---|
|
b6873265c9 | |
|
951cd00ed8 | |
|
7c11d9d75c | |
|
8468e29818 | |
|
abcd24240f | |
|
336594dfa4 | |
|
4aec307596 | |
|
6708a3a677 | |
|
4558abf160 | |
|
d42bdb758c | |
|
d9b2a372f1 | |
|
5a0f7d066c |
|
@ -10,14 +10,14 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
|
||||||
use super::model::{DQNModel, DQNModelConfig};
|
use super::model::{DQNModel, DQNModelConfig};
|
||||||
const EXPLORE_RATE: f32 = 0.8;
|
|
||||||
|
|
||||||
pub struct App<'a> {
|
pub struct App<'a> {
|
||||||
model: DQNModel<Backend>,
|
model: DQNModel<Backend>,
|
||||||
device: WgpuDevice,
|
device: WgpuDevice,
|
||||||
last_state_action: Option<(Info<'a>, Action)>,
|
last_state_action: Option<(Info<'a>, Action)>,
|
||||||
#[cfg(feature = "train")]
|
#[cfg(feature = "train")]
|
||||||
outlet: BufWriter<File>,
|
outlet: BufWriter<File>,
|
||||||
|
#[cfg(feature = "train")]
|
||||||
|
explore_rate: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> App<'a> {
|
impl<'a> App<'a> {
|
||||||
|
@ -33,6 +33,17 @@ impl<'a> App<'a> {
|
||||||
&device,
|
&device,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
#[cfg(feature = "train")]
|
||||||
|
let explore_rate = std::env::var("EPSILON")
|
||||||
|
.map(|x| {
|
||||||
|
let n: usize = x.parse().ok()?;
|
||||||
|
Some(1.0 / (n as f32 + 2.0).log2() - 0.03)
|
||||||
|
})
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(0.4);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
|
@ -45,6 +56,8 @@ impl<'a> App<'a> {
|
||||||
.open(format!("{model_path}/dataset"))
|
.open(format!("{model_path}/dataset"))
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
),
|
),
|
||||||
|
#[cfg(feature = "train")]
|
||||||
|
explore_rate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(feature = "train")]
|
#[cfg(feature = "train")]
|
||||||
|
@ -64,7 +77,7 @@ impl<'a> App<'a> {
|
||||||
bincode::serialize_into(&mut self.outlet, &item).unwrap();
|
bincode::serialize_into(&mut self.outlet, &item).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
let action = match thread_rng().gen_ratio((4096.0 * self.explore_rate) as u32, 4096) {
|
||||||
true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
|
true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
|
||||||
0 => Action::TurnRight,
|
0 => Action::TurnRight,
|
||||||
1 => Action::TurnLeft,
|
1 => Action::TurnLeft,
|
||||||
|
|
|
@ -18,7 +18,7 @@ impl Polar {
|
||||||
pub fn clip(&self) -> Self {
|
pub fn clip(&self) -> Self {
|
||||||
Polar {
|
Polar {
|
||||||
angle: self.angle,
|
angle: self.angle,
|
||||||
distance: self.distance.min(1e6).max(0.0),
|
distance: self.distance.max(0.0).min(1e3),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -118,9 +118,6 @@ impl<'a> Info<'a> {
|
||||||
let wall = self
|
let wall = self
|
||||||
.player
|
.player
|
||||||
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
|
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
|
||||||
let bullet = self
|
|
||||||
.player
|
|
||||||
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
|
|
||||||
|
|
||||||
let target = self.get_target().get_pos(self).clip();
|
let target = self.get_target().get_pos(self).clip();
|
||||||
|
|
||||||
|
@ -131,7 +128,7 @@ impl<'a> Info<'a> {
|
||||||
normalize_angle(target.angle - angle).tanh(),
|
normalize_angle(target.angle - angle).tanh(),
|
||||||
(wall.distance - target.distance).tanh(),
|
(wall.distance - target.distance).tanh(),
|
||||||
(self.player.power as f32).tanh(),
|
(self.player.power as f32).tanh(),
|
||||||
(wall.distance + 1.0).log2(),
|
(wall.clip().distance + 1.0).log2(),
|
||||||
(emeny.distance + 1.0).log2(),
|
(emeny.distance + 1.0).log2(),
|
||||||
normalize_angle(emeny.angle - gun_angle).tanh(),
|
normalize_angle(emeny.angle - gun_angle).tanh(),
|
||||||
normalize_angle(wall.angle - gun_angle).tanh(),
|
normalize_angle(wall.angle - gun_angle).tanh(),
|
||||||
|
@ -139,6 +136,7 @@ impl<'a> Info<'a> {
|
||||||
}
|
}
|
||||||
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
||||||
let feature = self.into_feature();
|
let feature = self.into_feature();
|
||||||
|
|
||||||
Tensor::from_floats(feature, device)
|
Tensor::from_floats(feature, device)
|
||||||
}
|
}
|
||||||
fn get_target(&self) -> Target {
|
fn get_target(&self) -> Target {
|
||||||
|
@ -179,7 +177,7 @@ impl<'a> Info<'a> {
|
||||||
|
|
||||||
reward
|
reward
|
||||||
+ match next.player.score - self.player.score {
|
+ match next.player.score - self.player.score {
|
||||||
x if x > 2 => 20.0,
|
x if x > 2 => 20.0, // bypass emeny
|
||||||
x if x > 0 => 10.0, // too high, tank my ignore power station
|
x if x > 0 => 10.0, // too high, tank my ignore power station
|
||||||
_ => -1.0,
|
_ => -1.0,
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,7 +6,7 @@ mod training;
|
||||||
|
|
||||||
pub mod prelude {
|
pub mod prelude {
|
||||||
pub use super::collect::App as DQNApp;
|
pub use super::collect::App as DQNApp;
|
||||||
pub use super::dataset::{TankDataset, TankItem};
|
pub use super::dataset::{TankBatcher, TankDataset, TankItem};
|
||||||
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
|
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
|
||||||
pub use super::training::run as train;
|
pub use super::training::{run as train, ExpConfig};
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,77 +1,95 @@
|
||||||
use burn::data::dataset::Dataset;
|
use burn::{
|
||||||
|
data::dataloader::DataLoaderBuilder,
|
||||||
|
optim::{AdamConfig, SgdConfig},
|
||||||
|
record::{CompactRecorder, NoStdTrainingRecorder},
|
||||||
|
tensor::backend::AutodiffBackend,
|
||||||
|
train::{
|
||||||
|
metric::{
|
||||||
|
store::{Aggregate, Direction, Split},
|
||||||
|
LossMetric,
|
||||||
|
},
|
||||||
|
LearnerBuilder, MetricEarlyStoppingStrategy, StoppingCondition,
|
||||||
|
},
|
||||||
|
};
|
||||||
|
|
||||||
use crate::dqn::prelude::TankItem;
|
use crate::dqn::prelude::{ExpConfig, TankBatcher, TankDataset};
|
||||||
use crate::ffi::prelude::*;
|
|
||||||
use rand::Rng;
|
|
||||||
|
|
||||||
// fn random_action() -> Action {
|
pub fn run<B: AutodiffBackend>(device: B::Device) {
|
||||||
// let mut rng = rand::thread_rng();
|
// let d = [
|
||||||
// match rng.gen_range(0..2) {
|
// feature[0],
|
||||||
// 0 => Action::AimLeft,
|
// -feature[0],
|
||||||
// 1 => Action::Forward,
|
// shoot_target_angle*0.7*feature[2],
|
||||||
// _ => unreachable!(),
|
// -shoot_target_angle*0.7*feature[2],
|
||||||
// }
|
// 8.0 * feature[2] / shoot_target_distance / shoot_target_angle,
|
||||||
// }
|
// feature[2]*shoot_target_distance*0.3-feature[2],
|
||||||
|
// ];
|
||||||
|
|
||||||
// fn random_item() -> TankItem {
|
let optimizer = AdamConfig::new();
|
||||||
// let mut previous_info=Info::default();
|
let config = ExpConfig::new(optimizer);
|
||||||
// TankItem {
|
let mut model = DQNModelConfig::new().init(&device);
|
||||||
// previous_state: todo!(),
|
|
||||||
// new_state: todo!(),
|
|
||||||
// action: todo!(),
|
|
||||||
// reward: todo!(),
|
|
||||||
// }
|
|
||||||
// }
|
|
||||||
|
|
||||||
pub struct FitDataset;
|
if fs::metadata(format!("{model_path}/model")).is_ok() {
|
||||||
|
model = model
|
||||||
impl FitDataset {
|
.load_file(
|
||||||
/// Get closer to the power station
|
format!("{model_path}/model"),
|
||||||
fn close_power_station() -> TankItem {
|
&NoStdTrainingRecorder::new(),
|
||||||
let mut power_stations = Station::default();
|
&device,
|
||||||
|
)
|
||||||
let mut previous_info = Info::default();
|
.unwrap();
|
||||||
let mut new_info = Info::default();
|
|
||||||
let mut rng = rand::thread_rng();
|
|
||||||
previous_info.player.power = rng.gen_range(0..2);
|
|
||||||
new_info.player.power = previous_info.player.power;
|
|
||||||
previous_info.player.angle = rng.gen_range(0..360);
|
|
||||||
new_info.player.angle = previous_info.player.angle;
|
|
||||||
|
|
||||||
TankItem {
|
|
||||||
previous_state: todo!(),
|
|
||||||
new_state: todo!(),
|
|
||||||
action: Action::Forward,
|
|
||||||
reward: todo!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
/// Flee from power station if power is high
|
|
||||||
fn flee_power_station() -> TankItem {
|
|
||||||
let mut previous_info = Info::default();
|
|
||||||
TankItem {
|
|
||||||
previous_state: todo!(),
|
|
||||||
new_state: todo!(),
|
|
||||||
action: Action::Backward,
|
|
||||||
reward: todo!(),
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
impl Dataset<TankItem> for FitDataset {
|
// Define train/test datasets and dataloaders
|
||||||
fn get(&self, _: usize) -> Option<TankItem> {
|
|
||||||
let previous_state = todo!();
|
|
||||||
let new_state = todo!();
|
|
||||||
let action = Action::AimLeft;
|
|
||||||
let reward = 0.0;
|
|
||||||
Some(TankItem {
|
|
||||||
previous_state,
|
|
||||||
new_state,
|
|
||||||
action,
|
|
||||||
reward,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
fn len(&self) -> usize {
|
let train_dataset = TankDataset::train();
|
||||||
1
|
let test_dataset = TankDataset::test();
|
||||||
}
|
|
||||||
|
println!("Train Dataset Size: {}", train_dataset.len());
|
||||||
|
println!("Test Dataset Size: {}", test_dataset.len());
|
||||||
|
|
||||||
|
let batcher_train = TankBatcher::<B>::new(device.clone());
|
||||||
|
|
||||||
|
let batcher_test = TankBatcher::<B::InnerBackend>::new(device.clone());
|
||||||
|
|
||||||
|
// Since dataset size is small, we do full batch gradient descent and set batch size equivalent to size of dataset
|
||||||
|
|
||||||
|
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||||
|
.batch_size(train_dataset.len())
|
||||||
|
.shuffle(config.seed)
|
||||||
|
.num_workers(config.num_workers)
|
||||||
|
.build(train_dataset);
|
||||||
|
|
||||||
|
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
||||||
|
.batch_size(test_dataset.len())
|
||||||
|
.shuffle(config.seed)
|
||||||
|
.num_workers(config.num_workers)
|
||||||
|
.build(test_dataset);
|
||||||
|
|
||||||
|
// Model
|
||||||
|
let learner = LearnerBuilder::new(ARTIFACT_DIR)
|
||||||
|
.metric_train_numeric(LossMetric::new())
|
||||||
|
.metric_valid_numeric(LossMetric::new())
|
||||||
|
.with_file_checkpointer(CompactRecorder::new())
|
||||||
|
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
|
||||||
|
Aggregate::Mean,
|
||||||
|
Direction::Lowest,
|
||||||
|
Split::Valid,
|
||||||
|
StoppingCondition::NoImprovementSince { n_epochs: 1 },
|
||||||
|
))
|
||||||
|
.devices(vec![device.clone()])
|
||||||
|
.num_epochs(config.num_epochs)
|
||||||
|
.summary()
|
||||||
|
.build(model, config.optimizer.init(), 5e-3);
|
||||||
|
|
||||||
|
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||||
|
|
||||||
|
config
|
||||||
|
.save(format!("{ARTIFACT_DIR}/config.json").as_str())
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
model_trained
|
||||||
|
.save_file(
|
||||||
|
format!("{ARTIFACT_DIR}/model"),
|
||||||
|
&NoStdTrainingRecorder::new(),
|
||||||
|
)
|
||||||
|
.expect("Failed to save trained model");
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
mod dqn;
|
mod dqn;
|
||||||
mod ffi;
|
mod ffi;
|
||||||
use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
|
use std::ffi::OsString;
|
||||||
|
|
||||||
use burn::backend::{wgpu::AutoGraphicsApi, Wgpu};
|
use burn::backend::{wgpu::AutoGraphicsApi, Wgpu};
|
||||||
use dqn::prelude::*;
|
use dqn::prelude::*;
|
||||||
|
@ -12,11 +12,9 @@ type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||||
|
|
||||||
#[no_mangle]
|
#[no_mangle]
|
||||||
pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> {
|
pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> {
|
||||||
let model_path =
|
let model_path = unsafe { std::slice::from_raw_parts(model_path, len as usize) };
|
||||||
OsStr::from_bytes(unsafe { std::slice::from_raw_parts(model_path, len as usize) })
|
let model_path = unsafe { OsString::from_encoded_bytes_unchecked(model_path.to_vec()) };
|
||||||
.to_str()
|
let app = DQNApp::new(model_path.to_str().unwrap());
|
||||||
.unwrap();
|
|
||||||
let app = DQNApp::new(model_path);
|
|
||||||
|
|
||||||
Box::into_raw(Box::new(app))
|
Box::into_raw(Box::new(app))
|
||||||
}
|
}
|
||||||
|
|
3
train.sh
3
train.sh
|
@ -9,7 +9,8 @@ for i in $(seq 1 $1);
|
||||||
do
|
do
|
||||||
echo "epoch $i"
|
echo "epoch $i"
|
||||||
cd TankMan
|
cd TankMan
|
||||||
timeout 240 python -m mlgame -f 3000 \
|
export EPSILON=$i
|
||||||
|
timeout 200 python -m mlgame -f 3000 \
|
||||||
-i ../ml/collect1.py -i ../ml/collect2.py \
|
-i ../ml/collect1.py -i ../ml/collect2.py \
|
||||||
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \
|
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \
|
||||||
--frame_limit 1000 > /dev/null
|
--frame_limit 1000 > /dev/null
|
||||||
|
|
Loading…
Reference in New Issue