From 8468e298182962360d7c44bd1a9838454feb1612 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Thu, 13 Jun 2024 00:01:59 +0800 Subject: [PATCH] add epsilon --- tank-rust/src/dqn/collect.rs | 18 +++++++++++++++--- train.sh | 3 ++- 2 files changed, 17 insertions(+), 4 deletions(-) diff --git a/tank-rust/src/dqn/collect.rs b/tank-rust/src/dqn/collect.rs index c6517ab..e1074b3 100644 --- a/tank-rust/src/dqn/collect.rs +++ b/tank-rust/src/dqn/collect.rs @@ -10,14 +10,13 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor use rand::{thread_rng, Rng}; use super::model::{DQNModel, DQNModelConfig}; -const EXPLORE_RATE: f32 = 0.8; - pub struct App<'a> { model: DQNModel, device: WgpuDevice, last_state_action: Option<(Info<'a>, Action)>, #[cfg(feature = "train")] outlet: BufWriter, + explore_rate: f32, } impl<'a> App<'a> { @@ -33,6 +32,17 @@ impl<'a> App<'a> { &device, ) .unwrap(); + #[cfg(feature = "train")] + let explore_rate = std::env::var("EPSILON") + .map(|x| { + let n: usize = x.parse().ok()?; + Some(1.0 / (n as f32 + 2.0).log2() - 0.03) + }) + .into_iter() + .flatten() + .next() + .unwrap_or(0.4); + Self { model, device, @@ -45,6 +55,8 @@ impl<'a> App<'a> { .open(format!("{model_path}/dataset")) .unwrap(), ), + #[cfg(feature = "train")] + explore_rate, } } #[cfg(feature = "train")] @@ -64,7 +76,7 @@ impl<'a> App<'a> { bincode::serialize_into(&mut self.outlet, &item).unwrap(); } - let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { + let action = match thread_rng().gen_ratio((4096.0 * self.explore_rate) as u32, 4096) { true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) { 0 => Action::TurnRight, 1 => Action::TurnLeft, diff --git a/train.sh b/train.sh index 9d75305..62d0c29 100755 --- a/train.sh +++ b/train.sh @@ -9,7 +9,8 @@ for i in $(seq 1 $1); do echo "epoch $i" cd TankMan - timeout 240 python -m mlgame -f 3000 \ + export EPSILON=$i + timeout 200 python -m mlgame -f 3000 \ -i ../ml/collect1.py -i ../ml/collect2.py \ . --green_team_num 1 --blue_team_num 1 --is_manual 1 \ --frame_limit 1000 > /dev/null