add epsilon

This commit is contained in:
Eason 2024-06-13 00:01:59 +08:00
parent abcd24240f
commit 8468e29818
2 changed files with 17 additions and 4 deletions

View File

@ -10,14 +10,13 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
use rand::{thread_rng, Rng};
use super::model::{DQNModel, DQNModelConfig};
const EXPLORE_RATE: f32 = 0.8;
pub struct App<'a> {
model: DQNModel<Backend>,
device: WgpuDevice,
last_state_action: Option<(Info<'a>, Action)>,
#[cfg(feature = "train")]
outlet: BufWriter<File>,
explore_rate: f32,
}
impl<'a> App<'a> {
@ -33,6 +32,17 @@ impl<'a> App<'a> {
&device,
)
.unwrap();
#[cfg(feature = "train")]
let explore_rate = std::env::var("EPSILON")
.map(|x| {
let n: usize = x.parse().ok()?;
Some(1.0 / (n as f32 + 2.0).log2() - 0.03)
})
.into_iter()
.flatten()
.next()
.unwrap_or(0.4);
Self {
model,
device,
@ -45,6 +55,8 @@ impl<'a> App<'a> {
.open(format!("{model_path}/dataset"))
.unwrap(),
),
#[cfg(feature = "train")]
explore_rate,
}
}
#[cfg(feature = "train")]
@ -64,7 +76,7 @@ impl<'a> App<'a> {
bincode::serialize_into(&mut self.outlet, &item).unwrap();
}
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
let action = match thread_rng().gen_ratio((4096.0 * self.explore_rate) as u32, 4096) {
true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
0 => Action::TurnRight,
1 => Action::TurnLeft,

View File

@ -9,7 +9,8 @@ for i in $(seq 1 $1);
do
echo "epoch $i"
cd TankMan
timeout 240 python -m mlgame -f 3000 \
export EPSILON=$i
timeout 200 python -m mlgame -f 3000 \
-i ../ml/collect1.py -i ../ml/collect2.py \
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \
--frame_limit 1000 > /dev/null