This commit is contained in:
Eason
2024-06-12 16:09:12 +08:00
commit d8b781011b
156 changed files with 26489 additions and 0 deletions

View File

@ -0,0 +1,192 @@
//! Feature extraction and reward calculation for DQN
use std::f32::consts::PI;
use burn::tensor::{backend::Backend, Tensor};
use crate::ffi::prelude::*;
pub const FEATRUE_SPACE: usize = 10;
pub const ACTION_SPACE: usize = 7;
#[derive(PartialEq, Default)]
struct Polar {
angle: f32,
distance: f32,
}
impl Polar {
pub fn clip(&self) -> Self {
Polar {
angle: self.angle,
distance: self.distance.min(1e6).max(0.0),
}
}
}
impl Eq for Polar {}
impl Ord for Polar {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.partial_cmp(&other.distance).unwrap()
}
}
impl PartialOrd for Polar {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
fn normalize_angle(mut angle: f32) -> f32 {
while angle < -PI {
angle += 2.0 * PI;
}
while angle >= PI {
angle -= 2.0 * PI;
}
angle
}
impl Player {
fn to_pos(&self) -> (i32, i32) {
(self.x, self.y)
}
fn center(&self, x: i32, y: i32) -> Polar {
let dx = x - self.x;
let dy = y - self.y;
let angle = (dy as f32).atan2(dx as f32);
let distance = (dx.pow(2) + dy.pow(2)) as f32;
Polar { angle, distance }
}
fn closest(&self, others: impl Iterator<Item = (i32, i32)>) -> Polar {
others
.map(|(x, y)| self.center(x, y))
.min()
.unwrap_or_default()
}
fn get_angle(&self) -> f32 {
(180 - self.angle) as f32 / 360.0 * 2.0 * PI
}
fn get_gun_angle(&self) -> f32 {
self.gun_angle as f32 / 360.0 * 2.0 * PI
}
}
#[derive(Debug)]
enum Target {
Oil,
Bullet,
Enemy,
}
impl Target {
fn get_pos(&self, info: &Info) -> Polar {
match self {
Target::Oil => info
.player
.closest(info.oil_stations.iter().map(Station::to_pos)),
Target::Bullet => info
.player
.closest(info.bullet_stations.iter().map(Station::to_pos)),
Target::Enemy => info.player.closest(info.enemies.iter().map(Player::to_pos)),
}
}
fn reach(&self, last: &Info, current: &Info) -> bool {
match self {
Target::Oil => last.player.oil > current.player.oil,
Target::Bullet => last.player.power > current.player.power,
Target::Enemy => false,
}
}
}
impl Station {
fn to_pos(&self) -> (i32, i32) {
(self.x as i32, self.y as i32)
}
}
impl Wall {
fn to_pos(&self) -> (i32, i32) {
(self.x, self.y)
}
}
impl<'a> Info<'a> {
pub fn into_feature(&self) -> [f32; FEATRUE_SPACE] {
let emeny = self.player.closest(self.enemies.iter().map(Player::to_pos));
let wall = self
.player
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
let bullet = self
.player
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
let target = self.get_target().get_pos(self).clip();
let angle = self.player.get_angle();
let gun_angle = self.player.get_gun_angle();
let feature = [
normalize_angle(target.angle - angle).tanh(),
normalize_angle(target.angle - angle + PI).tanh(),
normalize_angle(bullet.angle - angle).tanh(),
(target.distance + 1.0).log2(),
(wall.distance - target.distance).tanh(),
(bullet.distance + 1.0).log2(),
normalize_angle(emeny.angle - gun_angle).tanh(),
normalize_angle(wall.angle - gun_angle).tanh(),
(self.player.oil - 40.0).tanh(),
(self.player.power as f32 - 7.0).tanh(),
];
feature
}
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature();
Tensor::from_floats(feature, device)
}
fn get_target(&self) -> Target {
if self.player.oil < 40.0 {
Target::Oil
} else if self.player.power > 7 {
Target::Enemy
} else {
Target::Bullet
}
}
pub fn get_reward(&self, next: &Self, action: Action) -> f32 {
let same_position = self.player.x == next.player.x && self.player.y == next.player.y;
let mut reward = -2.3;
reward += match action {
Action::Forward | Action::Backward if same_position => -8.0,
Action::Shoot => match next.player.power > 7 {
true => 2.0,
false => -2.0,
},
_ => 0.0,
};
let target = self.get_target();
if target.reach(self, next) {
reward += 15.0;
} else {
let previous_target_position = target.get_pos(self);
let next_target_position = target.get_pos(next);
reward += match previous_target_position.cmp(&next_target_position) {
std::cmp::Ordering::Less => -5.0,
std::cmp::Ordering::Equal => 0.0,
std::cmp::Ordering::Greater => 5.8,
};
}
reward
+ match next.player.score - self.player.score {
x if x > 2 => 20.0,
x if x > 0 => 10.0, // too high, tank my ignore power station
_ => -1.0,
}
}
}