From 9e641b5f09cdf27d541e4a83c5a339acdb7cd3d3 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Wed, 12 Jun 2024 17:55:10 +0800 Subject: [PATCH] reduce feature --- TankMan/asset/maps/map_1_v_1.tmx | 8 ++++---- tank-rust/src/dqn/collect.rs | 28 +++++++++++++++++----------- tank-rust/src/dqn/feature.rs | 19 +++++++------------ tank-rust/src/ffi/action.rs | 32 ++++++++++++++++---------------- 4 files changed, 44 insertions(+), 43 deletions(-) diff --git a/TankMan/asset/maps/map_1_v_1.tmx b/TankMan/asset/maps/map_1_v_1.tmx index 0bf0636..d89caf2 100644 --- a/TankMan/asset/maps/map_1_v_1.tmx +++ b/TankMan/asset/maps/map_1_v_1.tmx @@ -9,10 +9,10 @@ 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3, -3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, -3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3, -3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, -3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, +3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3, +3,0,0,0,0,0,0,0,0,3,3,3,0,0,0,0,0,1,0,3, +3,0,2,0,0,0,0,0,0,3,3,3,0,0,0,0,0,0,0,3, +3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3, 3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, diff --git a/tank-rust/src/dqn/collect.rs b/tank-rust/src/dqn/collect.rs index 36be431..65eaf9a 100644 --- a/tank-rust/src/dqn/collect.rs +++ b/tank-rust/src/dqn/collect.rs @@ -10,7 +10,7 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor use rand::{thread_rng, Rng}; use super::model::{DQNModel, DQNModelConfig}; -const EXPLORE_RATE: f32 = 0.2; +const EXPLORE_RATE: f32 = 0.8; pub struct App<'a> { model: DQNModel, @@ -65,15 +65,13 @@ impl<'a> App<'a> { } let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { - true => match thread_rng().gen_range(0..ACTION_SPACE as i32) { - 0 => Action::Forward, - 1 => Action::Backward, - 2 => Action::TurnRight, - 3 => Action::TurnLeft, - 4 => Action::AimRight, - 5 => Action::AimLeft, - 6 => Action::Shoot, - _ => unreachable!("Invalid action"), + true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) { + 0 => Action::TurnRight, + 1 => Action::TurnLeft, + 2 => Action::AimRight, + 3 => Action::AimLeft, + 4 => Action::Shoot, + _ => Action::Forward, }, false => self.predict_action(state), }; @@ -85,6 +83,14 @@ impl<'a> App<'a> { pub fn predict_action(&self, state: &Info) -> Action { let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size] let ans = self.model.forward(input); - ans.argmax(1).into_scalar().try_into().unwrap() + match ans.argmax(1).into_scalar(){ + 0 => Action::TurnRight, + 1 => Action::TurnLeft, + 2 => Action::AimRight, + 3 => Action::AimLeft, + 4 => Action::Shoot, + 5 => Action::Forward, + _ => unreachable!("Invalid action"), + } } } diff --git a/tank-rust/src/dqn/feature.rs b/tank-rust/src/dqn/feature.rs index 30dff81..ea80027 100644 --- a/tank-rust/src/dqn/feature.rs +++ b/tank-rust/src/dqn/feature.rs @@ -5,8 +5,8 @@ use burn::tensor::{backend::Backend, Tensor}; use crate::ffi::prelude::*; -pub const FEATRUE_SPACE: usize = 10; -pub const ACTION_SPACE: usize = 7; +pub const FEATRUE_SPACE: usize = 7; +pub const ACTION_SPACE: usize = 6; #[derive(PartialEq, Default)] struct Polar { @@ -127,20 +127,15 @@ impl<'a> Info<'a> { let angle = self.player.get_angle(); let gun_angle = self.player.get_gun_angle(); - let feature = [ + [ normalize_angle(target.angle - angle).tanh(), - normalize_angle(target.angle - angle + PI).tanh(), - normalize_angle(bullet.angle - angle).tanh(), - (target.distance + 1.0).log2(), (wall.distance - target.distance).tanh(), - (bullet.distance + 1.0).log2(), + (self.player.power as f32).tanh(), + (wall.distance + 1.0).log2(), + (emeny.distance + 1.0).log2(), normalize_angle(emeny.angle - gun_angle).tanh(), normalize_angle(wall.angle - gun_angle).tanh(), - (self.player.oil - 40.0).tanh(), - (self.player.power as f32 - 7.0).tanh(), - ]; - - feature + ] } pub fn into_feature_tensor(&self, device: &B::Device) -> Tensor { let feature = self.into_feature(); diff --git a/tank-rust/src/ffi/action.rs b/tank-rust/src/ffi/action.rs index 045d331..d3a5b17 100644 --- a/tank-rust/src/ffi/action.rs +++ b/tank-rust/src/ffi/action.rs @@ -11,20 +11,20 @@ pub enum Action { None = 7, } -impl TryFrom for Action { - type Error = &'static str; +// impl TryFrom for Action { +// type Error = &'static str; - fn try_from(value: i32) -> Result { - match value { - 0 => Ok(Action::Forward), - 1 => Ok(Action::Backward), - 2 => Ok(Action::TurnRight), - 3 => Ok(Action::TurnLeft), - 4 => Ok(Action::AimRight), - 5 => Ok(Action::AimLeft), - 6 => Ok(Action::Shoot), - 7 => Ok(Action::None), - _ => Err("Invalid action"), - } - } -} +// fn try_from(value: i32) -> Result { +// match value { +// 0 => Ok(Action::Forward), +// 1 => Ok(Action::Backward), +// 2 => Ok(Action::TurnRight), +// 3 => Ok(Action::TurnLeft), +// 4 => Ok(Action::AimRight), +// 5 => Ok(Action::AimLeft), +// 6 => Ok(Action::Shoot), +// 7 => Ok(Action::None), +// _ => Err("Invalid action"), +// } +// } +// }