reduce feature
This commit is contained in:
parent
1a301178f8
commit
9e641b5f09
|
@ -9,10 +9,10 @@
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3,
|
3,0,0,0,0,5,0,0,0,0,0,0,0,0,0,4,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,1,0,3,
|
3,0,0,0,0,0,0,0,0,3,3,3,0,0,0,0,0,1,0,3,
|
||||||
3,0,2,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,2,0,0,0,0,0,0,3,3,3,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,3,0,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3,
|
3,0,0,0,4,0,0,0,0,0,0,0,0,0,5,0,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||||
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
3,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,3,
|
||||||
|
|
|
@ -10,7 +10,7 @@ use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecor
|
||||||
use rand::{thread_rng, Rng};
|
use rand::{thread_rng, Rng};
|
||||||
|
|
||||||
use super::model::{DQNModel, DQNModelConfig};
|
use super::model::{DQNModel, DQNModelConfig};
|
||||||
const EXPLORE_RATE: f32 = 0.2;
|
const EXPLORE_RATE: f32 = 0.8;
|
||||||
|
|
||||||
pub struct App<'a> {
|
pub struct App<'a> {
|
||||||
model: DQNModel<Backend>,
|
model: DQNModel<Backend>,
|
||||||
|
@ -65,15 +65,13 @@ impl<'a> App<'a> {
|
||||||
}
|
}
|
||||||
|
|
||||||
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
||||||
true => match thread_rng().gen_range(0..ACTION_SPACE as i32) {
|
true => match thread_rng().gen_range(0..(ACTION_SPACE+2) as i32) {
|
||||||
0 => Action::Forward,
|
0 => Action::TurnRight,
|
||||||
1 => Action::Backward,
|
1 => Action::TurnLeft,
|
||||||
2 => Action::TurnRight,
|
2 => Action::AimRight,
|
||||||
3 => Action::TurnLeft,
|
3 => Action::AimLeft,
|
||||||
4 => Action::AimRight,
|
4 => Action::Shoot,
|
||||||
5 => Action::AimLeft,
|
_ => Action::Forward,
|
||||||
6 => Action::Shoot,
|
|
||||||
_ => unreachable!("Invalid action"),
|
|
||||||
},
|
},
|
||||||
false => self.predict_action(state),
|
false => self.predict_action(state),
|
||||||
};
|
};
|
||||||
|
@ -85,6 +83,14 @@ impl<'a> App<'a> {
|
||||||
pub fn predict_action(&self, state: &Info) -> Action {
|
pub fn predict_action(&self, state: &Info) -> Action {
|
||||||
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
|
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
|
||||||
let ans = self.model.forward(input);
|
let ans = self.model.forward(input);
|
||||||
ans.argmax(1).into_scalar().try_into().unwrap()
|
match ans.argmax(1).into_scalar(){
|
||||||
|
0 => Action::TurnRight,
|
||||||
|
1 => Action::TurnLeft,
|
||||||
|
2 => Action::AimRight,
|
||||||
|
3 => Action::AimLeft,
|
||||||
|
4 => Action::Shoot,
|
||||||
|
5 => Action::Forward,
|
||||||
|
_ => unreachable!("Invalid action"),
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -5,8 +5,8 @@ use burn::tensor::{backend::Backend, Tensor};
|
||||||
|
|
||||||
use crate::ffi::prelude::*;
|
use crate::ffi::prelude::*;
|
||||||
|
|
||||||
pub const FEATRUE_SPACE: usize = 10;
|
pub const FEATRUE_SPACE: usize = 7;
|
||||||
pub const ACTION_SPACE: usize = 7;
|
pub const ACTION_SPACE: usize = 6;
|
||||||
|
|
||||||
#[derive(PartialEq, Default)]
|
#[derive(PartialEq, Default)]
|
||||||
struct Polar {
|
struct Polar {
|
||||||
|
@ -127,20 +127,15 @@ impl<'a> Info<'a> {
|
||||||
let angle = self.player.get_angle();
|
let angle = self.player.get_angle();
|
||||||
let gun_angle = self.player.get_gun_angle();
|
let gun_angle = self.player.get_gun_angle();
|
||||||
|
|
||||||
let feature = [
|
[
|
||||||
normalize_angle(target.angle - angle).tanh(),
|
normalize_angle(target.angle - angle).tanh(),
|
||||||
normalize_angle(target.angle - angle + PI).tanh(),
|
|
||||||
normalize_angle(bullet.angle - angle).tanh(),
|
|
||||||
(target.distance + 1.0).log2(),
|
|
||||||
(wall.distance - target.distance).tanh(),
|
(wall.distance - target.distance).tanh(),
|
||||||
(bullet.distance + 1.0).log2(),
|
(self.player.power as f32).tanh(),
|
||||||
|
(wall.distance + 1.0).log2(),
|
||||||
|
(emeny.distance + 1.0).log2(),
|
||||||
normalize_angle(emeny.angle - gun_angle).tanh(),
|
normalize_angle(emeny.angle - gun_angle).tanh(),
|
||||||
normalize_angle(wall.angle - gun_angle).tanh(),
|
normalize_angle(wall.angle - gun_angle).tanh(),
|
||||||
(self.player.oil - 40.0).tanh(),
|
]
|
||||||
(self.player.power as f32 - 7.0).tanh(),
|
|
||||||
];
|
|
||||||
|
|
||||||
feature
|
|
||||||
}
|
}
|
||||||
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
||||||
let feature = self.into_feature();
|
let feature = self.into_feature();
|
||||||
|
|
|
@ -11,20 +11,20 @@ pub enum Action {
|
||||||
None = 7,
|
None = 7,
|
||||||
}
|
}
|
||||||
|
|
||||||
impl TryFrom<i32> for Action {
|
// impl TryFrom<i32> for Action {
|
||||||
type Error = &'static str;
|
// type Error = &'static str;
|
||||||
|
|
||||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
// fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||||
match value {
|
// match value {
|
||||||
0 => Ok(Action::Forward),
|
// 0 => Ok(Action::Forward),
|
||||||
1 => Ok(Action::Backward),
|
// 1 => Ok(Action::Backward),
|
||||||
2 => Ok(Action::TurnRight),
|
// 2 => Ok(Action::TurnRight),
|
||||||
3 => Ok(Action::TurnLeft),
|
// 3 => Ok(Action::TurnLeft),
|
||||||
4 => Ok(Action::AimRight),
|
// 4 => Ok(Action::AimRight),
|
||||||
5 => Ok(Action::AimLeft),
|
// 5 => Ok(Action::AimLeft),
|
||||||
6 => Ok(Action::Shoot),
|
// 6 => Ok(Action::Shoot),
|
||||||
7 => Ok(Action::None),
|
// 7 => Ok(Action::None),
|
||||||
_ => Err("Invalid action"),
|
// _ => Err("Invalid action"),
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
}
|
// }
|
||||||
|
|
Loading…
Reference in New Issue