reduce feature
This commit is contained in:
parent
1a301178f8
commit
9e641b5f09
|
@ -9,10 +9,10 @@
|
|||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3,
|
||||
3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,3,3,3,0,0,0,0,0,1,0,3,
|
||||
3,0,2,0,0,0,0,0,0,3,3,3,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||
|
|
|
@ -10,7 +10,7 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
|
|||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::model::{DQNModel, DQNModelConfig};
|
||||
const EXPLORE_RATE: f32 = 0.2;
|
||||
const EXPLORE_RATE: f32 = 0.8;
|
||||
|
||||
pub struct App<'a> {
|
||||
model: DQNModel<Backend>,
|
||||
|
@ -65,15 +65,13 @@ impl<'a> App<'a> {
|
|||
}
|
||||
|
||||
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
||||
true => match thread_rng().gen_range(0..ACTION_SPACE as i32) {
|
||||
0 => Action::Forward,
|
||||
1 => Action::Backward,
|
||||
2 => Action::TurnRight,
|
||||
3 => Action::TurnLeft,
|
||||
4 => Action::AimRight,
|
||||
5 => Action::AimLeft,
|
||||
6 => Action::Shoot,
|
||||
_ => unreachable!("Invalid action"),
|
||||
true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) {
|
||||
0 => Action::TurnRight,
|
||||
1 => Action::TurnLeft,
|
||||
2 => Action::AimRight,
|
||||
3 => Action::AimLeft,
|
||||
4 => Action::Shoot,
|
||||
_ => Action::Forward,
|
||||
},
|
||||
false => self.predict_action(state),
|
||||
};
|
||||
|
@ -85,6 +83,14 @@ impl<'a> App<'a> {
|
|||
pub fn predict_action(&self, state: &Info) -> Action {
|
||||
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
|
||||
let ans = self.model.forward(input);
|
||||
ans.argmax(1).into_scalar().try_into().unwrap()
|
||||
match ans.argmax(1).into_scalar(){
|
||||
0 => Action::TurnRight,
|
||||
1 => Action::TurnLeft,
|
||||
2 => Action::AimRight,
|
||||
3 => Action::AimLeft,
|
||||
4 => Action::Shoot,
|
||||
5 => Action::Forward,
|
||||
_ => unreachable!("Invalid action"),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
|
|
@ -5,8 +5,8 @@ use burn::tensor::{backend::Backend, Tensor};
|
|||
|
||||
use crate::ffi::prelude::*;
|
||||
|
||||
pub const FEATRUE_SPACE: usize = 10;
|
||||
pub const ACTION_SPACE: usize = 7;
|
||||
pub const FEATRUE_SPACE: usize = 7;
|
||||
pub const ACTION_SPACE: usize = 6;
|
||||
|
||||
#[derive(PartialEq, Default)]
|
||||
struct Polar {
|
||||
|
@ -127,20 +127,15 @@ impl<'a> Info<'a> {
|
|||
let angle = self.player.get_angle();
|
||||
let gun_angle = self.player.get_gun_angle();
|
||||
|
||||
let feature = [
|
||||
[
|
||||
normalize_angle(target.angle - angle).tanh(),
|
||||
normalize_angle(target.angle - angle + PI).tanh(),
|
||||
normalize_angle(bullet.angle - angle).tanh(),
|
||||
(target.distance + 1.0).log2(),
|
||||
(wall.distance - target.distance).tanh(),
|
||||
(bullet.distance + 1.0).log2(),
|
||||
(self.player.power as f32).tanh(),
|
||||
(wall.distance + 1.0).log2(),
|
||||
(emeny.distance + 1.0).log2(),
|
||||
normalize_angle(emeny.angle - gun_angle).tanh(),
|
||||
normalize_angle(wall.angle - gun_angle).tanh(),
|
||||
(self.player.oil - 40.0).tanh(),
|
||||
(self.player.power as f32 - 7.0).tanh(),
|
||||
];
|
||||
|
||||
feature
|
||||
]
|
||||
}
|
||||
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
||||
let feature = self.into_feature();
|
||||
|
|
|
@ -11,20 +11,20 @@ pub enum Action {
|
|||
None = 7,
|
||||
}
|
||||
|
||||
impl TryFrom<i32> for Action {
|
||||
type Error = &'static str;
|
||||
// impl TryFrom<i32> for Action {
|
||||
// type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0 => Ok(Action::Forward),
|
||||
1 => Ok(Action::Backward),
|
||||
2 => Ok(Action::TurnRight),
|
||||
3 => Ok(Action::TurnLeft),
|
||||
4 => Ok(Action::AimRight),
|
||||
5 => Ok(Action::AimLeft),
|
||||
6 => Ok(Action::Shoot),
|
||||
7 => Ok(Action::None),
|
||||
_ => Err("Invalid action"),
|
||||
}
|
||||
}
|
||||
}
|
||||
// fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
// match value {
|
||||
// 0 => Ok(Action::Forward),
|
||||
// 1 => Ok(Action::Backward),
|
||||
// 2 => Ok(Action::TurnRight),
|
||||
// 3 => Ok(Action::TurnLeft),
|
||||
// 4 => Ok(Action::AimRight),
|
||||
// 5 => Ok(Action::AimLeft),
|
||||
// 6 => Ok(Action::Shoot),
|
||||
// 7 => Ok(Action::None),
|
||||
// _ => Err("Invalid action"),
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
|
Loading…
Reference in New Issue