From a10938d1f5750d40e0d5f267302dc8a2a706a0b2 Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:54:19 +0800 Subject: [PATCH] template backward.cl --- cl/backward.cl | 11 +++++++++++ src/layer.rs | 28 ++++++++++++++++++---------- src/main.rs | 4 ++-- src/state.rs | 17 ++++++++++++----- 4 files changed, 43 insertions(+), 17 deletions(-) diff --git a/cl/backward.cl b/cl/backward.cl index e69de29..898c660 100644 --- a/cl/backward.cl +++ b/cl/backward.cl @@ -0,0 +1,11 @@ +__kernel void backward(__global float* activate, +__global float* output, +__global float* delta, +__global float* input, +__global float* mul, +__global float* add, +int input_width, +int param_width) +{ + +} \ No newline at end of file diff --git a/src/layer.rs b/src/layer.rs index 2c71bda..cc7d089 100644 --- a/src/layer.rs +++ b/src/layer.rs @@ -1,11 +1,11 @@ use ocl::{ builders::{BufferBuilder, KernelBuilder}, - MemFlags, SpatialDims, + Buffer, SpatialDims, }; use serde::{Deserialize, Serialize}; use crate::{ - state::{Context, PROGRAM_FORWARD, PROGRAM_BACKWARD}, + state::{Context, PROGRAM_BACKWARD, PROGRAM_FORWARD}, Error, }; @@ -96,11 +96,11 @@ impl Layer { /// FIXME: we should use host memory instead device memory (EG. GPU) /// /// MEM_USE_HOST_PTR: use host memory, cache by device memory - pub fn forward(&mut self, state: &Context, activation: &ocl::Buffer) -> Result<(), Error> { + pub fn forward(&mut self, ctx: &Context, activation: &ocl::Buffer) -> Result<(), Error> { let kernel = KernelBuilder::new() - .queue(state.queue.clone()) + .queue(ctx.queue.clone()) .global_work_size(SpatialDims::One(self.inter)) - .program(&state.program[PROGRAM_FORWARD]) + .program(&ctx.program[PROGRAM_FORWARD]) .arg(&self.activate) .arg(&self.output) .arg(activation) @@ -116,17 +116,25 @@ impl Layer { Ok(()) } - /// forward pagination + /// forward pagination + /// + /// delta: da superscript [L], kernel is require to rewrite it to da superscript [L-1 ] /// /// FIXME: we should use host memory instead device memory (EG. GPU) /// /// MEM_USE_HOST_PTR: use host memory, cache by device memory - pub fn backward(&mut self, state: &Context) -> Result<(), Error> { - todo!(); + pub fn backward(&mut self, ctx: &Context, delta: &mut Buffer) -> Result<(), Error> { let kernel = KernelBuilder::new() - .queue(state.queue.clone()) + .queue(ctx.queue.clone()) .global_work_size(SpatialDims::One(self.inter)) - .program(&state.program[PROGRAM_BACKWARD]) + .program(&ctx.program[PROGRAM_BACKWARD]) + .arg(&self.activate) + .arg(&self.output) + .arg(delta) + .arg(&self.mul_conn) + .arg(&self.offset_conn) + .arg(self.input) + .arg(self.inter) .build()?; unsafe { diff --git a/src/main.rs b/src/main.rs index c5eaf23..cba9733 100644 --- a/src/main.rs +++ b/src/main.rs @@ -6,7 +6,7 @@ extern crate ocl; pub mod layer; pub mod state; -static SAVE:&str="save.bin"; +static SAVE: &str = "save.bin"; #[derive(Debug, thiserror::Error)] pub enum Error { @@ -20,7 +20,7 @@ pub enum Error { #[monoio::main] async fn main() { - let layers=match Path::new(SAVE).exists(){ + let layers = match Path::new(SAVE).exists() { true => Layers::new(), false => Layers::load_from(SAVE).await.unwrap(), }; diff --git a/src/state.rs b/src/state.rs index 88ebf38..4d8633c 100644 --- a/src/state.rs +++ b/src/state.rs @@ -76,18 +76,25 @@ impl Layers { .len(data.len()) .build()? }; - let mut input=&input; + let mut input = &input; - for layer in &mut self.0{ + for layer in &mut self.0 { layer.forward(ctx, input)?; - input=&layer.activate; + input = &layer.activate; } - for layer in self.0.iter_mut().rev(){ - layer.backward(ctx)?; + let mut delta = BufferBuilder::::new() + .context(&ctx.context) + .len(1024) + .fill_val(1.0) + .build()?; + + for layer in self.0.iter_mut().rev() { + layer.backward(ctx, &mut delta)?; } drop(data); + Ok(()) } }