template backward.cl

This commit is contained in:
Eason
2024-01-16 14:54:19 +08:00
parent 1ce6cf706e
commit a10938d1f5
4 changed files with 43 additions and 17 deletions

View File

@ -0,0 +1,11 @@
__kernel void backward(__global float* activate,
__global float* output,
__global float* delta,
__global float* input,
__global float* mul,
__global float* add,
int input_width,
int param_width)
{
}

View File

@ -1,11 +1,11 @@
use ocl::{ use ocl::{
builders::{BufferBuilder, KernelBuilder}, builders::{BufferBuilder, KernelBuilder},
MemFlags, SpatialDims, Buffer, SpatialDims,
}; };
use serde::{Deserialize, Serialize}; use serde::{Deserialize, Serialize};
use crate::{ use crate::{
state::{Context, PROGRAM_FORWARD, PROGRAM_BACKWARD}, state::{Context, PROGRAM_BACKWARD, PROGRAM_FORWARD},
Error, Error,
}; };
@ -96,11 +96,11 @@ impl Layer {
/// FIXME: we should use host memory instead device memory (EG. GPU) /// FIXME: we should use host memory instead device memory (EG. GPU)
/// ///
/// MEM_USE_HOST_PTR: use host memory, cache by device memory /// MEM_USE_HOST_PTR: use host memory, cache by device memory
pub fn forward(&mut self, state: &Context, activation: &ocl::Buffer<f32>) -> Result<(), Error> { pub fn forward(&mut self, ctx: &Context, activation: &ocl::Buffer<f32>) -> Result<(), Error> {
let kernel = KernelBuilder::new() let kernel = KernelBuilder::new()
.queue(state.queue.clone()) .queue(ctx.queue.clone())
.global_work_size(SpatialDims::One(self.inter)) .global_work_size(SpatialDims::One(self.inter))
.program(&state.program[PROGRAM_FORWARD]) .program(&ctx.program[PROGRAM_FORWARD])
.arg(&self.activate) .arg(&self.activate)
.arg(&self.output) .arg(&self.output)
.arg(activation) .arg(activation)
@ -118,15 +118,23 @@ impl Layer {
} }
/// forward pagination /// forward pagination
/// ///
/// delta: da superscript [L], kernel is require to rewrite it to da superscript [L-1 ]
///
/// FIXME: we should use host memory instead device memory (EG. GPU) /// FIXME: we should use host memory instead device memory (EG. GPU)
/// ///
/// MEM_USE_HOST_PTR: use host memory, cache by device memory /// MEM_USE_HOST_PTR: use host memory, cache by device memory
pub fn backward(&mut self, state: &Context) -> Result<(), Error> { pub fn backward(&mut self, ctx: &Context, delta: &mut Buffer<f32>) -> Result<(), Error> {
todo!();
let kernel = KernelBuilder::new() let kernel = KernelBuilder::new()
.queue(state.queue.clone()) .queue(ctx.queue.clone())
.global_work_size(SpatialDims::One(self.inter)) .global_work_size(SpatialDims::One(self.inter))
.program(&state.program[PROGRAM_BACKWARD]) .program(&ctx.program[PROGRAM_BACKWARD])
.arg(&self.activate)
.arg(&self.output)
.arg(delta)
.arg(&self.mul_conn)
.arg(&self.offset_conn)
.arg(self.input)
.arg(self.inter)
.build()?; .build()?;
unsafe { unsafe {

View File

@ -6,7 +6,7 @@ extern crate ocl;
pub mod layer; pub mod layer;
pub mod state; pub mod state;
static SAVE:&str="save.bin"; static SAVE: &str = "save.bin";
#[derive(Debug, thiserror::Error)] #[derive(Debug, thiserror::Error)]
pub enum Error { pub enum Error {
@ -20,7 +20,7 @@ pub enum Error {
#[monoio::main] #[monoio::main]
async fn main() { async fn main() {
let layers=match Path::new(SAVE).exists(){ let layers = match Path::new(SAVE).exists() {
true => Layers::new(), true => Layers::new(),
false => Layers::load_from(SAVE).await.unwrap(), false => Layers::load_from(SAVE).await.unwrap(),
}; };

View File

@ -76,18 +76,25 @@ impl Layers {
.len(data.len()) .len(data.len())
.build()? .build()?
}; };
let mut input=&input; let mut input = &input;
for layer in &mut self.0{ for layer in &mut self.0 {
layer.forward(ctx, input)?; layer.forward(ctx, input)?;
input=&layer.activate; input = &layer.activate;
} }
for layer in self.0.iter_mut().rev(){ let mut delta = BufferBuilder::<f32>::new()
layer.backward(ctx)?; .context(&ctx.context)
.len(1024)
.fill_val(1.0)
.build()?;
for layer in self.0.iter_mut().rev() {
layer.backward(ctx, &mut delta)?;
} }
drop(data); drop(data);
Ok(()) Ok(())
} }
} }