This commit is contained in:
Eason
2024-06-12 16:09:12 +08:00
commit d8b781011b
156 changed files with 26489 additions and 0 deletions

View File

@ -0,0 +1,90 @@
use std::{
fs::{File, OpenOptions},
io::{BufWriter, Write},
};
use super::dataset::TankItem;
use crate::{ffi::prelude::*, Backend};
use burn::{backend::wgpu::WgpuDevice, module::Module, record::NoStdTrainingRecorder};
use rand::{thread_rng, Rng};
use super::model::{DQNModel, DQNModelConfig};
const EXPLORE_RATE: f32 = 0.4;
pub struct App<'a> {
model: DQNModel<Backend>,
device: WgpuDevice,
last_state_action: Option<(Info<'a>, Action)>,
#[cfg(feature = "train")]
outlet: BufWriter<File>,
}
impl<'a> App<'a> {
pub fn new(model_path: &str) -> Self {
let device = burn::backend::wgpu::WgpuDevice::default();
let model = DQNModelConfig::new().init(&device);
let model = model
.load_file(
format!("{model_path}/model"),
&NoStdTrainingRecorder::new(),
&device,
)
.unwrap();
Self {
model,
device,
last_state_action: None,
#[cfg(feature = "train")]
outlet: BufWriter::new(
OpenOptions::new()
.write(true)
.create(true)
.open(format!("{model_path}/dataset"))
.unwrap(),
),
}
}
#[cfg(feature = "train")]
pub fn flush(&mut self) {
self.outlet.flush().unwrap();
}
#[cfg(feature = "train")]
pub fn collect_data(&mut self, state: &Info<'static>) -> Action {
if let Some((previous_state, action)) = self.last_state_action.take() {
let reward = previous_state.get_reward(state, action);
let item = TankItem {
previous_state: previous_state.into_feature(),
new_state: state.into_feature(),
action,
reward,
};
bincode::serialize_into(&mut self.outlet, &item).unwrap();
}
let action = match thread_rng().gen_ratio((4096.0 * EXPLORE_RATE) as u32, 4096) {
true => match thread_rng().gen_range(0..15 as i32) {
0 => Action::Forward,
1 => Action::Backward,
2 => Action::TurnRight,
3 => Action::TurnLeft,
4 => Action::AimRight,
5 => Action::AimLeft,
6 => Action::Shoot,
7 => Action::TurnRight,
_ => Action::Forward,
},
false => self.predict_action(state),
};
self.last_state_action = Some((state.clone(), action));
action
}
pub fn predict_action(&self, state: &Info) -> Action {
let input = state.into_feature_tensor(&self.device).unsqueeze(); // Convert input tensor to shape [1, input_size]
let ans = self.model.forward(input);
ans.argmax(1).into_scalar().try_into().unwrap()
}
}

View File

@ -0,0 +1,127 @@
use std::{
env::{self},
fs::File,
io::BufReader,
};
use crate::ffi::prelude::*;
use crate::ARTIFACT_DIR;
use burn::{
data::{dataloader::batcher::Batcher, dataset::Dataset},
prelude::*,
};
use super::feature::FEATRUE_SPACE;
#[derive(Clone, Debug, serde::Serialize, serde::Deserialize)]
pub struct TankItem {
pub previous_state: [f32; FEATRUE_SPACE],
pub new_state: [f32; FEATRUE_SPACE],
pub action: Action,
pub reward: f32,
}
pub struct TankDataset {
pub dataset: Vec<TankItem>,
}
impl Dataset<TankItem> for TankDataset {
fn get(&self, index: usize) -> Option<TankItem> {
self.dataset.get(index).cloned()
}
fn len(&self) -> usize {
self.dataset.len()
}
}
impl TankDataset {
pub fn new() -> Self {
Self {
dataset: Vec::new(),
}
}
pub fn load() -> Self {
let dataset_path = env::var("MODEL_PATH").unwrap_or_else(|_| ARTIFACT_DIR.to_string());
let mut dataset = Vec::new();
println!("Loading dataset from: {}", dataset_path);
if let Ok(reader) = File::open(format!("{dataset_path}/dataset")) {
let mut reader = BufReader::new(reader);
while let Ok(item) = bincode::deserialize_from::<_, TankItem>(&mut reader) {
dataset.push(item);
}
}
TankDataset { dataset }
}
pub fn add(&mut self, item: TankItem) {
self.dataset.push(item);
}
pub fn split(self, ratio: f32) -> (Self, Self) {
let split = (self.dataset.len() as f32 * ratio).round() as usize;
let (a, b) = self.dataset.split_at(split);
(
TankDataset {
dataset: a.to_vec(),
},
TankDataset {
dataset: b.to_vec(),
},
)
}
}
#[derive(Clone, Debug)]
pub struct TankBatcher<B: Backend> {
device: B::Device,
}
#[derive(Clone, Debug)]
pub struct TankBatch<B: Backend> {
pub new_state: Tensor<B, 2>,
pub old_state: Tensor<B, 2>,
pub action: Tensor<B, 1, Int>,
pub reward: Tensor<B, 1>,
}
impl<B: Backend> TankBatcher<B> {
pub fn new(device: B::Device) -> Self {
Self { device }
}
}
impl<B: Backend> Batcher<TankItem, TankBatch<B>> for TankBatcher<B> {
fn batch(&self, items: Vec<TankItem>) -> TankBatch<B> {
let mut new_state: Vec<Tensor<B, 2>> = Vec::new();
let mut old_state: Vec<Tensor<B, 2>> = Vec::new();
for item in items.iter() {
let new_state_tensor = Tensor::<B, 1>::from_floats(item.previous_state, &self.device);
let old_state_tensor = Tensor::<B, 1>::from_floats(item.new_state, &self.device);
new_state.push(new_state_tensor.unsqueeze());
old_state.push(old_state_tensor.unsqueeze());
}
let new_state = Tensor::cat(new_state, 0);
let old_state = Tensor::cat(old_state, 0);
let reward = items
.iter()
.map(|item| Tensor::<B, 1, Float>::from_floats([item.reward], &self.device))
.collect();
let reward = Tensor::cat(reward, 0);
let response = items
.iter()
.map(|item| Tensor::<B, 1, Int>::from_ints([item.action as i32], &self.device))
.collect();
let response = Tensor::cat(response, 0);
TankBatch {
new_state,
old_state,
reward,
action: response,
}
}
}

View File

@ -0,0 +1,192 @@
//! Feature extraction and reward calculation for DQN
use std::f32::consts::PI;
use burn::tensor::{backend::Backend, Tensor};
use crate::ffi::prelude::*;
pub const FEATRUE_SPACE: usize = 10;
pub const ACTION_SPACE: usize = 7;
#[derive(PartialEq, Default)]
struct Polar {
angle: f32,
distance: f32,
}
impl Polar {
pub fn clip(&self) -> Self {
Polar {
angle: self.angle,
distance: self.distance.min(1e6).max(0.0),
}
}
}
impl Eq for Polar {}
impl Ord for Polar {
fn cmp(&self, other: &Self) -> std::cmp::Ordering {
self.distance.partial_cmp(&other.distance).unwrap()
}
}
impl PartialOrd for Polar {
fn partial_cmp(&self, other: &Self) -> Option<std::cmp::Ordering> {
self.distance.partial_cmp(&other.distance)
}
}
fn normalize_angle(mut angle: f32) -> f32 {
while angle < -PI {
angle += 2.0 * PI;
}
while angle >= PI {
angle -= 2.0 * PI;
}
angle
}
impl Player {
fn to_pos(&self) -> (i32, i32) {
(self.x, self.y)
}
fn center(&self, x: i32, y: i32) -> Polar {
let dx = x - self.x;
let dy = y - self.y;
let angle = (dy as f32).atan2(dx as f32);
let distance = (dx.pow(2) + dy.pow(2)) as f32;
Polar { angle, distance }
}
fn closest(&self, others: impl Iterator<Item = (i32, i32)>) -> Polar {
others
.map(|(x, y)| self.center(x, y))
.min()
.unwrap_or_default()
}
fn get_angle(&self) -> f32 {
(180 - self.angle) as f32 / 360.0 * 2.0 * PI
}
fn get_gun_angle(&self) -> f32 {
self.gun_angle as f32 / 360.0 * 2.0 * PI
}
}
#[derive(Debug)]
enum Target {
Oil,
Bullet,
Enemy,
}
impl Target {
fn get_pos(&self, info: &Info) -> Polar {
match self {
Target::Oil => info
.player
.closest(info.oil_stations.iter().map(Station::to_pos)),
Target::Bullet => info
.player
.closest(info.bullet_stations.iter().map(Station::to_pos)),
Target::Enemy => info.player.closest(info.enemies.iter().map(Player::to_pos)),
}
}
fn reach(&self, last: &Info, current: &Info) -> bool {
match self {
Target::Oil => last.player.oil > current.player.oil,
Target::Bullet => last.player.power > current.player.power,
Target::Enemy => false,
}
}
}
impl Station {
fn to_pos(&self) -> (i32, i32) {
(self.x as i32, self.y as i32)
}
}
impl Wall {
fn to_pos(&self) -> (i32, i32) {
(self.x, self.y)
}
}
impl<'a> Info<'a> {
pub fn into_feature(&self) -> [f32; FEATRUE_SPACE] {
let emeny = self.player.closest(self.enemies.iter().map(Player::to_pos));
let wall = self
.player
.closest(self.walls.iter().map(|wall| (wall.x, wall.y)));
let bullet = self
.player
.closest(self.bullets.iter().map(|bullet| (bullet.x, bullet.y)));
let target = self.get_target().get_pos(self).clip();
let angle = self.player.get_angle();
let gun_angle = self.player.get_gun_angle();
let feature = [
normalize_angle(target.angle - angle).tanh(),
normalize_angle(target.angle - angle + PI).tanh(),
normalize_angle(bullet.angle - angle).tanh(),
(target.distance + 1.0).log2(),
(wall.distance - target.distance).tanh(),
(bullet.distance + 1.0).log2(),
normalize_angle(emeny.angle - gun_angle).tanh(),
normalize_angle(wall.angle - gun_angle).tanh(),
(self.player.oil - 40.0).tanh(),
(self.player.power as f32 - 7.0).tanh(),
];
feature
}
pub fn into_feature_tensor<B: Backend>(&self, device: &B::Device) -> Tensor<B, 1> {
let feature = self.into_feature();
Tensor::from_floats(feature, device)
}
fn get_target(&self) -> Target {
if self.player.oil < 40.0 {
Target::Oil
} else if self.player.power > 7 {
Target::Enemy
} else {
Target::Bullet
}
}
pub fn get_reward(&self, next: &Self, action: Action) -> f32 {
let same_position = self.player.x == next.player.x && self.player.y == next.player.y;
let mut reward = -2.3;
reward += match action {
Action::Forward | Action::Backward if same_position => -8.0,
Action::Shoot => match next.player.power > 7 {
true => 2.0,
false => -2.0,
},
_ => 0.0,
};
let target = self.get_target();
if target.reach(self, next) {
reward += 15.0;
} else {
let previous_target_position = target.get_pos(self);
let next_target_position = target.get_pos(next);
reward += match previous_target_position.cmp(&next_target_position) {
std::cmp::Ordering::Less => -5.0,
std::cmp::Ordering::Equal => 0.0,
std::cmp::Ordering::Greater => 5.8,
};
}
reward
+ match next.player.score - self.player.score {
x if x > 2 => 20.0,
x if x > 0 => 10.0, // too high, tank my ignore power station
_ => -1.0,
}
}
}

12
tank-rust/src/dqn/mod.rs Normal file
View File

@ -0,0 +1,12 @@
mod collect;
mod dataset;
mod feature;
mod model;
mod training;
pub mod prelude {
pub use super::collect::App as DQNApp;
pub use super::dataset::{TankDataset, TankItem};
pub use super::feature::{ACTION_SPACE, FEATRUE_SPACE};
pub use super::training::run as train;
}

122
tank-rust/src/dqn/model.rs Normal file
View File

@ -0,0 +1,122 @@
use super::{
dataset::TankBatch,
feature::{ACTION_SPACE, FEATRUE_SPACE},
};
use burn::{
nn::{
loss::{HuberLoss, HuberLossConfig, Reduction::Mean},
Linear, LinearConfig, Relu,
},
prelude::*,
tensor::backend::AutodiffBackend,
train::{
metric::{Adaptor, LossInput},
TrainOutput, TrainStep, ValidStep,
},
};
pub struct DQNOutput<B: Backend> {
estimated_reward: Tensor<B, 2>,
expected_reward: Tensor<B, 2>,
loss: Tensor<B, 1>,
}
impl<B: Backend> Adaptor<LossInput<B>> for DQNOutput<B> {
fn adapt(&self) -> LossInput<B> {
LossInput::new(self.loss.clone())
}
}
#[derive(Module, Debug)]
pub struct DQNModel<B: Backend> {
input_layer: Linear<B>,
hidden_layer_1: Linear<B>,
hidden_layer_2: Linear<B>,
output_layer: Linear<B>,
activation: Relu,
loss_function: HuberLoss<B>,
gamma: f32,
}
#[derive(Config)]
pub struct DQNModelConfig {
#[config(default = 64)]
pub hidden_layer_1_size: usize,
#[config(default = 96)]
pub hidden_layer_2_size: usize,
#[config(default = 64)]
pub hidden_layer_3_size: usize,
#[config(default = 0.99)]
pub gamma: f32,
}
impl DQNModelConfig {
pub fn init<B: Backend>(&self, device: &B::Device) -> DQNModel<B> {
let input_layer = LinearConfig::new(FEATRUE_SPACE, self.hidden_layer_1_size)
.with_bias(true)
.init(device);
let hidden_layer_1 = LinearConfig::new(self.hidden_layer_1_size, self.hidden_layer_2_size)
.with_bias(true)
.init(device);
let hidden_layer_2 = LinearConfig::new(self.hidden_layer_2_size, self.hidden_layer_3_size)
.with_bias(true)
.init(device);
let output_layer = LinearConfig::new(self.hidden_layer_3_size, ACTION_SPACE)
.with_bias(true)
.init(device);
DQNModel {
input_layer,
hidden_layer_1,
hidden_layer_2,
output_layer,
loss_function: HuberLossConfig::new(1.34).init(device),
activation: Relu::new(),
gamma: self.gamma,
}
}
}
impl<B: Backend> DQNModel<B> {
pub fn forward(&self, input: Tensor<B, 2>) -> Tensor<B, 2> {
let x = input.detach();
let x = self.input_layer.forward(x);
let x = self.activation.forward(x);
self.output_layer.forward(x)
}
pub fn forward_step(&self, item: TankBatch<B>) -> DQNOutput<B> {
let estimated_reward = self.forward(item.new_state);
// FIXME: magic unsqueeze
let a = item.action.unsqueeze_dim(1);
let x = estimated_reward.clone().gather(1, a);
// FIXME: what's final mask
let expected_reward = self.forward(item.old_state);
let y = expected_reward.clone().max_dim(1);
let y = y.mul_scalar(self.gamma).add(item.reward.unsqueeze());
let loss = self.loss_function.forward(x, y, Mean);
DQNOutput {
estimated_reward,
expected_reward,
loss,
}
}
}
impl<B: AutodiffBackend> TrainStep<TankBatch<B>, DQNOutput<B>> for DQNModel<B> {
fn step(&self, item: TankBatch<B>) -> TrainOutput<DQNOutput<B>> {
let item = self.forward_step(item);
TrainOutput::new(self, item.loss.backward(), item)
}
}
impl<B: Backend> ValidStep<TankBatch<B>, DQNOutput<B>> for DQNModel<B> {
fn step(&self, item: TankBatch<B>) -> DQNOutput<B> {
self.forward_step(item)
}
}

View File

@ -0,0 +1,103 @@
use std::{env, fs};
use super::dataset::{TankBatcher, TankDataset};
use super::model::DQNModelConfig;
use crate::ARTIFACT_DIR;
use burn::optim::AdamConfig;
use burn::train::metric::store::{Aggregate, Direction, Split};
use burn::train::{MetricEarlyStoppingStrategy, StoppingCondition};
use burn::{
data::{dataloader::DataLoaderBuilder, dataset::Dataset},
prelude::*,
record::{CompactRecorder, NoStdTrainingRecorder},
tensor::backend::AutodiffBackend,
train::{metric::LossMetric, LearnerBuilder},
};
#[derive(Config)]
pub struct ExpConfig {
#[config(default = 16)]
pub num_epochs: usize,
#[config(default = 6)]
pub num_workers: usize,
#[config(default = 47)]
pub seed: u64,
pub optimizer: AdamConfig,
#[config(default = 1.5e-3)]
pub learn_rate: f64,
#[config(default = 4096)]
pub batch_size: usize,
}
pub fn run<B>(device: B::Device)
where
B: AutodiffBackend,
{
let model_path = env::var("MODEL_PATH").unwrap_or_else(|_| ARTIFACT_DIR.to_string());
let optimizer = AdamConfig::new();
let config = ExpConfig::new(optimizer);
let mut model = DQNModelConfig::new().init(&device);
if fs::metadata(format!("{model_path}/model")).is_ok() {
model = model
.load_file(
format!("{model_path}/model"),
&NoStdTrainingRecorder::new(),
&device,
)
.unwrap();
}
B::seed(config.seed);
let (train_dataset, test_dataset) = TankDataset::load().split(1.0);
println!("Train Dataset Size: {}", train_dataset.len());
println!("Test Dataset Size: {}", test_dataset.len());
let batcher_train = TankBatcher::<B>::new(device.clone());
let batcher_test = TankBatcher::<B::InnerBackend>::new(device.clone());
let dataloader_train = DataLoaderBuilder::new(batcher_train)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(train_dataset);
let dataloader_test = DataLoaderBuilder::new(batcher_test)
.batch_size(config.batch_size)
.shuffle(config.seed)
.num_workers(config.num_workers)
.build(test_dataset);
let learner = LearnerBuilder::new(&model_path)
.metric_train_numeric(LossMetric::new())
.metric_valid_numeric(LossMetric::new())
.with_file_checkpointer(CompactRecorder::new())
.early_stopping(MetricEarlyStoppingStrategy::new::<LossMetric<B>>(
Aggregate::Mean,
Direction::Lowest,
Split::Train,
StoppingCondition::NoImprovementSince { n_epochs: 2 },
))
.devices(vec![device.clone()])
.num_epochs(config.num_epochs)
.summary()
.build(model, config.optimizer.init(), config.learn_rate);
let model_trained = learner.fit(dataloader_train, dataloader_test);
config
.save(format!("{model_path}/config.json").as_str())
.unwrap();
model_trained
.save_file(format!("{model_path}/model"), &NoStdTrainingRecorder::new())
.expect("Failed to save trained model");
}

View File

@ -0,0 +1,30 @@
#[repr(i32)]
#[derive(Debug, Copy, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub enum Action {
Forward = 0,
Backward = 1,
TurnRight = 2,
TurnLeft = 3,
AimRight = 4,
AimLeft = 5,
Shoot = 6,
None = 7,
}
impl TryFrom<i32> for Action {
type Error = &'static str;
fn try_from(value: i32) -> Result<Self, Self::Error> {
match value {
0 => Ok(Action::Forward),
1 => Ok(Action::Backward),
2 => Ok(Action::TurnRight),
3 => Ok(Action::TurnLeft),
4 => Ok(Action::AimRight),
5 => Ok(Action::AimLeft),
6 => Ok(Action::Shoot),
7 => Ok(Action::None),
_ => Err("Invalid action"),
}
}
}

View File

@ -0,0 +1,6 @@
#[repr(C)]
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct Bullet {
pub x: i32,
pub y: i32,
}

View File

@ -0,0 +1,58 @@
mod bullet;
mod player;
mod station;
mod wall;
pub use bullet::*;
pub use player::*;
pub use station::*;
pub use wall::*;
#[derive(Debug, Clone, Default)]
pub struct Info<'a> {
pub player: Player,
pub teammates: &'a [Player],
pub enemies: &'a [Player],
pub bullets: &'a [Bullet],
pub bullet_stations: &'a [Station],
pub oil_stations: &'a [Station],
pub walls: &'a [Wall],
}
#[repr(C)]
pub struct RawInfo {
player: *const Player,
teammates: *const Player,
teammates_len: u32,
enemies: *const Player,
enemies_len: u32,
bullets: *const Bullet,
bullet_len: u32,
bullet_stations: *const Station,
bullet_stations_len: u32,
oil_stations: *const Station,
oil_stations_len: u32,
walls: *const Wall,
walls_len: u32,
}
impl<'a> Info<'a> {
pub unsafe fn from_raw(self_: *const RawInfo) -> Self {
let raw = &*self_;
Info {
player: (&*raw.player).clone(),
teammates: std::slice::from_raw_parts(raw.teammates, raw.teammates_len as usize),
enemies: std::slice::from_raw_parts(raw.enemies, raw.enemies_len as usize),
bullets: std::slice::from_raw_parts(raw.bullets, raw.bullet_len as usize),
bullet_stations: std::slice::from_raw_parts(
raw.bullet_stations,
raw.bullet_stations_len as usize,
),
oil_stations: std::slice::from_raw_parts(
raw.oil_stations,
raw.oil_stations_len as usize,
),
walls: std::slice::from_raw_parts(raw.walls, raw.walls_len as usize),
}
}
}

View File

@ -0,0 +1,14 @@
#[repr(C)]
#[derive(Debug, Clone, serde::Deserialize, serde::Serialize, Default)]
pub struct Player {
pub x: i32,
pub y: i32,
pub speed: i32,
pub score: i32,
pub power: i32,
pub oil: f32,
pub lives: i32,
pub angle: i32,
pub gun_angle: i32,
pub cooldown: i32,
}

View File

@ -0,0 +1,7 @@
#[repr(C)]
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize, Default)]
pub struct Station {
pub x: i32,
pub y: i32,
pub power: i32,
}

View File

@ -0,0 +1,7 @@
#[repr(C)]
#[derive(Debug, Clone, PartialEq, Eq, serde::Deserialize, serde::Serialize)]
pub struct Wall {
pub x: i32,
pub y: i32,
pub lives: i32,
}

7
tank-rust/src/ffi/mod.rs Normal file
View File

@ -0,0 +1,7 @@
mod action;
mod info;
pub mod prelude {
pub use super::action::*;
pub use super::info::*;
}

77
tank-rust/src/fit.rs Normal file
View File

@ -0,0 +1,77 @@
use burn::data::dataset::Dataset;
use crate::dqn::prelude::TankItem;
use crate::ffi::prelude::*;
use rand::Rng;
// fn random_action() -> Action {
// let mut rng = rand::thread_rng();
// match rng.gen_range(0..2) {
// 0 => Action::AimLeft,
// 1 => Action::Forward,
// _ => unreachable!(),
// }
// }
// fn random_item() -> TankItem {
// let mut previous_info=Info::default();
// TankItem {
// previous_state: todo!(),
// new_state: todo!(),
// action: todo!(),
// reward: todo!(),
// }
// }
pub struct FitDataset;
impl FitDataset {
/// Get closer to the power station
fn close_power_station() -> TankItem {
let mut power_stations = Station::default();
let mut previous_info = Info::default();
let mut new_info = Info::default();
let mut rng = rand::thread_rng();
previous_info.player.power = rng.gen_range(0..2);
new_info.player.power = previous_info.player.power;
previous_info.player.angle = rng.gen_range(0..360);
new_info.player.angle = previous_info.player.angle;
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Forward,
reward: todo!(),
}
}
/// Flee from power station if power is high
fn flee_power_station() -> TankItem {
let mut previous_info = Info::default();
TankItem {
previous_state: todo!(),
new_state: todo!(),
action: Action::Backward,
reward: todo!(),
}
}
}
impl Dataset<TankItem> for FitDataset {
fn get(&self, _: usize) -> Option<TankItem> {
let previous_state = todo!();
let new_state = todo!();
let action = Action::AimLeft;
let reward = 0.0;
Some(TankItem {
previous_state,
new_state,
action,
reward,
})
}
fn len(&self) -> usize {
1
}
}

44
tank-rust/src/lib.rs Normal file
View File

@ -0,0 +1,44 @@
mod dqn;
mod ffi;
use std::{ffi::OsStr, os::unix::ffi::OsStrExt};
use burn::backend::{wgpu::AutoGraphicsApi, Wgpu};
use dqn::prelude::*;
use ffi::prelude::*;
static ARTIFACT_DIR: &str = "../output";
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
#[no_mangle]
pub extern "C" fn init(model_path: *const u8, len: i32) -> *mut DQNApp<'static> {
let model_path =
OsStr::from_bytes(unsafe { std::slice::from_raw_parts(model_path, len as usize) })
.to_str()
.unwrap();
let app = DQNApp::new(model_path);
Box::into_raw(Box::new(app))
}
#[no_mangle]
pub extern "C" fn tick(app: *mut DQNApp, raw: *mut RawInfo) -> i32 {
let app = unsafe { &mut *app };
let info: Info<'static> = unsafe { Info::from_raw(raw) };
cfg_if::cfg_if! {
if #[cfg(feature = "train")]{
app.collect_data(&info)as i32
}else{
app.predict_action(&info)as i32
}
}
}
#[no_mangle]
pub extern "C" fn flush(app: *mut DQNApp) {
let app = unsafe { &mut *app };
#[cfg(feature = "train")]
{
app.flush();
}
}

14
tank-rust/src/main.rs Normal file
View File

@ -0,0 +1,14 @@
mod dqn;
mod ffi;
use burn::backend::{wgpu::AutoGraphicsApi, Autodiff, Wgpu};
static ARTIFACT_DIR: &str = "../output";
type Backend = Wgpu<AutoGraphicsApi, f32, i32>;
type AutodiffBackend = Autodiff<Backend>;
pub fn main() {
let device = burn::backend::wgpu::WgpuDevice::default();
dqn::prelude::train::<AutodiffBackend>(device);
}