reinit
This commit is contained in:
90
tank-rust/src/dqn/collect.rs
Normal file
90
tank-rust/src/dqn/collect.rs
Normal file
@ -0,0 +1,90 @@
|
||||
use std::{
|
||||
fs::{File, OpenOptions},
|
||||
io::{BufWriter, Write},
|
||||
};
|
||||
|
||||
use super::dataset::TankItem;
|
||||
use crate::{ffi::prelude::*, Backend};
|
||||
use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecorder};
|
||||
use rand::{thread_rng, Rng};
|
||||
|
||||
use super::model::{DQNModel, DQNModelConfig};
|
||||
const EXPLORE_RATE: f32 = 0.4;
|
||||
|
||||
pub struct App<'a> {
|
||||
model: DQNModel<Backend>,
|
||||
device: WgpuDevice,
|
||||
last_state_action: Option<(Info<'a>, Action)>,
|
||||
#[cfg(feature = "train")]
|
||||
outlet: BufWriter<File>,
|
||||
}
|
||||
|
||||
impl<'a> App<'a> {
|
||||
pub fn new(model_path: &str) -> Self {
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
let model = DQNModelConfig::new().init(&device);
|
||||
|
||||
let model = model
|
||||
.load_file(
|
||||
format!("{model_path}/model"),
|
||||
&NoStdTrainingRecorder::new(),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
Self {
|
||||
model,
|
||||
device,
|
||||
last_state_action: None,
|
||||
#[cfg(feature = "train")]
|
||||
outlet: BufWriter::new(
|
||||
OpenOptions::new()
|
||||
.write(true)
|
||||
.create(true)
|
||||
.open(format!("{model_path}/dataset"))
|
||||
.unwrap(),
|
||||
),
|
||||
}
|
||||
}
|
||||
#[cfg(feature = "train")]
|
||||
pub fn flush(&mut self) {
|
||||
self.outlet.flush().unwrap();
|
||||
}
|
||||
#[cfg(feature = "train")]
|
||||
pub fn collect_data(&mut self, state: &Info<'static>) -> Action {
|
||||
if let Some((previous_state, action)) = self.last_state_action.take() {
|
||||
let reward = previous_state.get_reward(state, action);
|
||||
let item = TankItem {
|
||||
previous_state: previous_state.into_feature(),
|
||||
new_state: state.into_feature(),
|
||||
action,
|
||||
reward,
|
||||
};
|
||||
bincode::serialize_into(&mut self.outlet, &item).unwrap();
|
||||
}
|
||||
|
||||
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
|
||||
true => match thread_rng().gen_range(0..15 as i32) {
|
||||
0 => Action::Forward,
|
||||
1 => Action::Backward,
|
||||
2 => Action::TurnRight,
|
||||
3 => Action::TurnLeft,
|
||||
4 => Action::AimRight,
|
||||
5 => Action::AimLeft,
|
||||
6 => Action::Shoot,
|
||||
7 => Action::TurnRight,
|
||||
_ => Action::Forward,
|
||||
},
|
||||
false => self.predict_action(state),
|
||||
};
|
||||
|
||||
self.last_state_action = Some((state.clone(), action));
|
||||
|
||||
action
|
||||
}
|
||||
pub fn predict_action(&self, state: &Info) -> Action {
|
||||
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
|
||||
let ans = self.model.forward(input);
|
||||
ans.argmax(1).into_scalar().try_into().unwrap()
|
||||
}
|
||||
}
|
127
tank-rust/src/dqn/dataset.rs
Normal file
127
tank-rust/src/dqn/dataset.rs
Normal file
@ -0,0 +1,127 @@
|
||||
use std::{
|
||||
env::{self},
|
||||
fs::File,
|
||||
io::BufReader,
|
||||
};
|
||||
|
||||
use crate::ffi::prelude::*;
|
||||
use crate::ARTIFACT_DIR;
|
||||
use burn::{
|
||||
data::{dataloader::batcher::Batcher, dataset::Dataset},
|
||||
prelude::*,
|
||||
};
|
||||
|
||||
use super::feature::FEATRUE_SPACE;
|
||||
|
||||
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
|
||||
pub struct TankItem {
|
||||
pub previous_state: [f32; FEATRUE_SPACE],
|
||||
pub new_state: [f32; FEATRUE_SPACE],
|
||||
pub action: Action,
|
||||
pub reward: f32,
|
||||
}
|
||||
|
||||
pub struct TankDataset {
|
||||
pub dataset: Vec<TankItem>,
|
||||
}
|
||||
|
||||
impl Dataset<TankItem> for TankDataset {
|
||||
fn get(&self, index: usize) -> Option<TankItem> {
|
||||
self.dataset.get(index).cloned()
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
self.dataset.len()
|
||||
}
|
||||
}
|
||||
|
||||
impl TankDataset {
|
||||
pub fn new() -> Self {
|
||||
Self {
|
||||
dataset: Vec::new(),
|
||||
}
|
||||
}
|
||||
pub fn load() -> Self {
|
||||
let dataset_path = env::var("MODEL_PATH").unwrap_or_else(|_| ARTIFACT_DIR.to_string());
|
||||
let mut dataset = Vec::new();
|
||||
println!("Loading dataset from: {}", dataset_path);
|
||||
if let Ok(reader) = File::open(format!("{dataset_path}/dataset")) {
|
||||
let mut reader = BufReader::new(reader);
|
||||
while let Ok(item) = bincode::deserialize_from::<_, TankItem>(&mut reader) {
|
||||
dataset.push(item);
|
||||
}
|
||||
}
|
||||
TankDataset { dataset }
|
||||
}
|
||||
pub fn add(&mut self, item: TankItem) {
|
||||
self.dataset.push(item);
|
||||
}
|
||||
pub fn split(self, ratio: f32) -> (Self, Self) {
|
||||
let split = (self.dataset.len() as f32 * ratio).round() as usize;
|
||||
let (a, b) = self.dataset.split_at(split);
|
||||
(
|
||||
TankDataset {
|
||||
dataset: a.to_vec(),
|
||||
},
|
||||
TankDataset {
|
||||
dataset: b.to_vec(),
|
||||
},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TankBatcher<B: Backend> {
|
||||
device: B::Device,
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug)]
|
||||
pub struct TankBatch<B: Backend> {
|
||||
pub new_state: Tensor<B, 2>,
|
||||
pub old_state: Tensor<B, 2>,
|
||||
pub action: Tensor<B, 1, Int>,
|
||||
pub reward: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> TankBatcher<B> {
|
||||
pub fn new(device: B::Device) -> Self {
|
||||
Self { device }
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> Batcher<TankItem, TankBatch<B>> for TankBatcher<B> {
|
||||
fn batch(&self, items: Vec<TankItem>) -> TankBatch<B> {
|
||||
let mut new_state: Vec<Tensor<B, 2>> = Vec::new();
|
||||
let mut old_state: Vec<Tensor<B, 2>> = Vec::new();
|
||||
|
||||
for item in items.iter() {
|
||||
let new_state_tensor = Tensor::<B, 1>::from_floats(item.previous_state, &self.device);
|
||||
let old_state_tensor = Tensor::<B, 1>::from_floats(item.new_state, &self.device);
|
||||
|
||||
new_state.push(new_state_tensor.unsqueeze());
|
||||
old_state.push(old_state_tensor.unsqueeze());
|
||||
}
|
||||
|
||||
let new_state = Tensor::cat(new_state, 0);
|
||||
let old_state = Tensor::cat(old_state, 0);
|
||||
|
||||
let reward = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 1, Float>::from_floats([item.reward], &self.device))
|
||||
.collect();
|
||||
let reward = Tensor::cat(reward, 0);
|
||||
|
||||
let response = items
|
||||
.iter()
|
||||
.map(|item| Tensor::<B, 1, Int>::from_ints([item.action as i32], &self.device))
|
||||
.collect();
|
||||
let response = Tensor::cat(response, 0);
|
||||
|
||||
TankBatch {
|
||||
new_state,
|
||||
old_state,
|
||||
reward,
|
||||
action: response,
|
||||
}
|
||||
}
|
||||
}
|
192
tank-rust/src/dqn/feature.rs
Normal file
192
tank-rust/src/dqn/feature.rs
Normal file
@ -0,0 +1,192 @@
|
||||
//! Feature extraction and reward calculation for DQN
|
||||
use std::f32::consts::PI;
|
||||
|
||||
use burn::tensor::{backend::Backend, Tensor};
|
||||
|
||||
use crate::ffi::prelude::*;
|
||||
|
||||
pub const FEATRUE_SPACE: usize = 10;
|
||||
pub const ACTION_SPACE: usize = 7;
|
||||
|
||||
#[derive(PartialEq, Default)]
|
||||
struct Polar {
|
||||
angle: f32,
|
||||
distance: f32,
|
||||
}
|
||||
|
||||
impl Polar {
|
||||
pub fn clip(&self) -> Self {
|
||||
Polar {
|
||||
angle: self.angle,
|
||||
distance: self.distance.min(1e6).max(0.0),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Eq for Polar {}
|
||||
|
||||
impl Ord for Polar {
|
||||
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
|
||||
self.distance.partial_cmp(&other.distance).unwrap()
|
||||
}
|
||||
}
|
||||
|
||||
impl PartialOrd for Polar {
|
||||
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
|
||||
self.distance.partial_cmp(&other.distance)
|
||||
}
|
||||
}
|
||||
|
||||
fn normalize_angle(mut angle: f32) -> f32 {
|
||||
while angle < -PI {
|
||||
angle += 2.0 * PI;
|
||||
}
|
||||
while angle >= PI {
|
||||
angle -= 2.0 * PI;
|
||||
}
|
||||
angle
|
||||
}
|
||||
|
||||
impl Player {
|
||||
fn to_pos(&self) -> (i32, i32) {
|
||||
(self.x, self.y)
|
||||
}
|
||||
fn center(&self, x: i32, y: i32) -> Polar {
|
||||
let dx = x - self.x;
|
||||
let dy = y - self.y;
|
||||
let angle = (dy as f32).atan2(dx as f32);
|
||||
let distance = (dx.pow(2) + dy.pow(2)) as f32;
|
||||
Polar { angle, distance }
|
||||
}
|
||||
fn closest(&self, others: impl Iterator<Item = (i32, i32)>) -> Polar {
|
||||
others
|
||||
.map(|(x, y)| self.center(x, y))
|
||||
.min()
|
||||
.unwrap_or_default()
|
||||
}
|
||||
fn get_angle(&self) -> f32 {
|
||||
(180 - self.angle) as f32 / 360.0 * 2.0 * PI
|
||||
}
|
||||
fn get_gun_angle(&self) -> f32 {
|
||||
self.gun_angle as f32 / 360.0 * 2.0 * PI
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug)]
|
||||
enum Target {
|
||||
Oil,
|
||||
Bullet,
|
||||
Enemy,
|
||||
}
|
||||
|
||||
impl Target {
|
||||
fn get_pos(&self, info: &Info) -> Polar {
|
||||
match self {
|
||||
Target::Oil => info
|
||||
.player
|
||||
.closest(info.oil_stations.iter().map(Station::to_pos)),
|
||||
Target::Bullet => info
|
||||
.player
|
||||
.closest(info.bullet_stations.iter().map(Station::to_pos)),
|
||||
Target::Enemy => info.player.closest(info.enemies.iter().map(Player::to_pos)),
|
||||
}
|
||||
}
|
||||
fn reach(&self, last: &Info, current: &Info) -> bool {
|
||||
match self {
|
||||
Target::Oil => last.player.oil > current.player.oil,
|
||||
Target::Bullet => last.player.power > current.player.power,
|
||||
Target::Enemy => false,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Station {
|
||||
fn to_pos(&self) -> (i32, i32) {
|
||||
(self.x as i32, self.y as i32)
|
||||
}
|
||||
}
|
||||
|
||||
impl Wall {
|
||||
fn to_pos(&self) -> (i32, i32) {
|
||||
(self.x, self.y)
|
||||
}
|
||||
}
|
||||
|
||||
impl<'a> Info<'a> {
|
||||
pub fn into_feature(&self) -> [f32; FEATRUE_SPACE] {
|
||||
let emeny = self.player.closest(self.enemies.iter().map(Player::to_pos));
|
||||
let wall = self
|
||||
.player
|
||||
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
|
||||
let bullet = self
|
||||
.player
|
||||
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
|
||||
|
||||
let target = self.get_target().get_pos(self).clip();
|
||||
|
||||
let angle = self.player.get_angle();
|
||||
let gun_angle = self.player.get_gun_angle();
|
||||
|
||||
let feature = [
|
||||
normalize_angle(target.angle - angle).tanh(),
|
||||
normalize_angle(target.angle - angle + PI).tanh(),
|
||||
normalize_angle(bullet.angle - angle).tanh(),
|
||||
(target.distance + 1.0).log2(),
|
||||
(wall.distance - target.distance).tanh(),
|
||||
(bullet.distance + 1.0).log2(),
|
||||
normalize_angle(emeny.angle - gun_angle).tanh(),
|
||||
normalize_angle(wall.angle - gun_angle).tanh(),
|
||||
(self.player.oil - 40.0).tanh(),
|
||||
(self.player.power as f32 - 7.0).tanh(),
|
||||
];
|
||||
|
||||
feature
|
||||
}
|
||||
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
|
||||
let feature = self.into_feature();
|
||||
Tensor::from_floats(feature, device)
|
||||
}
|
||||
fn get_target(&self) -> Target {
|
||||
if self.player.oil < 40.0 {
|
||||
Target::Oil
|
||||
} else if self.player.power > 7 {
|
||||
Target::Enemy
|
||||
} else {
|
||||
Target::Bullet
|
||||
}
|
||||
}
|
||||
pub fn get_reward(&self, next: &Self, action: Action) -> f32 {
|
||||
let same_position = self.player.x == next.player.x && self.player.y == next.player.y;
|
||||
let mut reward = -2.3;
|
||||
reward += match action {
|
||||
Action::Forward | Action::Backward if same_position => -8.0,
|
||||
Action::Shoot => match next.player.power > 7 {
|
||||
true => 2.0,
|
||||
false => -2.0,
|
||||
},
|
||||
_ => 0.0,
|
||||
};
|
||||
|
||||
let target = self.get_target();
|
||||
|
||||
if target.reach(self, next) {
|
||||
reward += 15.0;
|
||||
} else {
|
||||
let previous_target_position = target.get_pos(self);
|
||||
let next_target_position = target.get_pos(next);
|
||||
|
||||
reward += match previous_target_position.cmp(&next_target_position) {
|
||||
std::cmp::Ordering::Less => -5.0,
|
||||
std::cmp::Ordering::Equal => 0.0,
|
||||
std::cmp::Ordering::Greater => 5.8,
|
||||
};
|
||||
}
|
||||
|
||||
reward
|
||||
+ match next.player.score - self.player.score {
|
||||
x if x > 2 => 20.0,
|
||||
x if x > 0 => 10.0, // too high, tank my ignore power station
|
||||
_ => -1.0,
|
||||
}
|
||||
}
|
||||
}
|
12
tank-rust/src/dqn/mod.rs
Normal file
12
tank-rust/src/dqn/mod.rs
Normal file
@ -0,0 +1,12 @@
|
||||
mod collect;
|
||||
mod dataset;
|
||||
mod feature;
|
||||
mod model;
|
||||
mod training;
|
||||
|
||||
pub mod prelude {
|
||||
pub use super::collect::App as DQNApp;
|
||||
pub use super::dataset::{TankDataset, TankItem};
|
||||
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
|
||||
pub use super::training::run as train;
|
||||
}
|
122
tank-rust/src/dqn/model.rs
Normal file
122
tank-rust/src/dqn/model.rs
Normal file
@ -0,0 +1,122 @@
|
||||
use super::{
|
||||
dataset::TankBatch,
|
||||
feature::{ACTION_SPACE, FEATRUE_SPACE},
|
||||
};
|
||||
use burn::{
|
||||
nn::{
|
||||
loss::{HuberLoss, HuberLossConfig, Reduction::Mean},
|
||||
Linear, LinearConfig, Relu,
|
||||
},
|
||||
prelude::*,
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{
|
||||
metric::{Adaptor, LossInput},
|
||||
TrainOutput, TrainStep, ValidStep,
|
||||
},
|
||||
};
|
||||
|
||||
pub struct DQNOutput<B: Backend> {
|
||||
estimated_reward: Tensor<B, 2>,
|
||||
expected_reward: Tensor<B, 2>,
|
||||
loss: Tensor<B, 1>,
|
||||
}
|
||||
|
||||
impl<B: Backend> Adaptor<LossInput<B>> for DQNOutput<B> {
|
||||
fn adapt(&self) -> LossInput<B> {
|
||||
LossInput::new(self.loss.clone())
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Module, Debug)]
|
||||
pub struct DQNModel<B: Backend> {
|
||||
input_layer: Linear<B>,
|
||||
hidden_layer_1: Linear<B>,
|
||||
hidden_layer_2: Linear<B>,
|
||||
output_layer: Linear<B>,
|
||||
activation: Relu,
|
||||
loss_function: HuberLoss<B>,
|
||||
gamma: f32,
|
||||
}
|
||||
|
||||
#[derive(Config)]
|
||||
pub struct DQNModelConfig {
|
||||
#[config(default = 64)]
|
||||
pub hidden_layer_1_size: usize,
|
||||
#[config(default = 96)]
|
||||
pub hidden_layer_2_size: usize,
|
||||
#[config(default = 64)]
|
||||
pub hidden_layer_3_size: usize,
|
||||
#[config(default = 0.99)]
|
||||
pub gamma: f32,
|
||||
}
|
||||
|
||||
impl DQNModelConfig {
|
||||
pub fn init<B: Backend>(&self, device: &B::Device) -> DQNModel<B> {
|
||||
let input_layer = LinearConfig::new(FEATRUE_SPACE, self.hidden_layer_1_size)
|
||||
.with_bias(true)
|
||||
.init(device);
|
||||
let hidden_layer_1 = LinearConfig::new(self.hidden_layer_1_size, self.hidden_layer_2_size)
|
||||
.with_bias(true)
|
||||
.init(device);
|
||||
let hidden_layer_2 = LinearConfig::new(self.hidden_layer_2_size, self.hidden_layer_3_size)
|
||||
.with_bias(true)
|
||||
.init(device);
|
||||
let output_layer = LinearConfig::new(self.hidden_layer_3_size, ACTION_SPACE)
|
||||
.with_bias(true)
|
||||
.init(device);
|
||||
|
||||
DQNModel {
|
||||
input_layer,
|
||||
hidden_layer_1,
|
||||
hidden_layer_2,
|
||||
output_layer,
|
||||
loss_function: HuberLossConfig::new(1.34).init(device),
|
||||
activation: Relu::new(),
|
||||
gamma: self.gamma,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> DQNModel<B> {
|
||||
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
|
||||
let x = input.detach();
|
||||
let x = self.input_layer.forward(x);
|
||||
let x = self.activation.forward(x);
|
||||
self.output_layer.forward(x)
|
||||
}
|
||||
|
||||
pub fn forward_step(&self, item: TankBatch<B>) -> DQNOutput<B> {
|
||||
let estimated_reward = self.forward(item.new_state);
|
||||
|
||||
// FIXME: magic unsqueeze
|
||||
let a = item.action.unsqueeze_dim(1);
|
||||
let x = estimated_reward.clone().gather(1, a);
|
||||
|
||||
// FIXME: what's final mask
|
||||
let expected_reward = self.forward(item.old_state);
|
||||
let y = expected_reward.clone().max_dim(1);
|
||||
let y = y.mul_scalar(self.gamma).add(item.reward.unsqueeze());
|
||||
|
||||
let loss = self.loss_function.forward(x, y, Mean);
|
||||
|
||||
DQNOutput {
|
||||
estimated_reward,
|
||||
expected_reward,
|
||||
loss,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: AutodiffBackend> TrainStep<TankBatch<B>, DQNOutput<B>> for DQNModel<B> {
|
||||
fn step(&self, item: TankBatch<B>) -> TrainOutput<DQNOutput<B>> {
|
||||
let item = self.forward_step(item);
|
||||
|
||||
TrainOutput::new(self, item.loss.backward(), item)
|
||||
}
|
||||
}
|
||||
|
||||
impl<B: Backend> ValidStep<TankBatch<B>, DQNOutput<B>> for DQNModel<B> {
|
||||
fn step(&self, item: TankBatch<B>) -> DQNOutput<B> {
|
||||
self.forward_step(item)
|
||||
}
|
||||
}
|
103
tank-rust/src/dqn/training.rs
Normal file
103
tank-rust/src/dqn/training.rs
Normal file
@ -0,0 +1,103 @@
|
||||
use std::{env, fs};
|
||||
|
||||
use super::dataset::{TankBatcher, TankDataset};
|
||||
use super::model::DQNModelConfig;
|
||||
use crate::ARTIFACT_DIR;
|
||||
use burn::optim::AdamConfig;
|
||||
use burn::train::metric::store::{Aggregate, Direction, Split};
|
||||
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
|
||||
use burn::{
|
||||
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
|
||||
prelude::*,
|
||||
record::{CompactRecorder, NoStdTrainingRecorder},
|
||||
tensor::backend::AutodiffBackend,
|
||||
train::{metric::LossMetric, LearnerBuilder},
|
||||
};
|
||||
#[derive(Config)]
|
||||
pub struct ExpConfig {
|
||||
#[config(default = 16)]
|
||||
pub num_epochs: usize,
|
||||
|
||||
#[config(default = 6)]
|
||||
pub num_workers: usize,
|
||||
|
||||
#[config(default = 47)]
|
||||
pub seed: u64,
|
||||
|
||||
pub optimizer: AdamConfig,
|
||||
|
||||
#[config(default = 1.5e-3)]
|
||||
pub learn_rate: f64,
|
||||
|
||||
#[config(default = 4096)]
|
||||
pub batch_size: usize,
|
||||
}
|
||||
|
||||
pub fn run<B>(device: B::Device)
|
||||
where
|
||||
B: AutodiffBackend,
|
||||
{
|
||||
let model_path = env::var("MODEL_PATH").unwrap_or_else(|_| ARTIFACT_DIR.to_string());
|
||||
|
||||
let optimizer = AdamConfig::new();
|
||||
let config = ExpConfig::new(optimizer);
|
||||
let mut model = DQNModelConfig::new().init(&device);
|
||||
|
||||
if fs::metadata(format!("{model_path}/model")).is_ok() {
|
||||
model = model
|
||||
.load_file(
|
||||
format!("{model_path}/model"),
|
||||
&NoStdTrainingRecorder::new(),
|
||||
&device,
|
||||
)
|
||||
.unwrap();
|
||||
}
|
||||
|
||||
B::seed(config.seed);
|
||||
|
||||
let (train_dataset, test_dataset) = TankDataset::load().split(1.0);
|
||||
|
||||
println!("Train Dataset Size: {}", train_dataset.len());
|
||||
println!("Test Dataset Size: {}", test_dataset.len());
|
||||
|
||||
let batcher_train = TankBatcher::<B>::new(device.clone());
|
||||
|
||||
let batcher_test = TankBatcher::<B::InnerBackend>::new(device.clone());
|
||||
|
||||
let dataloader_train = DataLoaderBuilder::new(batcher_train)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(train_dataset);
|
||||
|
||||
let dataloader_test = DataLoaderBuilder::new(batcher_test)
|
||||
.batch_size(config.batch_size)
|
||||
.shuffle(config.seed)
|
||||
.num_workers(config.num_workers)
|
||||
.build(test_dataset);
|
||||
|
||||
let learner = LearnerBuilder::new(&model_path)
|
||||
.metric_train_numeric(LossMetric::new())
|
||||
.metric_valid_numeric(LossMetric::new())
|
||||
.with_file_checkpointer(CompactRecorder::new())
|
||||
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
|
||||
Aggregate::Mean,
|
||||
Direction::Lowest,
|
||||
Split::Train,
|
||||
StoppingCondition::NoImprovementSince { n_epochs: 2 },
|
||||
))
|
||||
.devices(vec![device.clone()])
|
||||
.num_epochs(config.num_epochs)
|
||||
.summary()
|
||||
.build(model, config.optimizer.init(), config.learn_rate);
|
||||
|
||||
let model_trained = learner.fit(dataloader_train, dataloader_test);
|
||||
|
||||
config
|
||||
.save(format!("{model_path}/config.json").as_str())
|
||||
.unwrap();
|
||||
|
||||
model_trained
|
||||
.save_file(format!("{model_path}/model"), &NoStdTrainingRecorder::new())
|
||||
.expect("Failed to save trained model");
|
||||
}
|
30
tank-rust/src/ffi/action.rs
Normal file
30
tank-rust/src/ffi/action.rs
Normal file
@ -0,0 +1,30 @@
|
||||
#[repr(i32)]
|
||||
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub enum Action {
|
||||
Forward = 0,
|
||||
Backward = 1,
|
||||
TurnRight = 2,
|
||||
TurnLeft = 3,
|
||||
AimRight = 4,
|
||||
AimLeft = 5,
|
||||
Shoot = 6,
|
||||
None = 7,
|
||||
}
|
||||
|
||||
impl TryFrom<i32> for Action {
|
||||
type Error = &'static str;
|
||||
|
||||
fn try_from(value: i32) -> Result<Self, Self::Error> {
|
||||
match value {
|
||||
0 => Ok(Action::Forward),
|
||||
1 => Ok(Action::Backward),
|
||||
2 => Ok(Action::TurnRight),
|
||||
3 => Ok(Action::TurnLeft),
|
||||
4 => Ok(Action::AimRight),
|
||||
5 => Ok(Action::AimLeft),
|
||||
6 => Ok(Action::Shoot),
|
||||
7 => Ok(Action::None),
|
||||
_ => Err("Invalid action"),
|
||||
}
|
||||
}
|
||||
}
|
6
tank-rust/src/ffi/info/bullet.rs
Normal file
6
tank-rust/src/ffi/info/bullet.rs
Normal file
@ -0,0 +1,6 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct Bullet {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
}
|
58
tank-rust/src/ffi/info/mod.rs
Normal file
58
tank-rust/src/ffi/info/mod.rs
Normal file
@ -0,0 +1,58 @@
|
||||
mod bullet;
|
||||
mod player;
|
||||
mod station;
|
||||
mod wall;
|
||||
|
||||
pub use bullet::*;
|
||||
pub use player::*;
|
||||
pub use station::*;
|
||||
pub use wall::*;
|
||||
|
||||
#[derive(Debug, Clone, Default)]
|
||||
pub struct Info<'a> {
|
||||
pub player: Player,
|
||||
pub teammates: &'a [Player],
|
||||
pub enemies: &'a [Player],
|
||||
pub bullets: &'a [Bullet],
|
||||
pub bullet_stations: &'a [Station],
|
||||
pub oil_stations: &'a [Station],
|
||||
pub walls: &'a [Wall],
|
||||
}
|
||||
|
||||
#[repr(C)]
|
||||
pub struct RawInfo {
|
||||
player: *const Player,
|
||||
teammates: *const Player,
|
||||
teammates_len: u32,
|
||||
enemies: *const Player,
|
||||
enemies_len: u32,
|
||||
bullets: *const Bullet,
|
||||
bullet_len: u32,
|
||||
bullet_stations: *const Station,
|
||||
bullet_stations_len: u32,
|
||||
oil_stations: *const Station,
|
||||
oil_stations_len: u32,
|
||||
walls: *const Wall,
|
||||
walls_len: u32,
|
||||
}
|
||||
|
||||
impl<'a> Info<'a> {
|
||||
pub unsafe fn from_raw(self_: *const RawInfo) -> Self {
|
||||
let raw = &*self_;
|
||||
Info {
|
||||
player: (&*raw.player).clone(),
|
||||
teammates: std::slice::from_raw_parts(raw.teammates, raw.teammates_len as usize),
|
||||
enemies: std::slice::from_raw_parts(raw.enemies, raw.enemies_len as usize),
|
||||
bullets: std::slice::from_raw_parts(raw.bullets, raw.bullet_len as usize),
|
||||
bullet_stations: std::slice::from_raw_parts(
|
||||
raw.bullet_stations,
|
||||
raw.bullet_stations_len as usize,
|
||||
),
|
||||
oil_stations: std::slice::from_raw_parts(
|
||||
raw.oil_stations,
|
||||
raw.oil_stations_len as usize,
|
||||
),
|
||||
walls: std::slice::from_raw_parts(raw.walls, raw.walls_len as usize),
|
||||
}
|
||||
}
|
||||
}
|
14
tank-rust/src/ffi/info/player.rs
Normal file
14
tank-rust/src/ffi/info/player.rs
Normal file
@ -0,0 +1,14 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
|
||||
pub struct Player {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub speed: i32,
|
||||
pub score: i32,
|
||||
pub power: i32,
|
||||
pub oil: f32,
|
||||
pub lives: i32,
|
||||
pub angle: i32,
|
||||
pub gun_angle: i32,
|
||||
pub cooldown: i32,
|
||||
}
|
7
tank-rust/src/ffi/info/station.rs
Normal file
7
tank-rust/src/ffi/info/station.rs
Normal file
@ -0,0 +1,7 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize, Default)]
|
||||
pub struct Station {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub power: i32,
|
||||
}
|
7
tank-rust/src/ffi/info/wall.rs
Normal file
7
tank-rust/src/ffi/info/wall.rs
Normal file
@ -0,0 +1,7 @@
|
||||
#[repr(C)]
|
||||
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
|
||||
pub struct Wall {
|
||||
pub x: i32,
|
||||
pub y: i32,
|
||||
pub lives: i32,
|
||||
}
|
7
tank-rust/src/ffi/mod.rs
Normal file
7
tank-rust/src/ffi/mod.rs
Normal file
@ -0,0 +1,7 @@
|
||||
mod action;
|
||||
mod info;
|
||||
|
||||
pub mod prelude {
|
||||
pub use super::action::*;
|
||||
pub use super::info::*;
|
||||
}
|
77
tank-rust/src/fit.rs
Normal file
77
tank-rust/src/fit.rs
Normal file
@ -0,0 +1,77 @@
|
||||
use burn::data::dataset::Dataset;
|
||||
|
||||
use crate::dqn::prelude::TankItem;
|
||||
use crate::ffi::prelude::*;
|
||||
use rand::Rng;
|
||||
|
||||
// fn random_action() -> Action {
|
||||
// let mut rng = rand::thread_rng();
|
||||
// match rng.gen_range(0..2) {
|
||||
// 0 => Action::AimLeft,
|
||||
// 1 => Action::Forward,
|
||||
// _ => unreachable!(),
|
||||
// }
|
||||
// }
|
||||
|
||||
// fn random_item() -> TankItem {
|
||||
// let mut previous_info=Info::default();
|
||||
// TankItem {
|
||||
// previous_state: todo!(),
|
||||
// new_state: todo!(),
|
||||
// action: todo!(),
|
||||
// reward: todo!(),
|
||||
// }
|
||||
// }
|
||||
|
||||
pub struct FitDataset;
|
||||
|
||||
impl FitDataset {
|
||||
/// Get closer to the power station
|
||||
fn close_power_station() -> TankItem {
|
||||
let mut power_stations = Station::default();
|
||||
|
||||
let mut previous_info = Info::default();
|
||||
let mut new_info = Info::default();
|
||||
let mut rng = rand::thread_rng();
|
||||
previous_info.player.power = rng.gen_range(0..2);
|
||||
new_info.player.power = previous_info.player.power;
|
||||
previous_info.player.angle = rng.gen_range(0..360);
|
||||
new_info.player.angle = previous_info.player.angle;
|
||||
|
||||
TankItem {
|
||||
previous_state: todo!(),
|
||||
new_state: todo!(),
|
||||
action: Action::Forward,
|
||||
reward: todo!(),
|
||||
}
|
||||
}
|
||||
/// Flee from power station if power is high
|
||||
fn flee_power_station() -> TankItem {
|
||||
let mut previous_info = Info::default();
|
||||
TankItem {
|
||||
previous_state: todo!(),
|
||||
new_state: todo!(),
|
||||
action: Action::Backward,
|
||||
reward: todo!(),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Dataset<TankItem> for FitDataset {
|
||||
fn get(&self, _: usize) -> Option<TankItem> {
|
||||
let previous_state = todo!();
|
||||
let new_state = todo!();
|
||||
let action = Action::AimLeft;
|
||||
let reward = 0.0;
|
||||
Some(TankItem {
|
||||
previous_state,
|
||||
new_state,
|
||||
action,
|
||||
reward,
|
||||
})
|
||||
}
|
||||
|
||||
fn len(&self) -> usize {
|
||||
1
|
||||
}
|
||||
}
|
44
tank-rust/src/lib.rs
Normal file
44
tank-rust/src/lib.rs
Normal file
@ -0,0 +1,44 @@
|
||||
mod dqn;
|
||||
mod ffi;
|
||||
use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
|
||||
|
||||
use burn::backend::{wgpu::AutoGraphicsApi, Wgpu};
|
||||
use dqn::prelude::*;
|
||||
use ffi::prelude::*;
|
||||
|
||||
static ARTIFACT_DIR: &str = "../output";
|
||||
|
||||
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> {
|
||||
let model_path =
|
||||
OsStr::from_bytes(unsafe { std::slice::from_raw_parts(model_path, len as usize) })
|
||||
.to_str()
|
||||
.unwrap();
|
||||
let app = DQNApp::new(model_path);
|
||||
|
||||
Box::into_raw(Box::new(app))
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn tick(app: *mut DQNApp, raw: *mut RawInfo) -> i32 {
|
||||
let app = unsafe { &mut *app };
|
||||
let info: Info<'static> = unsafe { Info::from_raw(raw) };
|
||||
cfg_if::cfg_if! {
|
||||
if #[cfg(feature = "train")]{
|
||||
app.collect_data(&info)as i32
|
||||
}else{
|
||||
app.predict_action(&info)as i32
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[no_mangle]
|
||||
pub extern "C" fn flush(app: *mut DQNApp) {
|
||||
let app = unsafe { &mut *app };
|
||||
#[cfg(feature = "train")]
|
||||
{
|
||||
app.flush();
|
||||
}
|
||||
}
|
14
tank-rust/src/main.rs
Normal file
14
tank-rust/src/main.rs
Normal file
@ -0,0 +1,14 @@
|
||||
mod dqn;
|
||||
mod ffi;
|
||||
use burn::backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu};
|
||||
|
||||
static ARTIFACT_DIR: &str = "../output";
|
||||
|
||||
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
|
||||
type AutodiffBackend = Autodiff<Backend>;
|
||||
|
||||
pub fn main() {
|
||||
let device = burn::backend::wgpu::WgpuDevice::default();
|
||||
|
||||
dqn::prelude::train::<AutodiffBackend>(device);
|
||||
}
|
Reference in New Issue
Block a user