reduce feature

This commit is contained in:
Eason 2024-06-12 17:55:10 +08:00
parent 1a301178f8
commit 9e641b5f09
4 changed files with 44 additions and 43 deletions

View File

@ -9,10 +9,10 @@
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3,
3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,3,3,3,0,0,0,0,0,1,0,3,
3,0,2,0,0,0,0,0,0,3,3,3,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,

View File

@ -10,7 +10,7 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
use rand::{thread_rng, Rng};
use super::model::{DQNModel, DQNModelConfig};
const EXPLORE_RATE: f32 = 0.2;
const EXPLORE_RATE: f32 = 0.8;
pub struct App<'a> {
model: DQNModel<Backend>,
@ -65,15 +65,13 @@ impl<'a> App<'a> {
}
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
true => match thread_rng().gen_range(0..ACTION_SPACE as i32) {
0 => Action::Forward,
1 => Action::Backward,
2 => Action::TurnRight,
3 => Action::TurnLeft,
4 => Action::AimRight,
5 => Action::AimLeft,
6 => Action::Shoot,
_ => unreachable!("Invalid action"),
true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) {
0 => Action::TurnRight,
1 => Action::TurnLeft,
2 => Action::AimRight,
3 => Action::AimLeft,
4 => Action::Shoot,
_ => Action::Forward,
},
false => self.predict_action(state),
};
@ -85,6 +83,14 @@ impl<'a> App<'a> {
pub fn predict_action(&self, state: &Info) -> Action {
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
let ans = self.model.forward(input);
ans.argmax(1).into_scalar().try_into().unwrap()
match ans.argmax(1).into_scalar(){
0 => Action::TurnRight,
1 => Action::TurnLeft,
2 => Action::AimRight,
3 => Action::AimLeft,
4 => Action::Shoot,
5 => Action::Forward,
_ => unreachable!("Invalid action"),
}
}
}

View File

@ -5,8 +5,8 @@ use burn::tensor::{backend::Backend, Tensor};
use crate::ffi::prelude::*;
pub const FEATRUE_SPACE: usize = 10;
pub const ACTION_SPACE: usize = 7;
pub const FEATRUE_SPACE: usize = 7;
pub const ACTION_SPACE: usize = 6;
#[derive(PartialEq, Default)]
struct Polar {
@ -127,20 +127,15 @@ impl<'a> Info<'a> {
let angle = self.player.get_angle();
let gun_angle = self.player.get_gun_angle();
let feature = [
[
normalize_angle(target.angle - angle).tanh(),
normalize_angle(target.angle - angle + PI).tanh(),
normalize_angle(bullet.angle - angle).tanh(),
(target.distance + 1.0).log2(),
(wall.distance - target.distance).tanh(),
(bullet.distance + 1.0).log2(),
(self.player.power as f32).tanh(),
(wall.distance + 1.0).log2(),
(emeny.distance + 1.0).log2(),
normalize_angle(emeny.angle - gun_angle).tanh(),
normalize_angle(wall.angle - gun_angle).tanh(),
(self.player.oil - 40.0).tanh(),
(self.player.power as f32 - 7.0).tanh(),
];
feature
]
}
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature();

View File

@ -11,20 +11,20 @@ pub enum Action {
None = 7,
}
impl TryFrom<i32> for Action {
type Error = &'static str;
// impl TryFrom<i32> for Action {
// type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
0 => Ok(Action::Forward),
1 => Ok(Action::Backward),
2 => Ok(Action::TurnRight),
3 => Ok(Action::TurnLeft),
4 => Ok(Action::AimRight),
5 => Ok(Action::AimLeft),
6 => Ok(Action::Shoot),
7 => Ok(Action::None),
_ => Err("Invalid action"),
}
}
}
// fn try_from(value: i32) -> Result<Self, Self::Error> {
// match value {
// 0 => Ok(Action::Forward),
// 1 => Ok(Action::Backward),
// 2 => Ok(Action::TurnRight),
// 3 => Ok(Action::TurnLeft),
// 4 => Ok(Action::AimRight),
// 5 => Ok(Action::AimLeft),
// 6 => Ok(Action::Shoot),
// 7 => Ok(Action::None),
// _ => Err("Invalid action"),
// }
// }
// }