From 1a301178f81b7701fac579ef35fcb4188051360f Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:28:15 +0800 Subject: [PATCH] use action space --- tank-rust/src/dqn/collect.rs | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tank-rust/src/dqn/collect.rs b/tank-rust/src/dqn/collect.rs index 27a3455..36be431 100644 --- a/tank-rust/src/dqn/collect.rs +++ b/tank-rust/src/dqn/collect.rs @@ -4,12 +4,13 @@ use std::{ }; use super::dataset::TankItem; +use super::feature::ACTION_SPACE; use crate::{ffi::prelude::*, Backend}; use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecorder}; use rand::{thread_rng, Rng}; use super::model::{DQNModel, DQNModelConfig}; -const EXPLORE_RATE: f32 = 0.4; +const EXPLORE_RATE: f32 = 0.2; pub struct App<'a> { model: DQNModel, @@ -64,7 +65,7 @@ impl<'a> App<'a> { } let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { - true => match thread_rng().gen_range(0..15 as i32) { + true => match thread_rng().gen_range(0..ACTION_SPACE as i32) { 0 => Action::Forward, 1 => Action::Backward, 2 => Action::TurnRight, @@ -72,8 +73,7 @@ impl<'a> App<'a> { 4 => Action::AimRight, 5 => Action::AimLeft, 6 => Action::Shoot, - 7 => Action::TurnRight, - _ => Action::Forward, + _ => unreachable!("Invalid action"), }, false => self.predict_action(state), };