reduce feature

This commit is contained in:
Eason 2024-06-12 17:55:10 +08:00
parent 1a301178f8
commit 9e641b5f09
4 changed files with 44 additions and 43 deletions

View File

@ -9,10 +9,10 @@
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3, 3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3, 3,0,0,0,0,0,0,0,0,3,3,3,0,0,0,0,0,1,0,3,
3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,2,0,0,0,0,0,0,3,3,3,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3, 3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3, 3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,

View File

@ -10,7 +10,7 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
use rand::{thread_rng, Rng}; use rand::{thread_rng, Rng};
use super::model::{DQNModel, DQNModelConfig}; use super::model::{DQNModel, DQNModelConfig};
const EXPLORE_RATE: f32 = 0.2; const EXPLORE_RATE: f32 = 0.8;
pub struct App<'a> { pub struct App<'a> {
model: DQNModel<Backend>, model: DQNModel<Backend>,
@ -65,15 +65,13 @@ impl<'a> App<'a> {
} }
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) { let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
true => match thread_rng().gen_range(0..ACTION_SPACE as i32) { true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) {
0 => Action::Forward, 0 => Action::TurnRight,
1 => Action::Backward, 1 => Action::TurnLeft,
2 => Action::TurnRight, 2 => Action::AimRight,
3 => Action::TurnLeft, 3 => Action::AimLeft,
4 => Action::AimRight, 4 => Action::Shoot,
5 => Action::AimLeft, _ => Action::Forward,
6 => Action::Shoot,
_ => unreachable!("Invalid action"),
}, },
false => self.predict_action(state), false => self.predict_action(state),
}; };
@ -85,6 +83,14 @@ impl<'a> App<'a> {
pub fn predict_action(&self, state: &Info) -> Action { pub fn predict_action(&self, state: &Info) -> Action {
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size] let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
let ans = self.model.forward(input); let ans = self.model.forward(input);
ans.argmax(1).into_scalar().try_into().unwrap() match ans.argmax(1).into_scalar(){
0 => Action::TurnRight,
1 => Action::TurnLeft,
2 => Action::AimRight,
3 => Action::AimLeft,
4 => Action::Shoot,
5 => Action::Forward,
_ => unreachable!("Invalid action"),
}
} }
} }

View File

@ -5,8 +5,8 @@ use burn::tensor::{backend::Backend, Tensor};
use crate::ffi::prelude::*; use crate::ffi::prelude::*;
pub const FEATRUE_SPACE: usize = 10; pub const FEATRUE_SPACE: usize = 7;
pub const ACTION_SPACE: usize = 7; pub const ACTION_SPACE: usize = 6;
#[derive(PartialEq, Default)] #[derive(PartialEq, Default)]
struct Polar { struct Polar {
@ -127,20 +127,15 @@ impl<'a> Info<'a> {
let angle = self.player.get_angle(); let angle = self.player.get_angle();
let gun_angle = self.player.get_gun_angle(); let gun_angle = self.player.get_gun_angle();
let feature = [ [
normalize_angle(target.angle - angle).tanh(), normalize_angle(target.angle - angle).tanh(),
normalize_angle(target.angle - angle + PI).tanh(),
normalize_angle(bullet.angle - angle).tanh(),
(target.distance + 1.0).log2(),
(wall.distance - target.distance).tanh(), (wall.distance - target.distance).tanh(),
(bullet.distance + 1.0).log2(), (self.player.power as f32).tanh(),
(wall.distance + 1.0).log2(),
(emeny.distance + 1.0).log2(),
normalize_angle(emeny.angle - gun_angle).tanh(), normalize_angle(emeny.angle - gun_angle).tanh(),
normalize_angle(wall.angle - gun_angle).tanh(), normalize_angle(wall.angle - gun_angle).tanh(),
(self.player.oil - 40.0).tanh(), ]
(self.player.power as f32 - 7.0).tanh(),
];
feature
} }
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> { pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature(); let feature = self.into_feature();

View File

@ -11,20 +11,20 @@ pub enum Action {
None = 7, None = 7,
} }
impl TryFrom<i32> for Action { // impl TryFrom<i32> for Action {
type Error = &'static str; // type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> { // fn try_from(value: i32) -> Result<Self, Self::Error> {
match value { // match value {
0 => Ok(Action::Forward), // 0 => Ok(Action::Forward),
1 => Ok(Action::Backward), // 1 => Ok(Action::Backward),
2 => Ok(Action::TurnRight), // 2 => Ok(Action::TurnRight),
3 => Ok(Action::TurnLeft), // 3 => Ok(Action::TurnLeft),
4 => Ok(Action::AimRight), // 4 => Ok(Action::AimRight),
5 => Ok(Action::AimLeft), // 5 => Ok(Action::AimLeft),
6 => Ok(Action::Shoot), // 6 => Ok(Action::Shoot),
7 => Ok(Action::None), // 7 => Ok(Action::None),
_ => Err("Invalid action"), // _ => Err("Invalid action"),
} // }
} // }
} // }