This commit is contained in:
Eason 2024-01-16 14:31:04 +08:00
commit 1ce6cf706e
9 changed files with 924 additions and 0 deletions

2
.gitignore vendored Normal file
View File

@ -0,0 +1,2 @@
/target
/src/matrix

628
Cargo.lock generated Normal file
View File

@ -0,0 +1,628 @@
# This file is automatically @generated by Cargo.
# It is not intended for manual editing.
version = 3
[[package]]
name = "auto-const-array"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "62f7df18977a1ee03650ee4b31b4aefed6d56bac188760b6e37610400fe8d4bb"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "autocfg"
version = "1.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa"
[[package]]
name = "bincode"
version = "1.3.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad"
dependencies = [
"serde",
]
[[package]]
name = "bitflags"
version = "1.3.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a"
[[package]]
name = "byteorder"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b"
[[package]]
name = "bytes"
version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223"
[[package]]
name = "cfg-if"
version = "1.0.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd"
[[package]]
name = "cl-sys"
version = "0.4.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "4febd824a957638c066180fbf72b2bed5bcee33740773f3dc59fe91f0a3e6595"
dependencies = [
"libc",
]
[[package]]
name = "crossbeam"
version = "0.8.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8"
dependencies = [
"crossbeam-channel",
"crossbeam-deque",
"crossbeam-epoch",
"crossbeam-queue",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-channel"
version = "0.5.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-deque"
version = "0.8.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d"
dependencies = [
"crossbeam-epoch",
"crossbeam-utils",
]
[[package]]
name = "crossbeam-epoch"
version = "0.9.18"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-queue"
version = "0.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35"
dependencies = [
"crossbeam-utils",
]
[[package]]
name = "crossbeam-utils"
version = "0.8.19"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345"
[[package]]
name = "crunchy"
version = "0.2.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7"
[[package]]
name = "enum_primitive"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "be4551092f4d519593039259a9ed8daedf0da12e5109c5280338073eaeb81180"
dependencies = [
"num-traits 0.1.43",
]
[[package]]
name = "futures"
version = "0.1.31"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678"
[[package]]
name = "futures"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0"
dependencies = [
"futures-channel",
"futures-core",
"futures-executor",
"futures-io",
"futures-sink",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-channel"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78"
dependencies = [
"futures-core",
"futures-sink",
]
[[package]]
name = "futures-core"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d"
[[package]]
name = "futures-executor"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d"
dependencies = [
"futures-core",
"futures-task",
"futures-util",
]
[[package]]
name = "futures-io"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1"
[[package]]
name = "futures-macro"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "futures-sink"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5"
[[package]]
name = "futures-task"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004"
[[package]]
name = "futures-util"
version = "0.3.30"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48"
dependencies = [
"futures-channel",
"futures-core",
"futures-io",
"futures-macro",
"futures-sink",
"futures-task",
"memchr",
"pin-project-lite",
"pin-utils",
"slab",
]
[[package]]
name = "fxhash"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c"
dependencies = [
"byteorder",
]
[[package]]
name = "half"
version = "2.3.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872"
dependencies = [
"cfg-if",
"crunchy",
]
[[package]]
name = "io-uring"
version = "0.6.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "460648e47a07a43110fbfa2e0b14afb2be920093c31e5dccc50e49568e099762"
dependencies = [
"bitflags",
"libc",
]
[[package]]
name = "libc"
version = "0.2.152"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7"
[[package]]
name = "log"
version = "0.4.20"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f"
[[package]]
name = "memchr"
version = "2.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149"
[[package]]
name = "memoffset"
version = "0.7.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4"
dependencies = [
"autocfg",
]
[[package]]
name = "mio"
version = "0.8.10"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09"
dependencies = [
"libc",
"log",
"wasi",
"windows-sys",
]
[[package]]
name = "ml-rust"
version = "0.1.0"
dependencies = [
"bincode",
"futures 0.3.30",
"half",
"monoio",
"ocl",
"serde",
"thiserror",
]
[[package]]
name = "monoio"
version = "0.2.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3c7d2b656fb51eca3f9ef301f68435c789a6e08d200be1dc887a4c8b8d4b70fe"
dependencies = [
"auto-const-array",
"bytes",
"fxhash",
"io-uring",
"libc",
"memchr",
"mio",
"monoio-macros",
"nix",
"pin-project-lite",
"socket2",
"windows-sys",
]
[[package]]
name = "monoio-macros"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "176a5f5e69613d9e88337cf2a65e11135332b4efbcc628404a7c555e4452084c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "nix"
version = "0.26.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b"
dependencies = [
"bitflags",
"cfg-if",
"libc",
"memoffset",
"pin-utils",
]
[[package]]
name = "nodrop"
version = "0.1.14"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb"
[[package]]
name = "num-complex"
version = "0.4.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214"
dependencies = [
"num-traits 0.2.17",
]
[[package]]
name = "num-traits"
version = "0.1.43"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31"
dependencies = [
"num-traits 0.2.17",
]
[[package]]
name = "num-traits"
version = "0.2.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c"
dependencies = [
"autocfg",
]
[[package]]
name = "ocl"
version = "0.19.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d1c3ce118fd2f00eeb3c01f8073db1ee127cac0b2f79848192c7889b2bd7fe40"
dependencies = [
"futures 0.1.31",
"nodrop",
"num-traits 0.2.17",
"ocl-core",
"qutex",
"thiserror",
]
[[package]]
name = "ocl-core"
version = "0.11.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c145dd9f205b86611a5df15eb89517417b03005441cf6cec245c65a4b9248c52"
dependencies = [
"bitflags",
"cl-sys",
"enum_primitive",
"num-complex",
"num-traits 0.2.17",
"ocl-core-vector",
"rustc_version",
"thiserror",
]
[[package]]
name = "ocl-core-vector"
version = "0.1.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "f562279e046ca160aeed5eaf6f7c4eb9fa56cb8fd9d038dbdbf56225caeb8074"
dependencies = [
"num-traits 0.2.17",
]
[[package]]
name = "pin-project-lite"
version = "0.2.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58"
[[package]]
name = "pin-utils"
version = "0.1.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184"
[[package]]
name = "proc-macro2"
version = "1.0.76"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c"
dependencies = [
"unicode-ident",
]
[[package]]
name = "quote"
version = "1.0.35"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef"
dependencies = [
"proc-macro2",
]
[[package]]
name = "qutex"
version = "0.2.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "cda4a51ba3d773c196f9450a6b239077ad8dda608b15263b4c9f29e58909883f"
dependencies = [
"crossbeam",
"futures 0.1.31",
]
[[package]]
name = "rustc_version"
version = "0.4.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366"
dependencies = [
"semver",
]
[[package]]
name = "semver"
version = "1.0.21"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0"
[[package]]
name = "serde"
version = "1.0.195"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02"
dependencies = [
"serde_derive",
]
[[package]]
name = "serde_derive"
version = "1.0.195"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "slab"
version = "0.4.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67"
dependencies = [
"autocfg",
]
[[package]]
name = "socket2"
version = "0.5.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9"
dependencies = [
"libc",
"windows-sys",
]
[[package]]
name = "syn"
version = "2.0.48"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f"
dependencies = [
"proc-macro2",
"quote",
"unicode-ident",
]
[[package]]
name = "thiserror"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad"
dependencies = [
"thiserror-impl",
]
[[package]]
name = "thiserror-impl"
version = "1.0.56"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471"
dependencies = [
"proc-macro2",
"quote",
"syn",
]
[[package]]
name = "unicode-ident"
version = "1.0.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b"
[[package]]
name = "wasi"
version = "0.11.0+wasi-snapshot-preview1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423"
[[package]]
name = "windows-sys"
version = "0.48.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9"
dependencies = [
"windows-targets",
]
[[package]]
name = "windows-targets"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c"
dependencies = [
"windows_aarch64_gnullvm",
"windows_aarch64_msvc",
"windows_i686_gnu",
"windows_i686_msvc",
"windows_x86_64_gnu",
"windows_x86_64_gnullvm",
"windows_x86_64_msvc",
]
[[package]]
name = "windows_aarch64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8"
[[package]]
name = "windows_aarch64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc"
[[package]]
name = "windows_i686_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e"
[[package]]
name = "windows_i686_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406"
[[package]]
name = "windows_x86_64_gnu"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e"
[[package]]
name = "windows_x86_64_gnullvm"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc"
[[package]]
name = "windows_x86_64_msvc"
version = "0.48.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538"

18
Cargo.toml Normal file
View File

@ -0,0 +1,18 @@
[package]
name = "ml-rust"
version = "0.1.0"
edition = "2021"
# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[dependencies]
bincode = "1.3.3"
futures = "0.3.30"
half = "2.3.1"
monoio = "0.2.1"
ocl = "0.19.6"
thiserror = "1.0.56"
[dependencies.serde]
version="1.0.195"
features=["derive"]

0
cl/backward.cl Normal file
View File

15
cl/forward.cl Normal file
View File

@ -0,0 +1,15 @@
__kernel void forward(__global float* activate, __global float* output, __global float* input, __global float* mul, __global float* add, int input_width, int param_width)
{
int current_node = get_global_id(0);
float value = 0;
for (int input_node = 0; input_node < input_width; input_node++)
{
float input = input[input_node];
float factor = mul[input_node * param_width + current_node];
value += input * factor + add[current_node];
}
output[current_node] = value;
activate[current_node] = tanh(value);
}

2
rust-toolchain.toml Normal file
View File

@ -0,0 +1,2 @@
[toolchain]
channel = "nightly"

138
src/layer.rs Normal file
View File

@ -0,0 +1,138 @@
use ocl::{
builders::{BufferBuilder, KernelBuilder},
MemFlags, SpatialDims,
};
use serde::{Deserialize, Serialize};
use crate::{
state::{Context, PROGRAM_FORWARD, PROGRAM_BACKWARD},
Error,
};
/// A Layer of NN
pub struct Layer {
/// vector(intermedate) for offset connection
offset_conn: ocl::Buffer<f32>,
/// row-base matrix(input,intermedate) for multiply connection
mul_conn: ocl::Buffer<f32>,
/// vector(intermedate), the output value
/// be processed by activation function
output: ocl::Buffer<f32>,
/// vector(intermedate), the output value
/// should be processed by activation function
pub activate: ocl::Buffer<f32>,
/// input width
input: usize,
/// intermedate width
inter: usize,
}
#[derive(Deserialize, Serialize)]
pub struct LayerDump {
offset_conn: Vec<f32>,
mul_conn: Vec<f32>,
}
impl Layer {
/// dump buffers in Layer into bytes
///
/// FIXME: it's blocking
pub fn dump(self, que: &ocl::Queue) -> Result<Vec<u8>, Error> {
let mut offset_conn = Vec::with_capacity(self.input * self.inter);
let mut mul_conn = Vec::with_capacity(self.input * self.inter);
// let mut param = Vec::with_capacity(self.inter);
unsafe {
// FIXME: should wait for queue to finish even error
self.offset_conn
.read(&mut offset_conn)
.queue(que)
.block(false)
.enq()?;
self.mul_conn
.read(&mut mul_conn)
.queue(que)
.block(false)
.enq()?;
// self.param.read(&mut param).queue(que).block(false).enq()?;
}
que.finish()?;
let dump = LayerDump {
offset_conn,
mul_conn,
// param,
};
Ok(bincode::serialize(&dump)?)
}
/// create new Layer with random paramter
pub fn new(input: usize, node_count: usize, state: &Context) -> Result<Self, Error> {
let inter = node_count;
// FIXME: should write a kernel to randomize value instead
Ok(Self {
offset_conn: BufferBuilder::new()
.context(&state.context)
.len(inter)
.build()?,
mul_conn: BufferBuilder::new()
.context(&state.context)
.len(input * inter)
.build()?,
output: BufferBuilder::new()
.context(&state.context)
.len(inter)
.build()?,
activate: BufferBuilder::new()
.context(&state.context)
.len(inter)
.build()?,
input,
inter,
})
}
/// forward pagination
///
/// FIXME: we should use host memory instead device memory (EG. GPU)
///
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
pub fn forward(&mut self, state: &Context, activation: &ocl::Buffer<f32>) -> Result<(), Error> {
let kernel = KernelBuilder::new()
.queue(state.queue.clone())
.global_work_size(SpatialDims::One(self.inter))
.program(&state.program[PROGRAM_FORWARD])
.arg(&self.activate)
.arg(&self.output)
.arg(activation)
.arg(&self.mul_conn)
.arg(&self.offset_conn)
.arg(self.input)
.arg(self.inter)
.build()?;
unsafe {
kernel.enq().unwrap();
}
Ok(())
}
/// forward pagination
///
/// FIXME: we should use host memory instead device memory (EG. GPU)
///
/// MEM_USE_HOST_PTR: use host memory, cache by device memory
pub fn backward(&mut self, state: &Context) -> Result<(), Error> {
todo!();
let kernel = KernelBuilder::new()
.queue(state.queue.clone())
.global_work_size(SpatialDims::One(self.inter))
.program(&state.program[PROGRAM_BACKWARD])
.build()?;
unsafe {
kernel.enq().unwrap();
}
Ok(())
}
}

28
src/main.rs Normal file
View File

@ -0,0 +1,28 @@
use std::path::Path;
use crate::state::Layers;
extern crate ocl;
pub mod layer;
pub mod state;
static SAVE:&str="save.bin";
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("`{0}`")]
OpenCL(#[from] ocl::Error),
#[error("`{0}`")]
Bincode(#[from] Box<bincode::ErrorKind>),
#[error("`{0}`")]
IO(#[from] std::io::Error),
}
#[monoio::main]
async fn main() {
let layers=match Path::new(SAVE).exists(){
true => Layers::new(),
false => Layers::load_from(SAVE).await.unwrap(),
};
todo!()
}

93
src/state.rs Normal file
View File

@ -0,0 +1,93 @@
use std::path::Path;
use monoio::fs;
use ocl::{builders::BufferBuilder, Device, Program, Queue};
use serde::{Deserialize, Serialize};
use crate::{layer::Layer, Error};
/// index of precompiled opencl program
pub const PROGRAM_FORWARD: usize = 0;
/// index of precompiled opencl program
pub const PROGRAM_BACKWARD: usize = 1;
pub const LEARN_RATE: f32 = 0.05;
/// AppState
pub struct Context {
pub context: ocl::Context,
pub device: Device,
pub queue: Queue,
pub program: [Program; 2],
}
impl Context {
pub fn new() -> Self {
let context = ocl::Context::builder()
.devices(Device::specifier().first())
.build()
.unwrap();
let device = context.devices()[0];
let queue = Queue::new(&context, device, None).unwrap();
let forward = Program::builder()
.src(include_str!("../cl/forward.cl"))
.devices(device)
.build(&context)
.unwrap();
let backward = Program::builder()
.src(include_str!("../cl/backward.cl"))
.devices(device)
.build(&context)
.unwrap();
Context {
context,
device,
queue,
program: [forward, backward],
}
}
}
#[derive(Serialize, Deserialize)]
struct LayersDump(Vec<u8>);
pub struct Layers(Vec<Layer>);
impl Layers {
/// load from file
pub async fn load_from(path: impl AsRef<Path>) -> Result<Self, Error> {
let file = fs::File::open(path.as_ref()).await?;
todo!()
}
/// create layer with random param
pub fn new() -> Self {
todo!()
}
/// train
pub fn train(&mut self, ctx: &Context, data: Vec<f32>) -> Result<(), Error> {
let input = unsafe {
BufferBuilder::new()
.context(&ctx.context)
.use_host_slice(&data)
.len(data.len())
.build()?
};
let mut input=&input;
for layer in &mut self.0{
layer.forward(ctx, input)?;
input=&layer.activate;
}
for layer in self.0.iter_mut().rev(){
layer.backward(ctx)?;
}
drop(data);
Ok(())
}
}