use std::{ env::{self}, fs::File, io::BufReader, }; use crate::ffi::prelude::*; use crate::ARTIFACT_DIR; use burn::{ data::{dataloader::batcher::Batcher, dataset::Dataset}, prelude::*, }; use super::feature::FEATRUE_SPACE; #[derive(Clone, Debug, serde::Serialize, serde::Deserialize)] pub struct TankItem { pub previous_state: [f32; FEATRUE_SPACE], pub new_state: [f32; FEATRUE_SPACE], pub action: Action, pub reward: f32, } pub struct TankDataset { pub dataset: Vec, } impl Dataset for TankDataset { fn get(&self, index: usize) -> Option { self.dataset.get(index).cloned() } fn len(&self) -> usize { self.dataset.len() } } impl TankDataset { pub fn new() -> Self { Self { dataset: Vec::new(), } } pub fn load() -> Self { let dataset_path = env::var("MODEL_PATH").unwrap_or_else(|_| ARTIFACT_DIR.to_string()); let mut dataset = Vec::new(); println!("Loading dataset from: {}", dataset_path); if let Ok(reader) = File::open(format!("{dataset_path}/dataset")) { let mut reader = BufReader::new(reader); while let Ok(item) = bincode::deserialize_from::<_, TankItem>(&mut reader) { dataset.push(item); } } TankDataset { dataset } } pub fn add(&mut self, item: TankItem) { self.dataset.push(item); } pub fn split(self, ratio: f32) -> (Self, Self) { let split = (self.dataset.len() as f32 * ratio).round() as usize; let (a, b) = self.dataset.split_at(split); ( TankDataset { dataset: a.to_vec(), }, TankDataset { dataset: b.to_vec(), }, ) } } #[derive(Clone, Debug)] pub struct TankBatcher { device: B::Device, } #[derive(Clone, Debug)] pub struct TankBatch { pub new_state: Tensor, pub old_state: Tensor, pub action: Tensor, pub reward: Tensor, } impl TankBatcher { pub fn new(device: B::Device) -> Self { Self { device } } } impl Batcher> for TankBatcher { fn batch(&self, items: Vec) -> TankBatch { let mut new_state: Vec> = Vec::new(); let mut old_state: Vec> = Vec::new(); for item in items.iter() { let new_state_tensor = Tensor::::from_floats(item.previous_state, &self.device); let old_state_tensor = Tensor::::from_floats(item.new_state, &self.device); new_state.push(new_state_tensor.unsqueeze()); old_state.push(old_state_tensor.unsqueeze()); } let new_state = Tensor::cat(new_state, 0); let old_state = Tensor::cat(old_state, 0); let reward = items .iter() .map(|item| Tensor::::from_floats([item.reward], &self.device)) .collect(); let reward = Tensor::cat(reward, 0); let response = items .iter() .map(|item| Tensor::::from_ints([item.action as i32], &self.device)) .collect(); let response = Tensor::cat(response, 0); TankBatch { new_state, old_state, reward, action: response, } } }