template backward.cl
This commit is contained in:
@ -0,0 +1,11 @@
|
|||||||
|
__kernel void backward(__global float* activate,
|
||||||
|
__global float* output,
|
||||||
|
__global float* delta,
|
||||||
|
__global float* input,
|
||||||
|
__global float* mul,
|
||||||
|
__global float* add,
|
||||||
|
int input_width,
|
||||||
|
int param_width)
|
||||||
|
{
|
||||||
|
|
||||||
|
}
|
28
src/layer.rs
28
src/layer.rs
@ -1,11 +1,11 @@
|
|||||||
use ocl::{
|
use ocl::{
|
||||||
builders::{BufferBuilder, KernelBuilder},
|
builders::{BufferBuilder, KernelBuilder},
|
||||||
MemFlags, SpatialDims,
|
Buffer, SpatialDims,
|
||||||
};
|
};
|
||||||
use serde::{Deserialize, Serialize};
|
use serde::{Deserialize, Serialize};
|
||||||
|
|
||||||
use crate::{
|
use crate::{
|
||||||
state::{Context, PROGRAM_FORWARD, PROGRAM_BACKWARD},
|
state::{Context, PROGRAM_BACKWARD, PROGRAM_FORWARD},
|
||||||
Error,
|
Error,
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -96,11 +96,11 @@ impl Layer {
|
|||||||
/// FIXME: we should use host memory instead device memory (EG. GPU)
|
/// FIXME: we should use host memory instead device memory (EG. GPU)
|
||||||
///
|
///
|
||||||
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
|
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
|
||||||
pub fn forward(&mut self, state: &Context, activation: &ocl::Buffer<f32>) -> Result<(), Error> {
|
pub fn forward(&mut self, ctx: &Context, activation: &ocl::Buffer<f32>) -> Result<(), Error> {
|
||||||
let kernel = KernelBuilder::new()
|
let kernel = KernelBuilder::new()
|
||||||
.queue(state.queue.clone())
|
.queue(ctx.queue.clone())
|
||||||
.global_work_size(SpatialDims::One(self.inter))
|
.global_work_size(SpatialDims::One(self.inter))
|
||||||
.program(&state.program[PROGRAM_FORWARD])
|
.program(&ctx.program[PROGRAM_FORWARD])
|
||||||
.arg(&self.activate)
|
.arg(&self.activate)
|
||||||
.arg(&self.output)
|
.arg(&self.output)
|
||||||
.arg(activation)
|
.arg(activation)
|
||||||
@ -116,17 +116,25 @@ impl Layer {
|
|||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
/// forward pagination
|
/// forward pagination
|
||||||
|
///
|
||||||
|
/// delta: da superscript [L], kernel is require to rewrite it to da superscript [L-1 ]
|
||||||
///
|
///
|
||||||
/// FIXME: we should use host memory instead device memory (EG. GPU)
|
/// FIXME: we should use host memory instead device memory (EG. GPU)
|
||||||
///
|
///
|
||||||
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
|
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
|
||||||
pub fn backward(&mut self, state: &Context) -> Result<(), Error> {
|
pub fn backward(&mut self, ctx: &Context, delta: &mut Buffer<f32>) -> Result<(), Error> {
|
||||||
todo!();
|
|
||||||
let kernel = KernelBuilder::new()
|
let kernel = KernelBuilder::new()
|
||||||
.queue(state.queue.clone())
|
.queue(ctx.queue.clone())
|
||||||
.global_work_size(SpatialDims::One(self.inter))
|
.global_work_size(SpatialDims::One(self.inter))
|
||||||
.program(&state.program[PROGRAM_BACKWARD])
|
.program(&ctx.program[PROGRAM_BACKWARD])
|
||||||
|
.arg(&self.activate)
|
||||||
|
.arg(&self.output)
|
||||||
|
.arg(delta)
|
||||||
|
.arg(&self.mul_conn)
|
||||||
|
.arg(&self.offset_conn)
|
||||||
|
.arg(self.input)
|
||||||
|
.arg(self.inter)
|
||||||
.build()?;
|
.build()?;
|
||||||
|
|
||||||
unsafe {
|
unsafe {
|
||||||
|
@ -6,7 +6,7 @@ extern crate ocl;
|
|||||||
pub mod layer;
|
pub mod layer;
|
||||||
pub mod state;
|
pub mod state;
|
||||||
|
|
||||||
static SAVE:&str="save.bin";
|
static SAVE: &str = "save.bin";
|
||||||
|
|
||||||
#[derive(Debug, thiserror::Error)]
|
#[derive(Debug, thiserror::Error)]
|
||||||
pub enum Error {
|
pub enum Error {
|
||||||
@ -20,7 +20,7 @@ pub enum Error {
|
|||||||
|
|
||||||
#[monoio::main]
|
#[monoio::main]
|
||||||
async fn main() {
|
async fn main() {
|
||||||
let layers=match Path::new(SAVE).exists(){
|
let layers = match Path::new(SAVE).exists() {
|
||||||
true => Layers::new(),
|
true => Layers::new(),
|
||||||
false => Layers::load_from(SAVE).await.unwrap(),
|
false => Layers::load_from(SAVE).await.unwrap(),
|
||||||
};
|
};
|
||||||
|
17
src/state.rs
17
src/state.rs
@ -76,18 +76,25 @@ impl Layers {
|
|||||||
.len(data.len())
|
.len(data.len())
|
||||||
.build()?
|
.build()?
|
||||||
};
|
};
|
||||||
let mut input=&input;
|
let mut input = &input;
|
||||||
|
|
||||||
for layer in &mut self.0{
|
for layer in &mut self.0 {
|
||||||
layer.forward(ctx, input)?;
|
layer.forward(ctx, input)?;
|
||||||
input=&layer.activate;
|
input = &layer.activate;
|
||||||
}
|
}
|
||||||
|
|
||||||
for layer in self.0.iter_mut().rev(){
|
let mut delta = BufferBuilder::<f32>::new()
|
||||||
layer.backward(ctx)?;
|
.context(&ctx.context)
|
||||||
|
.len(1024)
|
||||||
|
.fill_val(1.0)
|
||||||
|
.build()?;
|
||||||
|
|
||||||
|
for layer in self.0.iter_mut().rev() {
|
||||||
|
layer.backward(ctx, &mut delta)?;
|
||||||
}
|
}
|
||||||
|
|
||||||
drop(data);
|
drop(data);
|
||||||
|
|
||||||
Ok(())
|
Ok(())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user