diff --git a/.gitignore b/.gitignore index ce180a2..d41e835 100644 --- a/.gitignore +++ b/.gitignore @@ -1,3 +1,7 @@ +# model +/*.save +/*.bin + *.env *.swp *.pyproj diff --git a/config.toml b/config.toml new file mode 100644 index 0000000..7cddb77 --- /dev/null +++ b/config.toml @@ -0,0 +1,7 @@ +exploration_rate = 1224 +update_frequency = 60 +batch_size = 48 +replay_size = 300 +learning_rate = 0.01 +gamma = 0.97 +train = false diff --git a/justfile b/justfile new file mode 100644 index 0000000..3809fed --- /dev/null +++ b/justfile @@ -0,0 +1,9 @@ +test: + python -m mlgame -f 30 -i ./ml/ml_play_manual_1P.py -i ./ml/ml_play_manual_2P.py . --level 8 --game_times 3 +build: + cd pyr && cargo build --release +train level: +run level: + python -m mlgame -f 400 -i ./ml/ml_play_pyr_test.py -i ./ml/ml_play_pyr_test.py . --sound off --level {{level}} --game_times 3 +clean: + rm -r model.bin diff --git a/pyr/Cargo.lock b/pyr/Cargo.lock new file mode 100644 index 0000000..23b3c27 --- /dev/null +++ b/pyr/Cargo.lock @@ -0,0 +1,1379 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "async-channel" +version = "2.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "136d4d23bcc79e27423727b36823d86233aad06dfea531837b038394d11e9928" +dependencies = [ + "concurrent-queue", + "event-listener 5.3.0", + "event-listener-strategy 0.5.1", + "futures-core", + "pin-project-lite", +] + +[[package]] +name = "async-executor" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b10202063978b3351199d68f8b22c4e47e4b1b822f8d43fd862d5ea8c006b29a" +dependencies = [ + "async-task", + "concurrent-queue", + "fastrand", + "futures-lite", + "slab", +] + +[[package]] +name = "async-fs" +version = "2.1.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bc19683171f287921f2405677dd2ed2549c3b3bda697a563ebc3a121ace2aba1" +dependencies = [ + "async-lock", + "blocking", + "futures-lite", +] + +[[package]] +name = "async-io" +version = "2.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dcccb0f599cfa2f8ace422d3555572f47424da5648a4382a9dd0310ff8210884" +dependencies = [ + "async-lock", + "cfg-if", + "concurrent-queue", + "futures-io", + "futures-lite", + "parking", + "polling", + "rustix", + "slab", + "tracing", + "windows-sys 0.52.0", +] + +[[package]] +name = "async-lock" +version = "3.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d034b430882f8381900d3fe6f0aaa3ad94f2cb4ac519b429692a1bc2dda4ae7b" +dependencies = [ + "event-listener 4.0.3", + "event-listener-strategy 0.4.0", + "pin-project-lite", +] + +[[package]] +name = "async-net" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b948000fad4873c1c9339d60f2623323a0cfd3816e5181033c6a5cb68b2accf7" +dependencies = [ + "async-io", + "blocking", + "futures-lite", +] + +[[package]] +name = "async-process" +version = "2.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a53fc6301894e04a92cb2584fedde80cb25ba8e02d9dc39d4a87d036e22f397d" +dependencies = [ + "async-channel", + "async-io", + "async-lock", + "async-signal", + "async-task", + "blocking", + "cfg-if", + "event-listener 5.3.0", + "futures-lite", + "rustix", + "tracing", + "windows-sys 0.52.0", +] + +[[package]] +name = "async-signal" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "afe66191c335039c7bb78f99dc7520b0cbb166b3a1cb33a03f53d8a1c6f2afda" +dependencies = [ + "async-io", + "async-lock", + "atomic-waker", + "cfg-if", + "futures-core", + "futures-io", + "rustix", + "signal-hook-registry", + "slab", + "windows-sys 0.52.0", +] + +[[package]] +name = "async-task" +version = "4.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb36e985947064623dbd357f727af08ffd077f93d696782f3c56365fa2e2799" + +[[package]] +name = "atomic-waker" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1505bd5d3d116872e7271a6d4e16d81d0c8570876c8de68093a09ac269d8aac0" + +[[package]] +name = "autocfg" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f1fdabc7756949593fe60f30ec81974b613357de856987752631dea1e3394c80" + +[[package]] +name = "bitflags" +version = "1.3.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" + +[[package]] +name = "bitflags" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cf4b9d6a944f767f8e5e0db018570623c85f3d925ac718db4e06d0187adb21c1" + +[[package]] +name = "blocking" +version = "1.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6a37913e8dc4ddcc604f0c6d3bf2887c995153af3611de9e23c352b44c1b9118" +dependencies = [ + "async-channel", + "async-lock", + "async-task", + "fastrand", + "futures-io", + "futures-lite", + "piper", + "tracing", +] + +[[package]] +name = "bytemuck" +version = "1.15.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5d6d68c57235a3a081186990eca2867354726650f42f7516ca50c28d6281fd15" +dependencies = [ + "bytemuck_derive", +] + +[[package]] +name = "bytemuck_derive" +version = "1.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4da9a32f3fed317401fa3c862968128267c3106685286e15d5aaa3d7389c2f60" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "byteorder" +version = "1.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1fd0f2584146f6f2ef48085050886acf353beff7305ebd1ae69500e27c67f64b" + +[[package]] +name = "candle-core" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f1b20174c1707e20f4cb364a355b449803c03e9b0c9193324623cf9787a4e00" +dependencies = [ + "byteorder", + "gemm", + "half", + "memmap2", + "num-traits", + "num_cpus", + "rand", + "rand_distr", + "rayon", + "safetensors", + "thiserror", + "yoke", + "zip", +] + +[[package]] +name = "candle-nn" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "66a27533c8edfc915a6459f9850641ef523a829fa1a181c670766c1f752d873a" +dependencies = [ + "candle-core", + "half", + "num-traits", + "rayon", + "safetensors", + "serde", + "thiserror", +] + +[[package]] +name = "cfg-if" +version = "1.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" + +[[package]] +name = "colored" +version = "2.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "cbf2150cce219b664a8a70df7a1f933836724b503f8a413af9365b4dcc4d90b8" +dependencies = [ + "lazy_static", + "windows-sys 0.48.0", +] + +[[package]] +name = "concurrent-queue" +version = "2.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d16048cd947b08fa32c24458a22f5dc5e835264f689f4f5653210c69fd107363" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crc32fast" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b3855a8a784b474f333699ef2bbca9db2c4a1f6d9088a90a2d25b1eb53111eaa" +dependencies = [ + "cfg-if", +] + +[[package]] +name = "crossbeam-deque" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "613f8cc01fe9cf1a3eb3d7f488fd2fa8388403e97039e2f73692932e291a770d" +dependencies = [ + "crossbeam-epoch", + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-epoch" +version = "0.9.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b82ac4a3c2ca9c3460964f020e1402edd5753411d7737aa39c3714ad1b5420e" +dependencies = [ + "crossbeam-utils", +] + +[[package]] +name = "crossbeam-utils" +version = "0.8.19" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "248e3bacc7dc6baa3b21e405ee045c3047101a49145e7e9eca583ab4c2ca5345" + +[[package]] +name = "crunchy" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" + +[[package]] +name = "deranged" +version = "0.3.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" +dependencies = [ + "powerfmt", +] + +[[package]] +name = "dyn-stack" +version = "0.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "56e53799688f5632f364f8fb387488dd05db9fe45db7011be066fc20e7027f8b" +dependencies = [ + "bytemuck", + "reborrow", +] + +[[package]] +name = "either" +version = "1.11.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a47c1c47d2f5964e29c61246e81db715514cd532db6b5116a25ea3c03d6780a2" + +[[package]] +name = "enum-as-inner" +version = "0.6.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5ffccbb6966c05b32ef8fbac435df276c4ae4d3dc55a8cd0eb9745e6c12f546a" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "equivalent" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" + +[[package]] +name = "errno" +version = "0.3.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a258e46cdc063eb8519c00b9fc845fc47bcfca4130e2f08e88665ceda8474245" +dependencies = [ + "libc", + "windows-sys 0.52.0", +] + +[[package]] +name = "event-listener" +version = "4.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "67b215c49b2b248c855fb73579eb1f4f26c38ffdc12973e20e07b91d78d5646e" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6d9944b8ca13534cdfb2800775f8dd4902ff3fc75a50101466decadfdf322a24" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "958e4d70b6d5e81971bebec42271ec641e7ff4e170a6fa605f2b8a8b65cb97d3" +dependencies = [ + "event-listener 4.0.3", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "332f51cb23d20b0de8458b86580878211da09bcd4503cb579c225b3d124cabb3" +dependencies = [ + "event-listener 5.3.0", + "pin-project-lite", +] + +[[package]] +name = "fastrand" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "658bd65b1cf4c852a3cc96f18a8ce7b5640f6b703f905c7d74532294c2a63984" + +[[package]] +name = "futures-core" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dfc6580bb841c5a68e9ef15c77ccc837b40a7504914d52e47b8b0e9bbda25a1d" + +[[package]] +name = "futures-io" +version = "0.3.30" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a44623e20b9681a318efdd71c299b6b222ed6f231972bfe2f224ebad6311f0c1" + +[[package]] +name = "futures-lite" +version = "2.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "52527eb5074e35e9339c6b4e8d12600c7128b68fb25dcb9fa9dec18f7c25f3a5" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + +[[package]] +name = "gemm" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6ab24cc62135b40090e31a76a9b2766a501979f3070fa27f689c27ec04377d32" +dependencies = [ + "dyn-stack", + "gemm-c32", + "gemm-c64", + "gemm-common", + "gemm-f16", + "gemm-f32", + "gemm-f64", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b9c030d0b983d1e34a546b86e08f600c11696fde16199f971cd46c12e67512c0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-c64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fbb5f2e79fefb9693d18e1066a557b4546cd334b226beadc68b11a8f9431852a" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-common" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a2e7ea062c987abcd8db95db917b4ffb4ecdfd0668471d8dc54734fdff2354e8" +dependencies = [ + "bytemuck", + "dyn-stack", + "half", + "num-complex", + "num-traits", + "once_cell", + "paste", + "pulp", + "raw-cpuid", + "rayon", + "seq-macro", + "sysctl", +] + +[[package]] +name = "gemm-f16" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7ca4c06b9b11952071d317604acb332e924e817bd891bec8dfb494168c7cedd4" +dependencies = [ + "dyn-stack", + "gemm-common", + "gemm-f32", + "half", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "rayon", + "seq-macro", +] + +[[package]] +name = "gemm-f32" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9a69f51aaefbd9cf12d18faf273d3e982d9d711f60775645ed5c8047b4ae113" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "gemm-f64" +version = "0.17.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "aa397a48544fadf0b81ec8741e5c0fba0043008113f71f2034def1935645d2b0" +dependencies = [ + "dyn-stack", + "gemm-common", + "num-complex", + "num-traits", + "paste", + "raw-cpuid", + "seq-macro", +] + +[[package]] +name = "getrandom" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "94b22e06ecb0110981051723910cbf0b5f5e09a2062dd7663334ee79a9d1286c" +dependencies = [ + "cfg-if", + "libc", + "wasi", +] + +[[package]] +name = "half" +version = "2.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6dd08c532ae367adf81c312a4580bc67f1d0fe8bc9c460520283f4c0ff277888" +dependencies = [ + "bytemuck", + "cfg-if", + "crunchy", + "num-traits", + "rand", + "rand_distr", +] + +[[package]] +name = "hashbrown" +version = "0.14.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "290f1a1d9242c78d09ce40a5e87e7554ee637af1351968159f4952f028f75604" + +[[package]] +name = "heck" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" + +[[package]] +name = "hermit-abi" +version = "0.3.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" + +[[package]] +name = "indexmap" +version = "2.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "168fb715dda47215e360912c096649d23d58bf392ac62f73919e831745e40f26" +dependencies = [ + "equivalent", + "hashbrown", +] + +[[package]] +name = "itoa" +version = "1.0.11" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "49f1f14873335454500d59611f1cf4a4b0f786f9ac11f4312a78e4cf2566695b" + +[[package]] +name = "lazy_static" +version = "1.4.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e2abad23fbc42b3700f2f279844dc832adb2b2eb069b2df918f455c4e18cc646" + +[[package]] +name = "libc" +version = "0.2.153" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c198f91728a82281a64e1f4f9eeb25d82cb32a5de251c6bd1b5154d63a8e7bd" + +[[package]] +name = "libm" +version = "0.2.8" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ec2a862134d2a7d32d7983ddcdd1c4923530833c9f2ea1a44fc5fa473989058" + +[[package]] +name = "linux-raw-sys" +version = "0.4.13" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "01cda141df6706de531b6c46c3a33ecca755538219bd484262fa09410c13539c" + +[[package]] +name = "log" +version = "0.4.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "90ed8c1e510134f979dbc4f070f87d4313098b704861a105fe34231c70a3901c" + +[[package]] +name = "memchr" +version = "2.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c8640c5d730cb13ebd907d8d04b52f55ac9a2eec55b440c8892f40d56c76c1d" + +[[package]] +name = "memmap2" +version = "0.9.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe751422e4a8caa417e13c3ea66452215d7d63e19e604f4980461212f3ae1322" +dependencies = [ + "libc", + "stable_deref_trait", +] + +[[package]] +name = "num-complex" +version = "0.4.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "23c6602fda94a57c990fe0df199a035d83576b496aa29f4e634a8ac6004e68a6" +dependencies = [ + "bytemuck", + "num-traits", +] + +[[package]] +name = "num-conv" +version = "0.1.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" + +[[package]] +name = "num-traits" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "da0df0e5185db44f69b44f26786fe401b6c293d1907744beaa7fa62b2e5a517a" +dependencies = [ + "autocfg", + "libm", +] + +[[package]] +name = "num_cpus" +version = "1.16.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4161fcb6d602d4d2081af7c3a45852d875a03dd337a6bfdd6e06407b61342a43" +dependencies = [ + "hermit-abi", + "libc", +] + +[[package]] +name = "num_threads" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5c7398b9c8b70908f6371f47ed36737907c87c52af34c268fed0bf0ceb92ead9" +dependencies = [ + "libc", +] + +[[package]] +name = "once_cell" +version = "1.19.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3fdb12b2476b595f9358c5161aa467c2438859caa136dec86c26fdd2efe17b92" + +[[package]] +name = "parking" +version = "2.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bb813b8af86854136c6922af0598d719255ecb2179515e6e7730d468f05c9cae" + +[[package]] +name = "paste" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" + +[[package]] +name = "pin-project-lite" +version = "0.2.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bda66fc9667c18cb2758a2ac84d1167245054bcf85d5d1aaa6923f45801bdd02" + +[[package]] +name = "piper" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "668d31b1c4eba19242f2088b2bf3316b82ca31082a8335764db4e083db7485d4" +dependencies = [ + "atomic-waker", + "fastrand", + "futures-io", +] + +[[package]] +name = "polling" +version = "3.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "645493cf344456ef24219d02a768cf1fb92ddf8c92161679ae3d91b91a637be3" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix", + "tracing", + "windows-sys 0.52.0", +] + +[[package]] +name = "powerfmt" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" + +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "proc-macro2" +version = "1.0.81" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3d1597b0c024618f09a9c3b8655b7e430397a36d23fdafec26d6965e9eec3eba" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "pulp" +version = "0.18.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e14989307e408d9f4245d4fda09a7b144a08114ba124e26cab60ab83dc98db10" +dependencies = [ + "bytemuck", + "libm", + "num-complex", + "reborrow", +] + +[[package]] +name = "pyr" +version = "0.1.0" +dependencies = [ + "candle-core", + "candle-nn", + "lazy_static", + "log", + "rand", + "serde", + "simple_logger", + "smol", + "toml", +] + +[[package]] +name = "quote" +version = "1.0.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0fa76aaf39101c457836aec0ce2316dbdc3ab723cdda1c6bd4e6ad4208acaca7" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "rand" +version = "0.8.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "34af8d1a0e25924bc5b7c43c079c942339d8f0a8b57c39049bef581b46327404" +dependencies = [ + "libc", + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6c10a63a0fa32252be49d21e7709d4d4baf8d231c2dbce1eaa8141b9b127d88" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.6.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec0be4795e2f6a28069bec0b5ff3e2ac9bafc99e6a9a7dc3547996c5c816922c" +dependencies = [ + "getrandom", +] + +[[package]] +name = "rand_distr" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32cb0b9bc82b0a0876c2dd994a7e7a2683d3e7390ca40e6886785ef0c7e3ee31" +dependencies = [ + "num-traits", + "rand", +] + +[[package]] +name = "raw-cpuid" +version = "10.7.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6c297679cb867470fa8c9f67dbba74a78d78e3e98d7cf2b08d6d71540f797332" +dependencies = [ + "bitflags 1.3.2", +] + +[[package]] +name = "rayon" +version = "1.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b418a60154510ca1a002a752ca9714984e21e4241e804d32555251faf8b78ffa" +dependencies = [ + "either", + "rayon-core", +] + +[[package]] +name = "rayon-core" +version = "1.12.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1465873a3dfdaa8ae7cb14b4383657caab0b3e8a0aa9ae8e04b044854c8dfce2" +dependencies = [ + "crossbeam-deque", + "crossbeam-utils", +] + +[[package]] +name = "reborrow" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "03251193000f4bd3b042892be858ee50e8b3719f2b08e5833ac4353724632430" + +[[package]] +name = "rustix" +version = "0.38.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "70dc5ec042f7a43c4a73241207cecc9873a06d45debb38b329f8541d85c2730f" +dependencies = [ + "bitflags 2.5.0", + "errno", + "libc", + "linux-raw-sys", + "windows-sys 0.52.0", +] + +[[package]] +name = "ryu" +version = "1.0.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e86697c916019a8588c99b5fac3cead74ec0b4b819707a682fd4d23fa0ce1ba1" + +[[package]] +name = "safetensors" +version = "0.4.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8ced76b22c7fba1162f11a5a75d9d8405264b467a07ae0c9c29be119b9297db9" +dependencies = [ + "serde", + "serde_json", +] + +[[package]] +name = "same-file" +version = "1.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" +dependencies = [ + "winapi-util", +] + +[[package]] +name = "seq-macro" +version = "0.3.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a3f0bf26fd526d2a95683cd0f87bf103b8539e2ca1ef48ce002d67aad59aa0b4" + +[[package]] +name = "serde" +version = "1.0.198" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9846a40c979031340571da2545a4e5b7c4163bdae79b301d5f86d03979451fcc" +dependencies = [ + "serde_derive", +] + +[[package]] +name = "serde_derive" +version = "1.0.198" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e88edab869b01783ba905e7d0153f9fc1a6505a96e4ad3018011eedb838566d9" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "serde_json" +version = "1.0.116" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3e17db7126d17feb94eb3fad46bf1a96b034e8aacbc2e775fe81505f8b0b2813" +dependencies = [ + "itoa", + "ryu", + "serde", +] + +[[package]] +name = "serde_spanned" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "eb3622f419d1296904700073ea6cc23ad690adbd66f13ea683df73298736f0c1" +dependencies = [ + "serde", +] + +[[package]] +name = "signal-hook-registry" +version = "1.4.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a9e9e0b4211b72e7b8b6e85c807d36c212bdb33ea8587f7569562a84df5465b1" +dependencies = [ + "libc", +] + +[[package]] +name = "simple_logger" +version = "4.3.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e7e46c8c90251d47d08b28b8a419ffb4aede0f87c2eea95e17d1d5bacbf3ef1" +dependencies = [ + "colored", + "log", + "time", + "windows-sys 0.48.0", +] + +[[package]] +name = "slab" +version = "0.4.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" +dependencies = [ + "autocfg", +] + +[[package]] +name = "smol" +version = "2.0.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e635339259e51ef85ac7aa29a1cd991b957047507288697a690e80ab97d07cad" +dependencies = [ + "async-channel", + "async-executor", + "async-fs", + "async-io", + "async-lock", + "async-net", + "async-process", + "blocking", + "futures-lite", +] + +[[package]] +name = "stable_deref_trait" +version = "1.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" + +[[package]] +name = "syn" +version = "2.0.60" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "909518bc7b1c9b779f1bbf07f2929d35af9f0f37e47c6e9ef7f9dddc1e1821f3" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "synstructure" +version = "0.13.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "sysctl" +version = "0.5.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ec7dddc5f0fee506baf8b9fdb989e242f17e4b11c61dfbb0635b705217199eea" +dependencies = [ + "bitflags 2.5.0", + "byteorder", + "enum-as-inner", + "libc", + "thiserror", + "walkdir", +] + +[[package]] +name = "thiserror" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0126ad08bff79f29fc3ae6a55cc72352056dfff61e3ff8bb7129476d44b23aa" +dependencies = [ + "thiserror-impl", +] + +[[package]] +name = "thiserror-impl" +version = "1.0.59" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d1cd413b5d558b4c5bf3680e324a6fa5014e7b7c067a51e69dbdf47eb7148b66" +dependencies = [ + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "time" +version = "0.3.36" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" +dependencies = [ + "deranged", + "itoa", + "libc", + "num-conv", + "num_threads", + "powerfmt", + "serde", + "time-core", + "time-macros", +] + +[[package]] +name = "time-core" +version = "0.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" + +[[package]] +name = "time-macros" +version = "0.2.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" +dependencies = [ + "num-conv", + "time-core", +] + +[[package]] +name = "toml" +version = "0.8.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e9dd1545e8208b4a5af1aa9bbd0b4cf7e9ea08fabc5d0a5c67fcaafa17433aa3" +dependencies = [ + "serde", + "serde_spanned", + "toml_datetime", + "toml_edit", +] + +[[package]] +name = "toml_datetime" +version = "0.6.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3550f4e9685620ac18a50ed434eb3aec30db8ba93b0287467bca5826ea25baf1" +dependencies = [ + "serde", +] + +[[package]] +name = "toml_edit" +version = "0.22.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3328d4f68a705b2a4498da1d580585d39a6510f98318a2cec3018a7ec61ddef" +dependencies = [ + "indexmap", + "serde", + "serde_spanned", + "toml_datetime", + "winnow", +] + +[[package]] +name = "tracing" +version = "0.1.40" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c3523ab5a71916ccf420eebdf5521fcef02141234bbc0b8a49f2fdc4544364ef" +dependencies = [ + "pin-project-lite", + "tracing-core", +] + +[[package]] +name = "tracing-core" +version = "0.1.32" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c06d3da6113f116aaee68e4d601191614c9053067f9ab7f6edbcb161237daa54" + +[[package]] +name = "unicode-ident" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" + +[[package]] +name = "walkdir" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "29790946404f91d9c5d06f9874efddea1dc06c5efe94541a7d6863108e3a5e4b" +dependencies = [ + "same-file", + "winapi-util", +] + +[[package]] +name = "wasi" +version = "0.11.0+wasi-snapshot-preview1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" + +[[package]] +name = "winapi-util" +version = "0.1.7" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "134306a13c5647ad6453e8deaec55d3a44d6021970129e6188735e74bf546697" +dependencies = [ + "windows-sys 0.52.0", +] + +[[package]] +name = "windows-sys" +version = "0.48.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" +dependencies = [ + "windows-targets 0.48.5", +] + +[[package]] +name = "windows-sys" +version = "0.52.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" +dependencies = [ + "windows-targets 0.52.5", +] + +[[package]] +name = "windows-targets" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" +dependencies = [ + "windows_aarch64_gnullvm 0.48.5", + "windows_aarch64_msvc 0.48.5", + "windows_i686_gnu 0.48.5", + "windows_i686_msvc 0.48.5", + "windows_x86_64_gnu 0.48.5", + "windows_x86_64_gnullvm 0.48.5", + "windows_x86_64_msvc 0.48.5", +] + +[[package]] +name = "windows-targets" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6f0713a46559409d202e70e28227288446bf7841d3211583a4b53e3f6d96e7eb" +dependencies = [ + "windows_aarch64_gnullvm 0.52.5", + "windows_aarch64_msvc 0.52.5", + "windows_i686_gnu 0.52.5", + "windows_i686_gnullvm", + "windows_i686_msvc 0.52.5", + "windows_x86_64_gnu 0.52.5", + "windows_x86_64_gnullvm 0.52.5", + "windows_x86_64_msvc 0.52.5", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7088eed71e8b8dda258ecc8bac5fb1153c5cffaf2578fc8ff5d61e23578d3263" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9985fd1504e250c615ca5f281c3f7a6da76213ebd5ccc9561496568a2752afb6" + +[[package]] +name = "windows_i686_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "88ba073cf16d5372720ec942a8ccbf61626074c6d4dd2e745299726ce8b89670" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "87f4261229030a858f36b459e748ae97545d6f1ec60e5e0d6a3d32e0dc232ee9" + +[[package]] +name = "windows_i686_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "db3c2bf3d13d5b658be73463284eaf12830ac9a26a90c717b7f771dfe97487bf" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4e4246f76bdeff09eb48875a0fd3e2af6aada79d409d33011886d3e1581517d9" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "852298e482cd67c356ddd9570386e2862b5673c85bd5f88df9ab6802b334c596" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.48.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bec47e5bfd1bff0eeaf6d8b485cc1074891a197ab4225d504cb7a1ab88b02bf0" + +[[package]] +name = "winnow" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0c976aaaa0e1f90dbb21e9587cdaf1d9679a1cde8875c0d6bd83ab96a208352" +dependencies = [ + "memchr", +] + +[[package]] +name = "yoke" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "65e71b2e4f287f467794c671e2b8f8a5f3716b3c829079a1c44740148eff07e4" +dependencies = [ + "serde", + "stable_deref_trait", + "yoke-derive", + "zerofrom", +] + +[[package]] +name = "yoke-derive" +version = "0.7.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9e6936f0cce458098a201c245a11bef556c6a0181129c7034d10d76d1ec3a2b8" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zerofrom" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "655b0814c5c0b19ade497851070c640773304939a6c0fd5f5fb43da0696d05b7" +dependencies = [ + "zerofrom-derive", +] + +[[package]] +name = "zerofrom-derive" +version = "0.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e6a647510471d372f2e6c2e6b7219e44d8c574d24fdc11c610a61455782f18c3" +dependencies = [ + "proc-macro2", + "quote", + "syn", + "synstructure", +] + +[[package]] +name = "zip" +version = "0.6.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "760394e246e4c28189f19d488c058bf16f564016aefac5d32bb1f3b51d5e9261" +dependencies = [ + "byteorder", + "crc32fast", + "crossbeam-utils", +] diff --git a/pyr/Cargo.toml b/pyr/Cargo.toml new file mode 100644 index 0000000..218f251 --- /dev/null +++ b/pyr/Cargo.toml @@ -0,0 +1,24 @@ +[package] +name = "pyr" +version = "0.1.0" +edition = "2021" + +[lib] +name = "pyr" +crate-type = ["cdylib"] + +[profile.release] +strip = true +opt-level = 2 +# lto = true + +[dependencies] +smol = "2.0.0" +log = "0.4.21" +simple_logger = "4.3.3" +lazy_static = "1.4.0" +candle-nn = "0.4.1" +candle-core = "0.4.1" +rand = "0.8.5" +toml = "0.8.12" +serde = {version = "1.0.198", features = ["derive"]} diff --git a/pyr/src/app/action.rs b/pyr/src/app/action.rs new file mode 100644 index 0000000..c2b368c --- /dev/null +++ b/pyr/src/app/action.rs @@ -0,0 +1,22 @@ +use crate::Direction; + +#[derive(PartialEq, Eq, Hash, Clone)] +pub enum AIAction { + Up, + Down, + Left, + Right, + None, +} + +impl From for Direction { + fn from(value: AIAction) -> Self { + match value { + AIAction::Up => Direction::Up, + AIAction::Down => Direction::Down, + AIAction::Left => Direction::Left, + AIAction::Right => Direction::Right, + AIAction::None => Direction::None, + } + } +} diff --git a/pyr/src/app/agent.rs b/pyr/src/app/agent.rs new file mode 100644 index 0000000..8c0c88f --- /dev/null +++ b/pyr/src/app/agent.rs @@ -0,0 +1,202 @@ +use std::collections::VecDeque; +use std::path::Path; + +use rand::distributions::Uniform; +use rand::{thread_rng, Rng}; + +use candle_core::{DType, Device, Module, Tensor}; +use candle_nn::{linear, seq, Activation, AdamW, Optimizer, Sequential, VarBuilder, VarMap}; + +use crate::CONFIG; + +use super::state::OBSERVATION_SPACE; +use super::{action::AIAction, huber::huber_loss, state::AIState}; + +const DEVICE: Device = Device::Cpu; + +const ACTION_SPACE: usize = 5; + +pub struct AIAgent { + var_map: VarMap, + model: Sequential, + optimizer: AdamW, + memory: VecDeque<(Tensor, u32, Tensor, f64)>, + old_state: Option, + step: usize, + accumulate_rewards: f64, +} + +impl AIAgent { + pub async fn new() -> Self { + let mut var_map = VarMap::new(); + if Path::new("model.bin").exists() { + var_map.load("model.bin").unwrap(); + } + let vb = VarBuilder::from_varmap(&var_map, DType::F32, &DEVICE); + let model = seq() + .add(linear(OBSERVATION_SPACE, 60, vb.pp("linear_in")).unwrap()) + .add(Activation::LeakyRelu(0.01)) + .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap()) + .add(Activation::LeakyRelu(0.01)) + .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap()) + .add(Activation::LeakyRelu(0.01)) + .add(linear(48, ACTION_SPACE, vb.pp("linear_out")).unwrap()) + .add(Activation::LeakyRelu(0.01)); + + let optimizer = AdamW::new_lr(var_map.all_vars(), CONFIG.learning_rate).unwrap(); + + Self { + var_map, + model, + optimizer, + memory: VecDeque::new(), + old_state: None, + step: 0, + accumulate_rewards: 0.0, + } + } + fn get_reward(&self, new_state: &AIState) -> f64 { + let old_state = self.old_state.as_ref().unwrap(); + let new_positive_distance = new_state + .get_postivie_food() + .map(|food| food.x + food.y) + .unwrap_or(0.0); + let old_positive_distance = old_state + .get_postivie_food() + .map(|food| food.x + food.y) + .unwrap_or(0.0); + let new_negative_distance = new_state + .get_negative_food() + .map(|food| food.x + food.y) + .unwrap_or(0.0); + let old_negative_distance = old_state + .get_negative_food() + .map(|food| food.x + food.y) + .unwrap_or(0.0); + + return (old_positive_distance - new_positive_distance) as f64 + + (new_negative_distance - old_negative_distance) as f64 + + 100.0*(new_state.player.score - old_state.player.score) as f64; + } + pub fn tick(&mut self, state: AIState) -> AIAction { + self.step += 1; + if self.old_state.is_none() { + self.old_state = Some(state); + return AIAction::None; + } + let old_state = self.old_state.as_ref().unwrap(); + + let action: u32 = match thread_rng().gen_ratio(CONFIG.exploration_rate, 4096) { + true if CONFIG.train => thread_rng().gen_range(0..(ACTION_SPACE as u32)), + _ => self + .model + .forward(&old_state.into_tensor()) + .unwrap() + .squeeze(0) + .unwrap() + .argmax(0) + .unwrap() + .to_scalar() + .unwrap(), + }; + + if CONFIG.train { + let reward = self.get_reward(&state); + self.accumulate_rewards += reward; + + self.memory.push_front(( + self.old_state + .as_ref() + .unwrap() + .into_tensor() + .squeeze(0) + .unwrap(), + action, + state.into_tensor().squeeze(0).unwrap(), + reward, + )); + self.memory.truncate(CONFIG.replay_size); + if self.step % CONFIG.update_frequency == 0 && self.memory.len() > CONFIG.batch_size { + self.train(); + } + } + + self.old_state = Some(state); + + match action { + 0 => AIAction::None, + 1 => AIAction::Up, + 2 => AIAction::Left, + 3 => AIAction::Right, + _ => AIAction::Down, + } + } + fn train(&mut self) { + // Sample randomly from the memory. + let batch = thread_rng() + .sample_iter(Uniform::from(0..self.memory.len())) + .take(CONFIG.batch_size) + .map(|i| self.memory.get(i).unwrap().clone()) + .collect::>(); + + // Group all the samples together into tensors with the appropriate shape. + let states: Vec<_> = batch.iter().map(|e| e.0.clone()).collect(); + let states = Tensor::stack(&states, 0).unwrap(); + + let actions = batch.iter().map(|e| e.1); + let actions = Tensor::from_iter(actions, &DEVICE) + .unwrap() + .unsqueeze(1) + .unwrap(); + + let next_states: Vec<_> = batch.iter().map(|e| e.2.clone()).collect(); + let next_states = Tensor::stack(&next_states, 0).unwrap(); + + let rewards = batch.iter().map(|e| e.3 as f32); + let rewards = Tensor::from_iter(rewards, &DEVICE) + .unwrap() + .unsqueeze(1) + .unwrap(); + + let non_final_mask = batch.iter().map(|_| true as u8 as f32); + let non_final_mask = Tensor::from_iter(non_final_mask, &DEVICE) + .unwrap() + .unsqueeze(1) + .unwrap(); + + // Get the estimated rewards for the actions that where taken at each step. + let estimated_rewards = self.model.forward(&states).unwrap(); + let x = estimated_rewards.gather(&actions, 1).unwrap(); + + // Get the maximum expected rewards for the next state, apply them a discount rate + // GAMMA and add them to the rewards that were actually gathered on the current state. + // If the next state is a terminal state, just omit maximum estimated + // rewards for that state. + let expected_rewards = self.model.forward(&next_states).unwrap().detach(); + let y = expected_rewards.max_keepdim(1).unwrap(); + let y = (y * CONFIG.gamma * non_final_mask + rewards).unwrap(); + + // Compare the estimated rewards with the maximum expected rewards and + // perform the backward step. + let loss = huber_loss(1.0_f32)(&x, &y); + log::trace!("loss: {:?}", loss); + self.optimizer + .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap()) + .unwrap(); + } + pub fn check_point(&mut self) { + self.memory.clear(); + if CONFIG.train { + self.var_map.save("model.bin").unwrap(); + log::info!("model.bin saved!"); + } + } +} + +// impl Drop for AIAgent { +// fn drop(&mut self) { +// self.var_map.save("model.bin").unwrap(); +// log::info!("model.bin saved!"); +// log::info!("Rewards {}", self.accumulate_rewards as i64); +// } +// } diff --git a/pyr/src/app/huber.rs b/pyr/src/app/huber.rs new file mode 100644 index 0000000..7ab9d61 --- /dev/null +++ b/pyr/src/app/huber.rs @@ -0,0 +1,32 @@ +use candle_core::{Tensor, WithDType}; + +pub trait Half +where + Self: WithDType + Copy, +{ + const HALF: Self; +} + +impl Half for f64 { + const HALF: f64 = 0.5; +} +impl Half for f32 { + const HALF: f32 = 0.5; +} + +pub fn huber_loss(threshold: D) -> impl Fn(&Tensor, &Tensor) -> D { + move |x: &Tensor, y: &Tensor| { + let diff = (x - y).unwrap(); + let diff_scaler = diff + .abs() + .unwrap() + .sum_all() + .unwrap() + .to_scalar::() + .unwrap(); + match diff_scaler < threshold { + true => ::HALF * diff_scaler, + false => threshold * (diff_scaler - ::HALF * threshold), + } + } +} diff --git a/pyr/src/app/mod.rs b/pyr/src/app/mod.rs new file mode 100644 index 0000000..8fc5f66 --- /dev/null +++ b/pyr/src/app/mod.rs @@ -0,0 +1,40 @@ +mod action; +mod agent; +mod huber; +mod state; + +use smol::block_on; + +use crate::data::prelude::*; + +use self::agent::AIAgent; + +pub struct TickState { + pub frame: u64, + pub player: Player, + pub opponent: Opponent, + pub foods: Vec, +} + +struct AppState {} + +pub struct App { + state: AppState, + agent: AIAgent, +} + +impl App { + pub fn new() -> Self { + let agent = block_on(AIAgent::new()); + Self { + state: AppState {}, + agent, + } + } + pub fn run(&mut self, tick: TickState) -> Direction { + self.agent.tick(tick.into()).into() + } + pub fn check_point(&mut self) { + self.agent.check_point(); + } +} diff --git a/pyr/src/app/state.rs b/pyr/src/app/state.rs new file mode 100644 index 0000000..5cf98d4 --- /dev/null +++ b/pyr/src/app/state.rs @@ -0,0 +1,109 @@ +use candle_core::{Device, Tensor}; + +use crate::{Food, Opponent, Player}; + +use super::TickState; + +pub const OBSERVATION_SPACE: usize = 14; + +#[derive(Clone)] +pub struct AIState { + pub frame: u64, + pub player: Player, + pub opponent: Opponent, + pub foods: Vec, +} + +impl From for AIState { + fn from(value: TickState) -> Self { + Self { + player: value.player, + opponent: value.opponent, + foods: value.foods, + frame: value.frame, + } + } +} + +fn food_distance<'a>(player: &'a Player) -> impl FnMut(&&Food) -> i32 + 'a { + move |food: &&Food| { + let dx = player.x - food.x; + let dy = player.y - food.y; + ((dx + dy) * 100.0) as i32 + } +} +impl AIState { + pub fn get_postivie_food(&self) -> Option<&Food> { + self.foods + .iter() + .filter(|x| x.score.is_sign_positive()) + .min_by_key(food_distance(&self.player)) + } + pub fn get_negative_food(&self) -> Option<&Food> { + self.foods + .iter() + .filter(|x| x.score.is_sign_negative()) + .min_by_key(food_distance(&self.player)) + } + pub fn into_tensor(&self) -> Tensor { + Tensor::new(&[self.into_feature()], &Device::Cpu).unwrap() + } + fn into_feature(&self) -> [f32; OBSERVATION_SPACE] { + let x = self.player.x; + let y = self.player.y; + // sort food into four group by two line (x+y=0, x-y=0) + let mut food_group = [ + 0.0, + 0.0, + 0.0, + 0.0, + self.opponent.x - self.player.x / 700.0, + self.opponent.y - self.player.y / 700.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + 0.0, + ]; + for food in self.foods.iter().filter(|x| x.score.is_sign_positive()) { + let dx = food.x - x; + let dy = food.y - y; + let group = match (dx + dy, dx - dy) { + (a, b) if a.is_sign_positive() && b.is_sign_positive() => 0, + (a, b) if a.is_sign_positive() && b.is_sign_positive() => 1, + (a, b) if a.is_sign_negative() && b.is_sign_negative() => 2, + _ => 3, + }; + food_group[group] += 10.0 / (dx + dy); + } + for food in self.foods.iter().filter(|x| x.score.is_sign_negative()) { + let dx = food.x - x; + let dy = food.y - y; + let group = match (dx + dy, dx - dy) { + (a, b) if a.is_sign_positive() && b.is_sign_positive() => 6, + (a, b) if a.is_sign_positive() && b.is_sign_positive() => 7, + (a, b) if a.is_sign_negative() && b.is_sign_negative() => 8, + _ => 9, + }; + food_group[group] += 10.0 / (dx + dy); + } + self.get_postivie_food().map(|food| { + let dx = food.x - x; + let dy = food.y - y; + food_group[10] = dx as f32; + food_group[11] = dy as f32; + }); + self.get_negative_food().map(|food| { + let dx = food.x - x; + let dy = food.y - y; + food_group[12] = dx as f32; + food_group[13] = dy as f32; + }); + + + food_group + } +} diff --git a/pyr/src/data/config.rs b/pyr/src/data/config.rs new file mode 100644 index 0000000..5f49982 --- /dev/null +++ b/pyr/src/data/config.rs @@ -0,0 +1,35 @@ +use serde::{Deserialize, Serialize}; + +lazy_static::lazy_static! { + pub static ref CONFIG: Config = { + match std::fs::read_to_string("config.toml"){ + Ok(content)=>toml::from_str(&content).unwrap(), + Err(_)=>Config::default() + } + }; +} + +#[derive(Serialize, Deserialize)] +pub struct Config { + pub exploration_rate: u32, + pub update_frequency: usize, + pub batch_size: usize, + pub replay_size: usize, + pub learning_rate: f64, + pub gamma: f64, + pub train: bool, +} + +impl Default for Config { + fn default() -> Self { + Self { + exploration_rate: 1024, + update_frequency: 150, + batch_size: 32, + replay_size: 250, + learning_rate: 0.04, + gamma: 0.99, + train: true, + } + } +} diff --git a/pyr/src/data/internal.rs b/pyr/src/data/internal.rs new file mode 100644 index 0000000..851b1d0 --- /dev/null +++ b/pyr/src/data/internal.rs @@ -0,0 +1,37 @@ +#[derive(Clone)] +pub struct Player { + pub x: f32, + pub y: f32, + pub height: f32, + pub width: f32, + pub level: f32, + pub velocity: f32, + pub score: f32, +} +#[derive(Clone)] +pub struct Opponent { + pub x: f32, + pub y: f32, + pub level: f32, +} + +#[derive(Clone, Debug)] +pub struct Food { + pub x: f32, + pub y: f32, + pub width: f32, + pub height: f32, + pub score: f32, +} + +impl Default for Food { + fn default() -> Self { + Food { + x: 1000000.0, + y: 1000000.0, + width: 1.0, + height: 1.0, + score: 0.0, + } + } +} diff --git a/pyr/src/data/mod.rs b/pyr/src/data/mod.rs new file mode 100644 index 0000000..7db4c58 --- /dev/null +++ b/pyr/src/data/mod.rs @@ -0,0 +1,13 @@ +mod config; +mod internal; +mod raw; + +pub mod parser { + pub use super::config::CONFIG; + pub use super::raw::*; +} + +pub mod prelude { + pub use super::internal::*; + pub use super::raw::Direction; +} diff --git a/pyr/src/data/raw.rs b/pyr/src/data/raw.rs new file mode 100644 index 0000000..d34599e --- /dev/null +++ b/pyr/src/data/raw.rs @@ -0,0 +1,82 @@ +use super::internal::*; + +#[repr(C)] +#[derive(Debug)] +pub struct RawOverall { + pub frame: u64, + score: i64, + score_to_pass: i64, + self_x: i64, + self_y: i64, + self_h: i64, + self_w: i64, + self_vel: i64, + self_lv: i64, + opponent_x: i64, + opponent_y: i64, + opponent_lv: i64, +} + +impl RawOverall { + pub fn get_player(&self) -> Player { + Player { + x: (self.self_x - 350) as f32, + y: (self.self_y - 350) as f32, + height: self.self_h as f32, + width: self.self_w as f32, + level: self.self_lv as f32, + velocity: self.self_vel as f32, + score: self.score as f32, + } + } + pub fn get_opponent(&self) -> Opponent { + Opponent { + x: (self.opponent_x - 350) as f32, + y: (self.opponent_y - 350) as f32, + level: self.opponent_lv as f32, + } + } +} + +#[repr(C)] +#[derive(Debug, Clone)] +pub struct RawFood { + pub h: i64, + pub w: i64, + pub x: i64, + pub y: i64, + pub score: i64, + pub kind: i32, +} + +impl From for Food { + fn from(value: RawFood) -> Self { + Food { + x: value.x as f32, + y: value.y as f32, + width: value.w as f32, + height: value.h as f32, + score: value.score as f32, + } + } +} + +#[repr(i32)] +#[derive(Debug)] +pub enum FoodKind { + Food1 = 1, + Food2 = 2, + Food3 = 3, + Garbage1 = 4, + Garbage2 = 5, + Garbage3 = 6, +} + +#[repr(i32)] +pub enum Direction { + Up = 1, + Down = 2, + Left = 3, + Right = 4, + None = 5, +} diff --git a/pyr/src/lib.rs b/pyr/src/lib.rs new file mode 100644 index 0000000..887a429 --- /dev/null +++ b/pyr/src/lib.rs @@ -0,0 +1,53 @@ +mod app; +mod data; + +use std::slice; + +use app::{App, TickState}; +use data::parser::*; +use data::prelude::*; +use simple_logger::SimpleLogger; + +#[no_mangle] +pub unsafe extern "C" fn tick( + app: *mut App, + overall: &RawOverall, + food: *mut RawFood, + len: u64, +) -> i32 { + let app = &mut *app; + + let state = { + let foods: Vec = slice::from_raw_parts(food, len as usize) + .into_iter() + .map(|x| x.to_owned().into()) + .collect(); + TickState { + frame: overall.frame, + player: overall.get_player(), + opponent: overall.get_opponent(), + foods, + } + }; + + app.run(state) as i32 +} + +#[no_mangle] +pub unsafe extern "C" fn check_point(app: *mut App) { + let app = &mut *app; + app.check_point(); +} + +#[no_mangle] +pub unsafe extern "C" fn new_app() -> *const App { + SimpleLogger::new().init().unwrap(); + log::info!("Initializing App..."); + let a = Box::into_raw(Box::new(App::new())); + a +} + +#[no_mangle] +pub unsafe extern "C" fn drop_app(app: *mut App) { + // drop(Box::from_raw(app)) +} diff --git a/pyr/src/main.rs b/pyr/src/main.rs new file mode 100644 index 0000000..e6a37c9 --- /dev/null +++ b/pyr/src/main.rs @@ -0,0 +1,25 @@ +// use candle_core::{DType, Device}; +// use candle_nn::{linear, loss::mse, seq, Activation, AdamW, VarBuilder, VarMap}; + +fn main() { + // let mut var_map = VarMap::new(); + // var_map.load("model.bin").unwrap(); + // let vb = VarBuilder::from_varmap(&var_map, DType::F32, &Device::Cpu); + // let model = seq() + // .add(linear(14, 60, vb.pp("linear_in")).unwrap()) + // .add(Activation::LeakyRelu(0.01)) + // .add(linear(60, 48, vb.pp("linear_mid_1")).unwrap()) + // .add(Activation::LeakyRelu(0.01)) + // .add(linear(48, 48, vb.pp("linear_mid_2")).unwrap()) + // .add(Activation::LeakyRelu(0.01)) + // .add(linear(48, 5, vb.pp("linear_out")).unwrap()) + // .add(Activation::LeakyRelu(0.01)); + + // let optimizer = AdamW::new_lr(var_map.all_vars(), 0.5).unwrap(); + + // let target = Tensor::new(&[0.0], &Device::Cpu).unwrap(); + + // self.optimizer + // .backward_step(&Tensor::new(&[loss], &DEVICE).unwrap()) + // .unwrap(); +} diff --git a/pyr/test.py b/pyr/test.py new file mode 100755 index 0000000..ff6ccc2 --- /dev/null +++ b/pyr/test.py @@ -0,0 +1,20 @@ +#!/usr/bin/env python3 +from ctypes import * + +mylib = cdll.LoadLibrary("./target/release/libpyr.so") + +class Point(Structure): + _fields_ = [("x", c_uint64), ("y", c_uint64)] + +point=Point() +point.x=1 +point.y=2 + +ptr=pointer(point) + +print(ptr) + +mylib.set_point(ptr) + +print(point.x) +print(point.y)