add epsilon
This commit is contained in:
parent
abcd24240f
commit
8468e29818
|
@ -10,14 +10,13 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
|
||||||
use super::model::{DQNModel, DQNModelConfig};
|
use super::model::{DQNModel, DQNModelConfig};
|
||||||
const EXPLORE_RATE: f32 = 0.8;
|
|
||||||
|
|
||||||
pub struct App<'a> {
|
pub struct App<'a> {
|
||||||
model: DQNModel<Backend>,
|
model: DQNModel<Backend>,
|
||||||
device: WgpuDevice,
|
device: WgpuDevice,
|
||||||
last_state_action: Option<(Info<'a>, Action)>,
|
last_state_action: Option<(Info<'a>, Action)>,
|
||||||
#[cfg(feature = "train")]
|
#[cfg(feature = "train")]
|
||||||
outlet: BufWriter<File>,
|
outlet: BufWriter<File>,
|
||||||
|
explore_rate: f32,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl<'a> App<'a> {
|
impl<'a> App<'a> {
|
||||||
|
@ -33,6 +32,17 @@ impl<'a> App<'a> {
|
||||||
&device,
|
&device,
|
||||||
)
|
)
|
||||||
.unwrap();
|
.unwrap();
|
||||||
|
#[cfg(feature = "train")]
|
||||||
|
let explore_rate = std::env::var("EPSILON")
|
||||||
|
.map(|x| {
|
||||||
|
let n: usize = x.parse().ok()?;
|
||||||
|
Some(1.0 / (n as f32 + 2.0).log2() - 0.03)
|
||||||
|
})
|
||||||
|
.into_iter()
|
||||||
|
.flatten()
|
||||||
|
.next()
|
||||||
|
.unwrap_or(0.4);
|
||||||
|
|
||||||
Self {
|
Self {
|
||||||
model,
|
model,
|
||||||
device,
|
device,
|
||||||
|
@ -45,6 +55,8 @@ impl<'a> App<'a> {
|
||||||
.open(format!("{model_path}/dataset"))
|
.open(format!("{model_path}/dataset"))
|
||||||
.unwrap(),
|
.unwrap(),
|
||||||
),
|
),
|
||||||
|
#[cfg(feature = "train")]
|
||||||
|
explore_rate,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
#[cfg(feature = "train")]
|
#[cfg(feature = "train")]
|
||||||
|
@ -64,7 +76,7 @@ impl<'a> App<'a> {
|
||||||
bincode::serialize_into(&mut self.outlet, &item).unwrap();
|
bincode::serialize_into(&mut self.outlet, &item).unwrap();
|
||||||
}
|
}
|
||||||
|
|
||||||
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
let action = match thread_rng().gen_ratio((4096.0 * self.explore_rate) as u32, 4096) {
|
||||||
true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
|
true => match thread_rng().gen_range(0..(ACTION_SPACE + 2) as i32) {
|
||||||
0 => Action::TurnRight,
|
0 => Action::TurnRight,
|
||||||
1 => Action::TurnLeft,
|
1 => Action::TurnLeft,
|
||||||
|
|
3
train.sh
3
train.sh
|
@ -9,7 +9,8 @@ for i in $(seq 1 $1);
|
||||||
do
|
do
|
||||||
echo "epoch $i"
|
echo "epoch $i"
|
||||||
cd TankMan
|
cd TankMan
|
||||||
timeout 240 python -m mlgame -f 3000 \
|
export EPSILON=$i
|
||||||
|
timeout 200 python -m mlgame -f 3000 \
|
||||||
-i ../ml/collect1.py -i ../ml/collect2.py \
|
-i ../ml/collect1.py -i ../ml/collect2.py \
|
||||||
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \
|
. --green_team_num 1 --blue_team_num 1 --is_manual 1 \
|
||||||
--frame_limit 1000 > /dev/null
|
--frame_limit 1000 > /dev/null
|
||||||
|
|
Loading…
Reference in New Issue