forked from easonabc-public/paia-hw5
193 lines
5.3 KiB
Rust
193 lines
5.3 KiB
Rust
//! Feature extraction and reward calculation for DQN
|
|
use std::f32::consts::PI;
|
|
|
|
use burn::tensor::{backend::Backend, Tensor};
|
|
|
|
use crate::ffi::prelude::*;
|
|
|
|
pub const FEATRUE_SPACE: usize = 10;
|
|
pub const ACTION_SPACE: usize = 7;
|
|
|
|
#[derive(PartialEq, Default)]
|
|
struct Polar {
|
|
angle: f32,
|
|
distance: f32,
|
|
}
|
|
|
|
impl Polar {
|
|
pub fn clip(&self) -> Self {
|
|
Polar {
|
|
angle: self.angle,
|
|
distance: self.distance.min(1e6).max(0.0),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Eq for Polar {}
|
|
|
|
impl Ord for Polar {
|
|
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
|
self.distance.partial_cmp(&other.distance).unwrap()
|
|
}
|
|
}
|
|
|
|
impl PartialOrd for Polar {
|
|
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
|
self.distance.partial_cmp(&other.distance)
|
|
}
|
|
}
|
|
|
|
fn normalize_angle(mut angle: f32) -> f32 {
|
|
while angle < -PI {
|
|
angle += 2.0 * PI;
|
|
}
|
|
while angle >= PI {
|
|
angle -= 2.0 * PI;
|
|
}
|
|
angle
|
|
}
|
|
|
|
impl Player {
|
|
fn to_pos(&self) -> (i32, i32) {
|
|
(self.x, self.y)
|
|
}
|
|
fn center(&self, x: i32, y: i32) -> Polar {
|
|
let dx = x - self.x;
|
|
let dy = y - self.y;
|
|
let angle = (dy as f32).atan2(dx as f32);
|
|
let distance = (dx.pow(2) + dy.pow(2)) as f32;
|
|
Polar { angle, distance }
|
|
}
|
|
fn closest(&self, others: impl Iterator<Item = (i32, i32)>) -> Polar {
|
|
others
|
|
.map(|(x, y)| self.center(x, y))
|
|
.min()
|
|
.unwrap_or_default()
|
|
}
|
|
fn get_angle(&self) -> f32 {
|
|
(180 - self.angle) as f32 / 360.0 * 2.0 * PI
|
|
}
|
|
fn get_gun_angle(&self) -> f32 {
|
|
self.gun_angle as f32 / 360.0 * 2.0 * PI
|
|
}
|
|
}
|
|
|
|
#[derive(Debug)]
|
|
enum Target {
|
|
Oil,
|
|
Bullet,
|
|
Enemy,
|
|
}
|
|
|
|
impl Target {
|
|
fn get_pos(&self, info: &Info) -> Polar {
|
|
match self {
|
|
Target::Oil => info
|
|
.player
|
|
.closest(info.oil_stations.iter().map(Station::to_pos)),
|
|
Target::Bullet => info
|
|
.player
|
|
.closest(info.bullet_stations.iter().map(Station::to_pos)),
|
|
Target::Enemy => info.player.closest(info.enemies.iter().map(Player::to_pos)),
|
|
}
|
|
}
|
|
fn reach(&self, last: &Info, current: &Info) -> bool {
|
|
match self {
|
|
Target::Oil => last.player.oil > current.player.oil,
|
|
Target::Bullet => last.player.power > current.player.power,
|
|
Target::Enemy => false,
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Station {
|
|
fn to_pos(&self) -> (i32, i32) {
|
|
(self.x as i32, self.y as i32)
|
|
}
|
|
}
|
|
|
|
impl Wall {
|
|
fn to_pos(&self) -> (i32, i32) {
|
|
(self.x, self.y)
|
|
}
|
|
}
|
|
|
|
impl<'a> Info<'a> {
|
|
pub fn into_feature(&self) -> [f32; FEATRUE_SPACE] {
|
|
let emeny = self.player.closest(self.enemies.iter().map(Player::to_pos));
|
|
let wall = self
|
|
.player
|
|
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
|
|
let bullet = self
|
|
.player
|
|
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
|
|
|
|
let target = self.get_target().get_pos(self).clip();
|
|
|
|
let angle = self.player.get_angle();
|
|
let gun_angle = self.player.get_gun_angle();
|
|
|
|
let feature = [
|
|
normalize_angle(target.angle - angle).tanh(),
|
|
normalize_angle(target.angle - angle + PI).tanh(),
|
|
normalize_angle(bullet.angle - angle).tanh(),
|
|
(target.distance + 1.0).log2(),
|
|
(wall.distance - target.distance).tanh(),
|
|
(bullet.distance + 1.0).log2(),
|
|
normalize_angle(emeny.angle - gun_angle).tanh(),
|
|
normalize_angle(wall.angle - gun_angle).tanh(),
|
|
(self.player.oil - 40.0).tanh(),
|
|
(self.player.power as f32 - 7.0).tanh(),
|
|
];
|
|
|
|
feature
|
|
}
|
|
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
|
let feature = self.into_feature();
|
|
Tensor::from_floats(feature, device)
|
|
}
|
|
fn get_target(&self) -> Target {
|
|
if self.player.oil < 40.0 {
|
|
Target::Oil
|
|
} else if self.player.power > 7 {
|
|
Target::Enemy
|
|
} else {
|
|
Target::Bullet
|
|
}
|
|
}
|
|
pub fn get_reward(&self, next: &Self, action: Action) -> f32 {
|
|
let same_position = self.player.x == next.player.x && self.player.y == next.player.y;
|
|
let mut reward = -2.3;
|
|
reward += match action {
|
|
Action::Forward | Action::Backward if same_position => -8.0,
|
|
Action::Shoot => match next.player.power > 7 {
|
|
true => 2.0,
|
|
false => -2.0,
|
|
},
|
|
_ => 0.0,
|
|
};
|
|
|
|
let target = self.get_target();
|
|
|
|
if target.reach(self, next) {
|
|
reward += 15.0;
|
|
} else {
|
|
let previous_target_position = target.get_pos(self);
|
|
let next_target_position = target.get_pos(next);
|
|
|
|
reward += match previous_target_position.cmp(&next_target_position) {
|
|
std::cmp::Ordering::Less => -5.0,
|
|
std::cmp::Ordering::Equal => 0.0,
|
|
std::cmp::Ordering::Greater => 5.8,
|
|
};
|
|
}
|
|
|
|
reward
|
|
+ match next.player.score - self.player.score {
|
|
x if x > 2 => 20.0,
|
|
x if x > 0 => 10.0, // too high, tank my ignore power station
|
|
_ => -1.0,
|
|
}
|
|
}
|
|
}
|