From 1ce6cf706e74a310867f00670aafe21438c833ea Mon Sep 17 00:00:00 2001 From: Eason <30045503+Eason0729@users.noreply.github.com> Date: Tue, 16 Jan 2024 14:31:04 +0800 Subject: [PATCH] init --- .gitignore | 2 + Cargo.lock | 628 ++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 18 ++ cl/backward.cl | 0 cl/forward.cl | 15 ++ rust-toolchain.toml | 2 + src/layer.rs | 138 ++++++++++ src/main.rs | 28 ++ src/state.rs | 93 +++++++ 9 files changed, 924 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 cl/backward.cl create mode 100644 cl/forward.cl create mode 100644 rust-toolchain.toml create mode 100644 src/layer.rs create mode 100644 src/main.rs create mode 100644 src/state.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..46ac07e --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/target +/src/matrix \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..bfcde3d --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,628 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "auto-const-array" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "62f7df18977a1ee03650ee4b31b4aefed6d56bac188760b6e37610400fe8d4bb" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "autocfg" +version = "1.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" + +[[package]] +name = "bincode" +version = "1.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b1f45e9417d87227c7a56d22e471c6206462cba514c7590c09aff4cf6d1ddcad" +dependencies = [ + "serde", +] + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "bytes" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2bd12c1caf447e69cd4528f47f94d203fd2582878ecb9e9465484c4148a8223" + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "cl-sys" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4febd824a957638c066180fbf72b2bed5bcee33740773f3dc59fe91f0a3e6595" +dependencies = [ + "libc", +] + +[[package]] +name = "crossbeam" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1137cd7e7fc0fb5d3c5a8678be38ec56e819125d8d7907411fe24ccb943faca8" +dependencies = [ + "crossbeam-channel", + "crossbeam-deque", + "crossbeam-epoch", + "crossbeam-queue", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-channel" +version = "0.5.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "176dc175b78f56c0f321911d9c8eb2b77a78a4860b9c19db83835fea1a46649b" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-queue" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "df0346b5d5e76ac2fe4e327c5fd1118d6be7c51dfb18f9b7922923f287471e35" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "enum_primitive" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "be4551092f4d519593039259a9ed8daedf0da12e5109c5280338073eaeb81180" +dependencies = [ + "num-traits 0.1.43", +] + +[[package]] +name = "futures" +version = "0.1.31" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3a471a38ef8ed83cd6e40aa59c1ffe17db6855c18e3604d9c4ed8c08ebc28678" + +[[package]] +name = "futures" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645c6916888f6cb6350d2550b80fb63e734897a8498abe35cfb732b6487804b0" +dependencies = [ + "futures-channel", + "futures-core", + "futures-executor", + "futures-io", + "futures-sink", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-channel" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eac8f7d7865dcb88bd4373ab671c8cf4508703796caa2b1985a9ca867b3fcb78" +dependencies = [ + "futures-core", + "futures-sink", +] + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-executor" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a576fc72ae164fca6b9db127eaa9a9dda0d61316034f33a0a0d4eda41f02b01d" +dependencies = [ + "futures-core", + "futures-task", + "futures-util", +] + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-macro" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87750cf4b7a4c0625b1529e4c543c2182106e4dedc60a2a6455e00d212c489ac" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "futures-sink" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9fb8e00e87438d937621c1c6269e53f536c14d3fbd6a042bb24879e57d474fb5" + +[[package]] +name = "futures-task" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "38d84fa142264698cdce1a9f9172cf383a0c82de1bddcf3092901442c4097004" + +[[package]] +name = "futures-util" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d6401deb83407ab3da39eba7e33987a73c3df0c82b4bb5813ee871c19c41d48" +dependencies = [ + "futures-channel", + "futures-core", + "futures-io", + "futures-macro", + "futures-sink", + "futures-task", + "memchr", + "pin-project-lite", + "pin-utils", + "slab", +] + +[[package]] +name = "fxhash" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c31b6d751ae2c7f11320402d34e41349dd1016f8d5d45e48c4312bc8625af50c" +dependencies = [ + "byteorder", +] + +[[package]] +name = "half" +version = "2.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc52e53916c08643f1b56ec082790d1e86a32e58dc5268f897f313fbae7b4872" +dependencies = [ + "cfg-if", + "crunchy", +] + +[[package]] +name = "io-uring" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "460648e47a07a43110fbfa2e0b14afb2be920093c31e5dccc50e49568e099762" +dependencies = [ + "bitflags", + "libc", +] + +[[package]] +name = "libc" +version = "0.2.152" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "13e3bf6590cbc649f4d1a3eefc9d5d6eb746f5200ffb04e5e142700b8faa56e7" + +[[package]] +name = "log" +version = "0.4.20" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" + +[[package]] +name = "memchr" +version = "2.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "523dc4f511e55ab87b694dc30d0f820d60906ef06413f93d4d7a1385599cc149" + +[[package]] +name = "memoffset" +version = "0.7.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5de893c32cde5f383baa4c04c5d6dbdd735cfd4a794b0debdb2bb1b421da5ff4" +dependencies = [ + "autocfg", +] + +[[package]] +name = "mio" +version = "0.8.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f3d0b296e374a4e6f3c7b0a1f5a51d748a0d34c85e7dc48fc3fa9a87657fe09" +dependencies = [ + "libc", + "log", + "wasi", + "windows-sys", +] + +[[package]] +name = "ml-rust" +version = "0.1.0" +dependencies = [ + "bincode", + "futures 0.3.30", + "half", + "monoio", + "ocl", + "serde", + "thiserror", +] + +[[package]] +name = "monoio" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3c7d2b656fb51eca3f9ef301f68435c789a6e08d200be1dc887a4c8b8d4b70fe" +dependencies = [ + "auto-const-array", + "bytes", + "fxhash", + "io-uring", + "libc", + "memchr", + "mio", + "monoio-macros", + "nix", + "pin-project-lite", + "socket2", + "windows-sys", +] + +[[package]] +name = "monoio-macros" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "176a5f5e69613d9e88337cf2a65e11135332b4efbcc628404a7c555e4452084c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "nix" +version = "0.26.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "598beaf3cc6fdd9a5dfb1630c2800c7acd31df7aaf0f565796fba2b53ca1af1b" +dependencies = [ + "bitflags", + "cfg-if", + "libc", + "memoffset", + "pin-utils", +] + +[[package]] +name = "nodrop" +version = "0.1.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "72ef4a56884ca558e5ddb05a1d1e7e1bfd9a68d9ed024c21704cc98872dae1bb" + +[[package]] +name = "num-complex" +version = "0.4.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1ba157ca0885411de85d6ca030ba7e2a83a28636056c7c699b07c8b6f7383214" +dependencies = [ + "num-traits 0.2.17", +] + +[[package]] +name = "num-traits" +version = "0.1.43" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "92e5113e9fd4cc14ded8e499429f396a20f98c772a47cc8622a736e1ec843c31" +dependencies = [ + "num-traits 0.2.17", +] + +[[package]] +name = "num-traits" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" +dependencies = [ + "autocfg", +] + +[[package]] +name = "ocl" +version = "0.19.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1c3ce118fd2f00eeb3c01f8073db1ee127cac0b2f79848192c7889b2bd7fe40" +dependencies = [ + "futures 0.1.31", + "nodrop", + "num-traits 0.2.17", + "ocl-core", + "qutex", + "thiserror", +] + +[[package]] +name = "ocl-core" +version = "0.11.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c145dd9f205b86611a5df15eb89517417b03005441cf6cec245c65a4b9248c52" +dependencies = [ + "bitflags", + "cl-sys", + "enum_primitive", + "num-complex", + "num-traits 0.2.17", + "ocl-core-vector", + "rustc_version", + "thiserror", +] + +[[package]] +name = "ocl-core-vector" +version = "0.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f562279e046ca160aeed5eaf6f7c4eb9fa56cb8fd9d038dbdbf56225caeb8074" +dependencies = [ + "num-traits 0.2.17", +] + +[[package]] +name = "pin-project-lite" +version = "0.2.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8afb450f006bf6385ca15ef45d71d2288452bc3683ce2e2cacc0d18e4be60b58" + +[[package]] +name = "pin-utils" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" + +[[package]] +name = "proc-macro2" +version = "1.0.76" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95fc56cda0b5c3325f5fbbd7ff9fda9e02bb00bb3dac51252d2f1bfa1cb8cc8c" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.35" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "291ec9ab5efd934aaf503a6466c5d5251535d108ee747472c3977cc5acc868ef" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "qutex" +version = "0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cda4a51ba3d773c196f9450a6b239077ad8dda608b15263b4c9f29e58909883f" +dependencies = [ + "crossbeam", + "futures 0.1.31", +] + +[[package]] +name = "rustc_version" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bfa0f585226d2e68097d4f95d113b15b83a82e819ab25717ec0590d9584ef366" +dependencies = [ + "semver", +] + +[[package]] +name = "semver" +version = "1.0.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b97ed7a9823b74f99c7742f5336af7be5ecd3eeafcb1507d1fa93347b1d589b0" + +[[package]] +name = "serde" +version = "1.0.195" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "63261df402c67811e9ac6def069e4786148c4563f4b50fd4bf30aa370d626b02" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.195" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "46fe8f8603d81ba86327b23a2e9cdf49e1255fb94a4c5f297f6ee0547178ea2c" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "socket2" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7b5fac59a5cb5dd637972e5fca70daf0523c9067fcdc4842f053dae04a18f8e9" +dependencies = [ + "libc", + "windows-sys", +] + +[[package]] +name = "syn" +version = "2.0.48" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0f3531638e407dfc0814761abb7c00a5b54992b849452a0646b7f65c9f770f3f" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "thiserror" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d54378c645627613241d077a3a79db965db602882668f9136ac42af9ecb730ad" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.56" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fa0faa943b50f3db30a20aa7e265dbc66076993efed8463e8de414e5d06d3471" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..0f7c081 --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,18 @@ +[package] +name = "ml-rust" +version = "0.1.0" +edition = "2021" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html + +[dependencies] +bincode = "1.3.3" +futures = "0.3.30" +half = "2.3.1" +monoio = "0.2.1" +ocl = "0.19.6" +thiserror = "1.0.56" + +[dependencies.serde] +version="1.0.195" +features=["derive"] diff --git a/cl/backward.cl b/cl/backward.cl new file mode 100644 index 0000000..e69de29 diff --git a/cl/forward.cl b/cl/forward.cl new file mode 100644 index 0000000..aaa49d1 --- /dev/null +++ b/cl/forward.cl @@ -0,0 +1,15 @@ +__kernel void forward(__global float* activate, __global float* output, __global float* input, __global float* mul, __global float* add, int input_width, int param_width) +{ + int current_node = get_global_id(0); + + float value = 0; + for (int input_node = 0; input_node < input_width; input_node++) + { + float input = input[input_node]; + float factor = mul[input_node * param_width + current_node]; + value += input * factor + add[current_node]; + } + + output[current_node] = value; + activate[current_node] = tanh(value); +} \ No newline at end of file diff --git a/rust-toolchain.toml b/rust-toolchain.toml new file mode 100644 index 0000000..271800c --- /dev/null +++ b/rust-toolchain.toml @@ -0,0 +1,2 @@ +[toolchain] +channel = "nightly" \ No newline at end of file diff --git a/src/layer.rs b/src/layer.rs new file mode 100644 index 0000000..2c71bda --- /dev/null +++ b/src/layer.rs @@ -0,0 +1,138 @@ +use ocl::{ + builders::{BufferBuilder, KernelBuilder}, + MemFlags, SpatialDims, +}; +use serde::{Deserialize, Serialize}; + +use crate::{ + state::{Context, PROGRAM_FORWARD, PROGRAM_BACKWARD}, + Error, +}; + +/// A Layer of NN +pub struct Layer { + /// vector(intermedate) for offset connection + offset_conn: ocl::Buffer, + /// row-base matrix(input,intermedate) for multiply connection + mul_conn: ocl::Buffer, + /// vector(intermedate), the output value + /// be processed by activation function + output: ocl::Buffer, + /// vector(intermedate), the output value + /// should be processed by activation function + pub activate: ocl::Buffer, + /// input width + input: usize, + /// intermedate width + inter: usize, +} + +#[derive(Deserialize, Serialize)] +pub struct LayerDump { + offset_conn: Vec, + mul_conn: Vec, +} + +impl Layer { + /// dump buffers in Layer into bytes + /// + /// FIXME: it's blocking + pub fn dump(self, que: &ocl::Queue) -> Result, Error> { + let mut offset_conn = Vec::with_capacity(self.input * self.inter); + let mut mul_conn = Vec::with_capacity(self.input * self.inter); + // let mut param = Vec::with_capacity(self.inter); + + unsafe { + // FIXME: should wait for queue to finish even error + self.offset_conn + .read(&mut offset_conn) + .queue(que) + .block(false) + .enq()?; + self.mul_conn + .read(&mut mul_conn) + .queue(que) + .block(false) + .enq()?; + // self.param.read(&mut param).queue(que).block(false).enq()?; + } + que.finish()?; + + let dump = LayerDump { + offset_conn, + mul_conn, + // param, + }; + Ok(bincode::serialize(&dump)?) + } + /// create new Layer with random paramter + pub fn new(input: usize, node_count: usize, state: &Context) -> Result { + let inter = node_count; + + // FIXME: should write a kernel to randomize value instead + Ok(Self { + offset_conn: BufferBuilder::new() + .context(&state.context) + .len(inter) + .build()?, + mul_conn: BufferBuilder::new() + .context(&state.context) + .len(input * inter) + .build()?, + output: BufferBuilder::new() + .context(&state.context) + .len(inter) + .build()?, + activate: BufferBuilder::new() + .context(&state.context) + .len(inter) + .build()?, + input, + inter, + }) + } + /// forward pagination + /// + /// FIXME: we should use host memory instead device memory (EG. GPU) + /// + /// MEM_USE_HOST_PTR: use host memory, cache by device memory + pub fn forward(&mut self, state: &Context, activation: &ocl::Buffer) -> Result<(), Error> { + let kernel = KernelBuilder::new() + .queue(state.queue.clone()) + .global_work_size(SpatialDims::One(self.inter)) + .program(&state.program[PROGRAM_FORWARD]) + .arg(&self.activate) + .arg(&self.output) + .arg(activation) + .arg(&self.mul_conn) + .arg(&self.offset_conn) + .arg(self.input) + .arg(self.inter) + .build()?; + + unsafe { + kernel.enq().unwrap(); + } + + Ok(()) + } + /// forward pagination + /// + /// FIXME: we should use host memory instead device memory (EG. GPU) + /// + /// MEM_USE_HOST_PTR: use host memory, cache by device memory + pub fn backward(&mut self, state: &Context) -> Result<(), Error> { + todo!(); + let kernel = KernelBuilder::new() + .queue(state.queue.clone()) + .global_work_size(SpatialDims::One(self.inter)) + .program(&state.program[PROGRAM_BACKWARD]) + .build()?; + + unsafe { + kernel.enq().unwrap(); + } + + Ok(()) + } +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c5eaf23 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,28 @@ +use std::path::Path; + +use crate::state::Layers; + +extern crate ocl; +pub mod layer; +pub mod state; + +static SAVE:&str="save.bin"; + +#[derive(Debug, thiserror::Error)] +pub enum Error { + #[error("`{0}`")] + OpenCL(#[from] ocl::Error), + #[error("`{0}`")] + Bincode(#[from] Box), + #[error("`{0}`")] + IO(#[from] std::io::Error), +} + +#[monoio::main] +async fn main() { + let layers=match Path::new(SAVE).exists(){ + true => Layers::new(), + false => Layers::load_from(SAVE).await.unwrap(), + }; + todo!() +} diff --git a/src/state.rs b/src/state.rs new file mode 100644 index 0000000..88ebf38 --- /dev/null +++ b/src/state.rs @@ -0,0 +1,93 @@ +use std::path::Path; + +use monoio::fs; +use ocl::{builders::BufferBuilder, Device, Program, Queue}; +use serde::{Deserialize, Serialize}; + +use crate::{layer::Layer, Error}; + +/// index of precompiled opencl program +pub const PROGRAM_FORWARD: usize = 0; + +/// index of precompiled opencl program +pub const PROGRAM_BACKWARD: usize = 1; + +pub const LEARN_RATE: f32 = 0.05; + +/// AppState +pub struct Context { + pub context: ocl::Context, + pub device: Device, + pub queue: Queue, + pub program: [Program; 2], +} + +impl Context { + pub fn new() -> Self { + let context = ocl::Context::builder() + .devices(Device::specifier().first()) + .build() + .unwrap(); + let device = context.devices()[0]; + let queue = Queue::new(&context, device, None).unwrap(); + + let forward = Program::builder() + .src(include_str!("../cl/forward.cl")) + .devices(device) + .build(&context) + .unwrap(); + let backward = Program::builder() + .src(include_str!("../cl/backward.cl")) + .devices(device) + .build(&context) + .unwrap(); + + Context { + context, + device, + queue, + program: [forward, backward], + } + } +} + +#[derive(Serialize, Deserialize)] +struct LayersDump(Vec); + +pub struct Layers(Vec); + +impl Layers { + /// load from file + pub async fn load_from(path: impl AsRef) -> Result { + let file = fs::File::open(path.as_ref()).await?; + + todo!() + } + /// create layer with random param + pub fn new() -> Self { + todo!() + } + /// train + pub fn train(&mut self, ctx: &Context, data: Vec) -> Result<(), Error> { + let input = unsafe { + BufferBuilder::new() + .context(&ctx.context) + .use_host_slice(&data) + .len(data.len()) + .build()? + }; + let mut input=&input; + + for layer in &mut self.0{ + layer.forward(ctx, input)?; + input=&layer.activate; + } + + for layer in self.0.iter_mut().rev(){ + layer.backward(ctx)?; + } + + drop(data); + Ok(()) + } +}