transfer to deploy server
This commit is contained in:
parent
96d458597e
commit
1a777943ff
|
@ -1,3 +1,7 @@
|
||||||
|
# model
|
||||||
|
/*.save
|
||||||
|
/*.bin
|
||||||
|
|
||||||
*.env
|
*.env
|
||||||
*.swp
|
*.swp
|
||||||
*.pyproj
|
*.pyproj
|
||||||
|
|
|
@ -0,0 +1,7 @@
|
||||||
|
exploration_rate = 1224
|
||||||
|
update_frequency = 60
|
||||||
|
batch_size = 48
|
||||||
|
replay_size = 300
|
||||||
|
learning_rate = 0.01
|
||||||
|
gamma = 0.97
|
||||||
|
train = false
|
|
@ -0,0 +1,9 @@
|
||||||
|
test:
|
||||||
|
python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3
|
||||||
|
build:
|
||||||
|
cd pyr && cargo build --release
|
||||||
|
train level:
|
||||||
|
run level:
|
||||||
|
python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3
|
||||||
|
clean:
|
||||||
|
rm -r model.bin
|
File diff suppressed because it is too large
Load Diff
|
@ -0,0 +1,24 @@
|
||||||
|
[package]
|
||||||
|
name = "pyr"
|
||||||
|
version = "0.1.0"
|
||||||
|
edition = "2021"
|
||||||
|
|
||||||
|
[lib]
|
||||||
|
name = "pyr"
|
||||||
|
crate-type = ["cdylib"]
|
||||||
|
|
||||||
|
[profile.release]
|
||||||
|
strip = true
|
||||||
|
opt-level = 2
|
||||||
|
# lto = true
|
||||||
|
|
||||||
|
[dependencies]
|
||||||
|
smol = "2.0.0"
|
||||||
|
log = "0.4.21"
|
||||||
|
simple_logger = "4.3.3"
|
||||||
|
lazy_static = "1.4.0"
|
||||||
|
candle-nn = "0.4.1"
|
||||||
|
candle-core = "0.4.1"
|
||||||
|
rand = "0.8.5"
|
||||||
|
toml = "0.8.12"
|
||||||
|
serde = {version = "1.0.198", features = ["derive"]}
|
|
@ -0,0 +1,22 @@
|
||||||
|
use crate::Direction;
|
||||||
|
|
||||||
|
#[derive(PartialEq, Eq, Hash, Clone)]
|
||||||
|
pub enum AIAction {
|
||||||
|
Up,
|
||||||
|
Down,
|
||||||
|
Left,
|
||||||
|
Right,
|
||||||
|
None,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<AIAction> for Direction {
|
||||||
|
fn from(value: AIAction) -> Self {
|
||||||
|
match value {
|
||||||
|
AIAction::Up => Direction::Up,
|
||||||
|
AIAction::Down => Direction::Down,
|
||||||
|
AIAction::Left => Direction::Left,
|
||||||
|
AIAction::Right => Direction::Right,
|
||||||
|
AIAction::None => Direction::None,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,202 @@
|
||||||
|
use std::collections::VecDeque;
|
||||||
|
use std::path::Path;
|
||||||
|
|
||||||
|
use rand::distributions::Uniform;
|
||||||
|
use rand::{thread_rng, Rng};
|
||||||
|
|
||||||
|
use candle_core::{DType, Device, Module, Tensor};
|
||||||
|
use candle_nn::{linear, seq, Activation, AdamW, Optimizer, Sequential, VarBuilder, VarMap};
|
||||||
|
|
||||||
|
use crate::CONFIG;
|
||||||
|
|
||||||
|
use super::state::OBSERVATION_SPACE;
|
||||||
|
use super::{action::AIAction, huber::huber_loss, state::AIState};
|
||||||
|
|
||||||
|
const DEVICE: Device = Device::Cpu;
|
||||||
|
|
||||||
|
const ACTION_SPACE: usize = 5;
|
||||||
|
|
||||||
|
pub struct AIAgent {
|
||||||
|
var_map: VarMap,
|
||||||
|
model: Sequential,
|
||||||
|
optimizer: AdamW,
|
||||||
|
memory: VecDeque<(Tensor, u32, Tensor, f64)>,
|
||||||
|
old_state: Option<AIState>,
|
||||||
|
step: usize,
|
||||||
|
accumulate_rewards: f64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl AIAgent {
|
||||||
|
pub async fn new() -> Self {
|
||||||
|
let mut var_map = VarMap::new();
|
||||||
|
if Path::new("model.bin").exists() {
|
||||||
|
var_map.load("model.bin").unwrap();
|
||||||
|
}
|
||||||
|
let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE);
|
||||||
|
let model = seq()
|
||||||
|
.add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap())
|
||||||
|
.add(Activation::LeakyRelu(0.01))
|
||||||
|
.add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
|
||||||
|
.add(Activation::LeakyRelu(0.01))
|
||||||
|
.add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
|
||||||
|
.add(Activation::LeakyRelu(0.01))
|
||||||
|
.add(linear(48, ACTION_SPACE, vb.pp("linear_out")).unwrap())
|
||||||
|
.add(Activation::LeakyRelu(0.01));
|
||||||
|
|
||||||
|
let optimizer = AdamW::new_lr(var_map.all_vars(), CONFIG.learning_rate).unwrap();
|
||||||
|
|
||||||
|
Self {
|
||||||
|
var_map,
|
||||||
|
model,
|
||||||
|
optimizer,
|
||||||
|
memory: VecDeque::new(),
|
||||||
|
old_state: None,
|
||||||
|
step: 0,
|
||||||
|
accumulate_rewards: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn get_reward(&self, new_state: &AIState) -> f64 {
|
||||||
|
let old_state = self.old_state.as_ref().unwrap();
|
||||||
|
let new_positive_distance = new_state
|
||||||
|
.get_postivie_food()
|
||||||
|
.map(|food| food.x + food.y)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
let old_positive_distance = old_state
|
||||||
|
.get_postivie_food()
|
||||||
|
.map(|food| food.x + food.y)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
let new_negative_distance = new_state
|
||||||
|
.get_negative_food()
|
||||||
|
.map(|food| food.x + food.y)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
let old_negative_distance = old_state
|
||||||
|
.get_negative_food()
|
||||||
|
.map(|food| food.x + food.y)
|
||||||
|
.unwrap_or(0.0);
|
||||||
|
|
||||||
|
return (old_positive_distance - new_positive_distance) as f64
|
||||||
|
+ (new_negative_distance - old_negative_distance) as f64
|
||||||
|
+ 100.0*(new_state.player.score - old_state.player.score) as f64;
|
||||||
|
}
|
||||||
|
pub fn tick(&mut self, state: AIState) -> AIAction {
|
||||||
|
self.step += 1;
|
||||||
|
if self.old_state.is_none() {
|
||||||
|
self.old_state = Some(state);
|
||||||
|
return AIAction::None;
|
||||||
|
}
|
||||||
|
let old_state = self.old_state.as_ref().unwrap();
|
||||||
|
|
||||||
|
let action: u32 = match thread_rng().gen_ratio(CONFIG.exploration_rate, 4096) {
|
||||||
|
true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)),
|
||||||
|
_ => self
|
||||||
|
.model
|
||||||
|
.forward(&old_state.into_tensor())
|
||||||
|
.unwrap()
|
||||||
|
.squeeze(0)
|
||||||
|
.unwrap()
|
||||||
|
.argmax(0)
|
||||||
|
.unwrap()
|
||||||
|
.to_scalar()
|
||||||
|
.unwrap(),
|
||||||
|
};
|
||||||
|
|
||||||
|
if CONFIG.train {
|
||||||
|
let reward = self.get_reward(&state);
|
||||||
|
self.accumulate_rewards += reward;
|
||||||
|
|
||||||
|
self.memory.push_front((
|
||||||
|
self.old_state
|
||||||
|
.as_ref()
|
||||||
|
.unwrap()
|
||||||
|
.into_tensor()
|
||||||
|
.squeeze(0)
|
||||||
|
.unwrap(),
|
||||||
|
action,
|
||||||
|
state.into_tensor().squeeze(0).unwrap(),
|
||||||
|
reward,
|
||||||
|
));
|
||||||
|
self.memory.truncate(CONFIG.replay_size);
|
||||||
|
if self.step % CONFIG.update_frequency == 0 && self.memory.len() > CONFIG.batch_size {
|
||||||
|
self.train();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
self.old_state = Some(state);
|
||||||
|
|
||||||
|
match action {
|
||||||
|
0 => AIAction::None,
|
||||||
|
1 => AIAction::Up,
|
||||||
|
2 => AIAction::Left,
|
||||||
|
3 => AIAction::Right,
|
||||||
|
_ => AIAction::Down,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fn train(&mut self) {
|
||||||
|
// Sample randomly from the memory.
|
||||||
|
let batch = thread_rng()
|
||||||
|
.sample_iter(Uniform::from(0..self.memory.len()))
|
||||||
|
.take(CONFIG.batch_size)
|
||||||
|
.map(|i| self.memory.get(i).unwrap().clone())
|
||||||
|
.collect::<Vec<_>>();
|
||||||
|
|
||||||
|
// Group all the samples together into tensors with the appropriate shape.
|
||||||
|
let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect();
|
||||||
|
let states = Tensor::stack(&states, 0).unwrap();
|
||||||
|
|
||||||
|
let actions = batch.iter().map(|e| e.1);
|
||||||
|
let actions = Tensor::from_iter(actions, &DEVICE)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(1)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect();
|
||||||
|
let next_states = Tensor::stack(&next_states, 0).unwrap();
|
||||||
|
|
||||||
|
let rewards = batch.iter().map(|e| e.3 as f32);
|
||||||
|
let rewards = Tensor::from_iter(rewards, &DEVICE)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(1)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
let non_final_mask = batch.iter().map(|_| true as u8 as f32);
|
||||||
|
let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE)
|
||||||
|
.unwrap()
|
||||||
|
.unsqueeze(1)
|
||||||
|
.unwrap();
|
||||||
|
|
||||||
|
// Get the estimated rewards for the actions that where taken at each step.
|
||||||
|
let estimated_rewards = self.model.forward(&states).unwrap();
|
||||||
|
let x = estimated_rewards.gather(&actions, 1).unwrap();
|
||||||
|
|
||||||
|
// Get the maximum expected rewards for the next state, apply them a discount rate
|
||||||
|
// GAMMA and add them to the rewards that were actually gathered on the current state.
|
||||||
|
// If the next state is a terminal state, just omit maximum estimated
|
||||||
|
// rewards for that state.
|
||||||
|
let expected_rewards = self.model.forward(&next_states).unwrap().detach();
|
||||||
|
let y = expected_rewards.max_keepdim(1).unwrap();
|
||||||
|
let y = (y * CONFIG.gamma * non_final_mask + rewards).unwrap();
|
||||||
|
|
||||||
|
// Compare the estimated rewards with the maximum expected rewards and
|
||||||
|
// perform the backward step.
|
||||||
|
let loss = huber_loss(1.0_f32)(&x, &y);
|
||||||
|
log::trace!("loss: {:?}", loss);
|
||||||
|
self.optimizer
|
||||||
|
.backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
||||||
|
.unwrap();
|
||||||
|
}
|
||||||
|
pub fn check_point(&mut self) {
|
||||||
|
self.memory.clear();
|
||||||
|
if CONFIG.train {
|
||||||
|
self.var_map.save("model.bin").unwrap();
|
||||||
|
log::info!("model.bin saved!");
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// impl Drop for AIAgent {
|
||||||
|
// fn drop(&mut self) {
|
||||||
|
// self.var_map.save("model.bin").unwrap();
|
||||||
|
// log::info!("model.bin saved!");
|
||||||
|
// log::info!("Rewards {}", self.accumulate_rewards as i64);
|
||||||
|
// }
|
||||||
|
// }
|
|
@ -0,0 +1,32 @@
|
||||||
|
use candle_core::{Tensor, WithDType};
|
||||||
|
|
||||||
|
pub trait Half
|
||||||
|
where
|
||||||
|
Self: WithDType + Copy,
|
||||||
|
{
|
||||||
|
const HALF: Self;
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Half for f64 {
|
||||||
|
const HALF: f64 = 0.5;
|
||||||
|
}
|
||||||
|
impl Half for f32 {
|
||||||
|
const HALF: f32 = 0.5;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub fn huber_loss<D: WithDType + Half>(threshold: D) -> impl Fn(&Tensor, &Tensor) -> D {
|
||||||
|
move |x: &Tensor, y: &Tensor| {
|
||||||
|
let diff = (x - y).unwrap();
|
||||||
|
let diff_scaler = diff
|
||||||
|
.abs()
|
||||||
|
.unwrap()
|
||||||
|
.sum_all()
|
||||||
|
.unwrap()
|
||||||
|
.to_scalar::<D>()
|
||||||
|
.unwrap();
|
||||||
|
match diff_scaler < threshold {
|
||||||
|
true => <D as Half>::HALF * diff_scaler,
|
||||||
|
false => threshold * (diff_scaler - <D as Half>::HALF * threshold),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,40 @@
|
||||||
|
mod action;
|
||||||
|
mod agent;
|
||||||
|
mod huber;
|
||||||
|
mod state;
|
||||||
|
|
||||||
|
use smol::block_on;
|
||||||
|
|
||||||
|
use crate::data::prelude::*;
|
||||||
|
|
||||||
|
use self::agent::AIAgent;
|
||||||
|
|
||||||
|
pub struct TickState {
|
||||||
|
pub frame: u64,
|
||||||
|
pub player: Player,
|
||||||
|
pub opponent: Opponent,
|
||||||
|
pub foods: Vec<Food>,
|
||||||
|
}
|
||||||
|
|
||||||
|
struct AppState {}
|
||||||
|
|
||||||
|
pub struct App {
|
||||||
|
state: AppState,
|
||||||
|
agent: AIAgent,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl App {
|
||||||
|
pub fn new() -> Self {
|
||||||
|
let agent = block_on(AIAgent::new());
|
||||||
|
Self {
|
||||||
|
state: AppState {},
|
||||||
|
agent,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn run(&mut self, tick: TickState) -> Direction {
|
||||||
|
self.agent.tick(tick.into()).into()
|
||||||
|
}
|
||||||
|
pub fn check_point(&mut self) {
|
||||||
|
self.agent.check_point();
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,109 @@
|
||||||
|
use candle_core::{Device, Tensor};
|
||||||
|
|
||||||
|
use crate::{Food, Opponent, Player};
|
||||||
|
|
||||||
|
use super::TickState;
|
||||||
|
|
||||||
|
pub const OBSERVATION_SPACE: usize = 14;
|
||||||
|
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct AIState {
|
||||||
|
pub frame: u64,
|
||||||
|
pub player: Player,
|
||||||
|
pub opponent: Opponent,
|
||||||
|
pub foods: Vec<Food>,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<TickState> for AIState {
|
||||||
|
fn from(value: TickState) -> Self {
|
||||||
|
Self {
|
||||||
|
player: value.player,
|
||||||
|
opponent: value.opponent,
|
||||||
|
foods: value.foods,
|
||||||
|
frame: value.frame,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
fn food_distance<'a>(player: &'a Player) -> impl FnMut(&&Food) -> i32 + 'a {
|
||||||
|
move |food: &&Food| {
|
||||||
|
let dx = player.x - food.x;
|
||||||
|
let dy = player.y - food.y;
|
||||||
|
((dx + dy) * 100.0) as i32
|
||||||
|
}
|
||||||
|
}
|
||||||
|
impl AIState {
|
||||||
|
pub fn get_postivie_food(&self) -> Option<&Food> {
|
||||||
|
self.foods
|
||||||
|
.iter()
|
||||||
|
.filter(|x| x.score.is_sign_positive())
|
||||||
|
.min_by_key(food_distance(&self.player))
|
||||||
|
}
|
||||||
|
pub fn get_negative_food(&self) -> Option<&Food> {
|
||||||
|
self.foods
|
||||||
|
.iter()
|
||||||
|
.filter(|x| x.score.is_sign_negative())
|
||||||
|
.min_by_key(food_distance(&self.player))
|
||||||
|
}
|
||||||
|
pub fn into_tensor(&self) -> Tensor {
|
||||||
|
Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap()
|
||||||
|
}
|
||||||
|
fn into_feature(&self) -> [f32; OBSERVATION_SPACE] {
|
||||||
|
let x = self.player.x;
|
||||||
|
let y = self.player.y;
|
||||||
|
// sort food into four group by two line (x+y=0, x-y=0)
|
||||||
|
let mut food_group = [
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
self.opponent.x - self.player.x / 700.0,
|
||||||
|
self.opponent.y - self.player.y / 700.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
0.0,
|
||||||
|
];
|
||||||
|
for food in self.foods.iter().filter(|x| x.score.is_sign_positive()) {
|
||||||
|
let dx = food.x - x;
|
||||||
|
let dy = food.y - y;
|
||||||
|
let group = match (dx + dy, dx - dy) {
|
||||||
|
(a, b) if a.is_sign_positive() && b.is_sign_positive() => 0,
|
||||||
|
(a, b) if a.is_sign_positive() && b.is_sign_positive() => 1,
|
||||||
|
(a, b) if a.is_sign_negative() && b.is_sign_negative() => 2,
|
||||||
|
_ => 3,
|
||||||
|
};
|
||||||
|
food_group[group] += 10.0 / (dx + dy);
|
||||||
|
}
|
||||||
|
for food in self.foods.iter().filter(|x| x.score.is_sign_negative()) {
|
||||||
|
let dx = food.x - x;
|
||||||
|
let dy = food.y - y;
|
||||||
|
let group = match (dx + dy, dx - dy) {
|
||||||
|
(a, b) if a.is_sign_positive() && b.is_sign_positive() => 6,
|
||||||
|
(a, b) if a.is_sign_positive() && b.is_sign_positive() => 7,
|
||||||
|
(a, b) if a.is_sign_negative() && b.is_sign_negative() => 8,
|
||||||
|
_ => 9,
|
||||||
|
};
|
||||||
|
food_group[group] += 10.0 / (dx + dy);
|
||||||
|
}
|
||||||
|
self.get_postivie_food().map(|food| {
|
||||||
|
let dx = food.x - x;
|
||||||
|
let dy = food.y - y;
|
||||||
|
food_group[10] = dx as f32;
|
||||||
|
food_group[11] = dy as f32;
|
||||||
|
});
|
||||||
|
self.get_negative_food().map(|food| {
|
||||||
|
let dx = food.x - x;
|
||||||
|
let dy = food.y - y;
|
||||||
|
food_group[12] = dx as f32;
|
||||||
|
food_group[13] = dy as f32;
|
||||||
|
});
|
||||||
|
|
||||||
|
|
||||||
|
food_group
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,35 @@
|
||||||
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
|
lazy_static::lazy_static! {
|
||||||
|
pub static ref CONFIG: Config = {
|
||||||
|
match std::fs::read_to_string("config.toml"){
|
||||||
|
Ok(content)=>toml::from_str(&content).unwrap(),
|
||||||
|
Err(_)=>Config::default()
|
||||||
|
}
|
||||||
|
};
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Serialize, Deserialize)]
|
||||||
|
pub struct Config {
|
||||||
|
pub exploration_rate: u32,
|
||||||
|
pub update_frequency: usize,
|
||||||
|
pub batch_size: usize,
|
||||||
|
pub replay_size: usize,
|
||||||
|
pub learning_rate: f64,
|
||||||
|
pub gamma: f64,
|
||||||
|
pub train: bool,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Config {
|
||||||
|
fn default() -> Self {
|
||||||
|
Self {
|
||||||
|
exploration_rate: 1024,
|
||||||
|
update_frequency: 150,
|
||||||
|
batch_size: 32,
|
||||||
|
replay_size: 250,
|
||||||
|
learning_rate: 0.04,
|
||||||
|
gamma: 0.99,
|
||||||
|
train: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,37 @@
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Player {
|
||||||
|
pub x: f32,
|
||||||
|
pub y: f32,
|
||||||
|
pub height: f32,
|
||||||
|
pub width: f32,
|
||||||
|
pub level: f32,
|
||||||
|
pub velocity: f32,
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
#[derive(Clone)]
|
||||||
|
pub struct Opponent {
|
||||||
|
pub x: f32,
|
||||||
|
pub y: f32,
|
||||||
|
pub level: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[derive(Clone, Debug)]
|
||||||
|
pub struct Food {
|
||||||
|
pub x: f32,
|
||||||
|
pub y: f32,
|
||||||
|
pub width: f32,
|
||||||
|
pub height: f32,
|
||||||
|
pub score: f32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl Default for Food {
|
||||||
|
fn default() -> Self {
|
||||||
|
Food {
|
||||||
|
x: 1000000.0,
|
||||||
|
y: 1000000.0,
|
||||||
|
width: 1.0,
|
||||||
|
height: 1.0,
|
||||||
|
score: 0.0,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
|
@ -0,0 +1,13 @@
|
||||||
|
mod config;
|
||||||
|
mod internal;
|
||||||
|
mod raw;
|
||||||
|
|
||||||
|
pub mod parser {
|
||||||
|
pub use super::config::CONFIG;
|
||||||
|
pub use super::raw::*;
|
||||||
|
}
|
||||||
|
|
||||||
|
pub mod prelude {
|
||||||
|
pub use super::internal::*;
|
||||||
|
pub use super::raw::Direction;
|
||||||
|
}
|
|
@ -0,0 +1,82 @@
|
||||||
|
use super::internal::*;
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub struct RawOverall {
|
||||||
|
pub frame: u64,
|
||||||
|
score: i64,
|
||||||
|
score_to_pass: i64,
|
||||||
|
self_x: i64,
|
||||||
|
self_y: i64,
|
||||||
|
self_h: i64,
|
||||||
|
self_w: i64,
|
||||||
|
self_vel: i64,
|
||||||
|
self_lv: i64,
|
||||||
|
opponent_x: i64,
|
||||||
|
opponent_y: i64,
|
||||||
|
opponent_lv: i64,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl RawOverall {
|
||||||
|
pub fn get_player(&self) -> Player {
|
||||||
|
Player {
|
||||||
|
x: (self.self_x - 350) as f32,
|
||||||
|
y: (self.self_y - 350) as f32,
|
||||||
|
height: self.self_h as f32,
|
||||||
|
width: self.self_w as f32,
|
||||||
|
level: self.self_lv as f32,
|
||||||
|
velocity: self.self_vel as f32,
|
||||||
|
score: self.score as f32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pub fn get_opponent(&self) -> Opponent {
|
||||||
|
Opponent {
|
||||||
|
x: (self.opponent_x - 350) as f32,
|
||||||
|
y: (self.opponent_y - 350) as f32,
|
||||||
|
level: self.opponent_lv as f32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(C)]
|
||||||
|
#[derive(Debug, Clone)]
|
||||||
|
pub struct RawFood {
|
||||||
|
pub h: i64,
|
||||||
|
pub w: i64,
|
||||||
|
pub x: i64,
|
||||||
|
pub y: i64,
|
||||||
|
pub score: i64,
|
||||||
|
pub kind: i32,
|
||||||
|
}
|
||||||
|
|
||||||
|
impl From<RawFood> for Food {
|
||||||
|
fn from(value: RawFood) -> Self {
|
||||||
|
Food {
|
||||||
|
x: value.x as f32,
|
||||||
|
y: value.y as f32,
|
||||||
|
width: value.w as f32,
|
||||||
|
height: value.h as f32,
|
||||||
|
score: value.score as f32,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(i32)]
|
||||||
|
#[derive(Debug)]
|
||||||
|
pub enum FoodKind {
|
||||||
|
Food1 = 1,
|
||||||
|
Food2 = 2,
|
||||||
|
Food3 = 3,
|
||||||
|
Garbage1 = 4,
|
||||||
|
Garbage2 = 5,
|
||||||
|
Garbage3 = 6,
|
||||||
|
}
|
||||||
|
|
||||||
|
#[repr(i32)]
|
||||||
|
pub enum Direction {
|
||||||
|
Up = 1,
|
||||||
|
Down = 2,
|
||||||
|
Left = 3,
|
||||||
|
Right = 4,
|
||||||
|
None = 5,
|
||||||
|
}
|
|
@ -0,0 +1,53 @@
|
||||||
|
mod app;
|
||||||
|
mod data;
|
||||||
|
|
||||||
|
use std::slice;
|
||||||
|
|
||||||
|
use app::{App, TickState};
|
||||||
|
use data::parser::*;
|
||||||
|
use data::prelude::*;
|
||||||
|
use simple_logger::SimpleLogger;
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn tick(
|
||||||
|
app: *mut App,
|
||||||
|
overall: &RawOverall,
|
||||||
|
food: *mut RawFood,
|
||||||
|
len: u64,
|
||||||
|
) -> i32 {
|
||||||
|
let app = &mut *app;
|
||||||
|
|
||||||
|
let state = {
|
||||||
|
let foods: Vec<Food> = slice::from_raw_parts(food, len as usize)
|
||||||
|
.into_iter()
|
||||||
|
.map(|x| x.to_owned().into())
|
||||||
|
.collect();
|
||||||
|
TickState {
|
||||||
|
frame: overall.frame,
|
||||||
|
player: overall.get_player(),
|
||||||
|
opponent: overall.get_opponent(),
|
||||||
|
foods,
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
|
app.run(state) as i32
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn check_point(app: *mut App) {
|
||||||
|
let app = &mut *app;
|
||||||
|
app.check_point();
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn new_app() -> *const App {
|
||||||
|
SimpleLogger::new().init().unwrap();
|
||||||
|
log::info!("Initializing App...");
|
||||||
|
let a = Box::into_raw(Box::new(App::new()));
|
||||||
|
a
|
||||||
|
}
|
||||||
|
|
||||||
|
#[no_mangle]
|
||||||
|
pub unsafe extern "C" fn drop_app(app: *mut App) {
|
||||||
|
// drop(Box::from_raw(app))
|
||||||
|
}
|
|
@ -0,0 +1,25 @@
|
||||||
|
// use candle_core::{DType, Device};
|
||||||
|
// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap};
|
||||||
|
|
||||||
|
fn main() {
|
||||||
|
// let mut var_map = VarMap::new();
|
||||||
|
// var_map.load("model.bin").unwrap();
|
||||||
|
// let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu);
|
||||||
|
// let model = seq()
|
||||||
|
// .add(linear(14, 60, vb.pp("linear_in")).unwrap())
|
||||||
|
// .add(Activation::LeakyRelu(0.01))
|
||||||
|
// .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap())
|
||||||
|
// .add(Activation::LeakyRelu(0.01))
|
||||||
|
// .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap())
|
||||||
|
// .add(Activation::LeakyRelu(0.01))
|
||||||
|
// .add(linear(48, 5, vb.pp("linear_out")).unwrap())
|
||||||
|
// .add(Activation::LeakyRelu(0.01));
|
||||||
|
|
||||||
|
// let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap();
|
||||||
|
|
||||||
|
// let target = Tensor::new(&[0.0], &Device::Cpu).unwrap();
|
||||||
|
|
||||||
|
// self.optimizer
|
||||||
|
// .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap())
|
||||||
|
// .unwrap();
|
||||||
|
}
|
|
@ -0,0 +1,20 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
from ctypes import *
|
||||||
|
|
||||||
|
mylib = cdll.LoadLibrary("./target/release/libpyr.so")
|
||||||
|
|
||||||
|
class Point(Structure):
|
||||||
|
_fields_ = [("x", c_uint64), ("y", c_uint64)]
|
||||||
|
|
||||||
|
point=Point()
|
||||||
|
point.x=1
|
||||||
|
point.y=2
|
||||||
|
|
||||||
|
ptr=pointer(point)
|
||||||
|
|
||||||
|
print(ptr)
|
||||||
|
|
||||||
|
mylib.set_point(ptr)
|
||||||
|
|
||||||
|
print(point.x)
|
||||||
|
print(point.y)
|
Loading…
Reference in New Issue