diff --git a/.github/workflows/demo.yml b/.github/workflows/demo.yml new file mode 100644 index 0000000..dcf5259 --- /dev/null +++ b/.github/workflows/demo.yml @@ -0,0 +1,18 @@ +# See https://fly.io/docs/app-guides/continuous-deployment-with-github-actions/ + +name: Demo +on: + push: + branches: + - main +jobs: + deploy: + name: Deploy app + runs-on: ubuntu-latest + concurrency: deploy-group # optional: ensure only one action runs at a time + steps: + - uses: actions/checkout@v4 + - uses: superfly/flyctl-actions/setup-flyctl@master + - run: flyctl deploy --remote-only + env: + FLY_API_TOKEN: ${{ secrets.FLY_API_TOKEN }} diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml index 3e2e21d..e6fc09c 100644 --- a/.github/workflows/release.yml +++ b/.github/workflows/release.yml @@ -93,12 +93,3 @@ jobs: make_latest: true env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} - - deploy: - name: Deploy demo app - needs: docker - runs-on: ubuntu-latest - steps: - - name: Deploy to Coolify - run: | - curl --request GET '${{ secrets.COOLIFY_WEBHOOK }}' --header 'Authorization: Bearer ${{ secrets.COOLIFY_TOKEN }}' diff --git a/Dockerfile b/Dockerfile index 8470909..6cbb3be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -13,9 +13,9 @@ COPY ./server/Cargo.toml ./server/Cargo.lock ./ ARG pkg=chat-rs-api RUN apt-get update -qq && apt-get install -y -qq pkg-config libpq-dev && apt-get clean -RUN --mount=type=cache,id=s/6916a5c5-fde9-46e7-934c-cf425dd70d0e-rust_target,target=/app/target \ - --mount=type=cache,id=s/6916a5c5-fde9-46e7-934c-cf425dd70d0e-cargo_registry,target=/usr/local/cargo/registry \ - --mount=type=cache,id=s/6916a5c5-fde9-46e7-934c-cf425dd70d0e-cargo_git,target=/usr/local/cargo/git \ +RUN --mount=type=cache,id=rust_target,target=/app/target \ + --mount=type=cache,id=cargo_registry,target=/usr/local/cargo/registry \ + --mount=type=cache,id=cargo_git,target=/usr/local/cargo/git \ set -eux; \ cargo build --release; \ objcopy --compress-debug-sections target/release/$pkg ./run-server @@ -29,7 +29,7 @@ ENV PNPM_HOME="/pnpm" ENV PATH="$PNPM_HOME:$PATH" COPY ./web/package.json ./web/pnpm-lock.yaml ./ -RUN --mount=type=cache,id=s/6916a5c5-fde9-46e7-934c-cf425dd70d0e-pnpm,target=/pnpm/store pnpm install --frozen-lockfile +RUN --mount=type=cache,id=pnpm,target=/pnpm/store pnpm install --frozen-lockfile COPY ./web/src src COPY ./web/public public diff --git a/README.md b/README.md index 570d6cd..642da3b 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ A fast, secure, self-hostable chat application built with Rust, TypeScript, and React. Chat with multiple AI providers using your own API keys, with real-time streaming built-in. -Demo link: https://rs-chat-demo.up.railway.app/ (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete all data. Please also don't enter any sensitive information or confidential data) +Demo link: https://rs-chat.fly.dev/ (⚠️ This is a demo - don't expect your account/chats to be there when you come back. It may intermittently delete all data. Please also don't enter any sensitive information or confidential data) ## ✨ Features diff --git a/fly.toml b/fly.toml new file mode 100644 index 0000000..6de2e03 --- /dev/null +++ b/fly.toml @@ -0,0 +1,40 @@ +app = "rs-chat" +primary_region = "ewr" + +[build] + +[env] +RUST_LOG = "info" +RS_CHAT_SERVER_ADDRESS = "https://rs-chat.fly.dev" +DOCKER_CERT_PATH = "/certs" +DOCKER_TLS_VERIFY = "1" + +[http_service] +internal_port = 8080 +force_https = true +auto_stop_machines = "stop" +auto_start_machines = true +min_machines_running = 0 +processes = ["app"] + +[[vm]] +memory = "1gb" +cpu_kind = "shared" +cpus = 2 + +[mounts] +source = "data" +destination = "/data" +initial_size = "2gb" + +[[files]] +guest_path = "/certs/ca.pem" +secret_name = "DOCKER_CA_CERT" + +[[files]] +guest_path = "/certs/cert.pem" +secret_name = "DOCKER_CLIENT_CERT" + +[[files]] +guest_path = "/certs/key.pem" +secret_name = "DOCKER_CLIENT_KEY" diff --git a/server/Cargo.lock b/server/Cargo.lock index 1f3a8a8..c4b9aca 100644 --- a/server/Cargo.lock +++ b/server/Cargo.lock @@ -75,12 +75,6 @@ dependencies = [ "memchr", ] -[[package]] -name = "android-tzdata" -version = "0.1.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e999941b234f3131b00bc13c22d06e8c5ff726d1b6318ac7eb276997bbb4fef0" - [[package]] name = "android_system_properties" version = "0.1.5" @@ -98,9 +92,9 @@ checksum = "69f7f8c3906b62b754cd5326047894316021dcfe5a194c8ea52bdd94934a3457" [[package]] name = "astral-tokio-tar" -version = "0.5.2" +version = "0.5.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1abb2bfba199d9ec4759b797115ba6ae435bdd920ce99783bb53aeff57ba919b" +checksum = "0036af73142caf1291d4ec8ed667d3a1145bd55c8189517bd5aa07b3167ae1e1" dependencies = [ "filetime", "futures-core", @@ -112,13 +106,33 @@ dependencies = [ "xattr", ] +[[package]] +name = "async-io" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "19634d6336019ef220f09fd31168ce5c184b295cbf80345437cc36094ef223ca" +dependencies = [ + "async-lock", + "cfg-if", + "concurrent-queue", + "futures-io", + "futures-lite", + "parking", + "polling", + "rustix", + "slab", + "windows-sys 0.60.2", +] + [[package]] name = "async-lock" -version = "2.8.0" +version = "3.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "287272293e9d8c41773cec55e365490fe034813a2f172f502d6ddcf75b2f582b" +checksum = "5fd03604047cee9b6ce9de9f70c6cd540a0520c813cbd49bae61f33ab80ed1dc" dependencies = [ "event-listener", + "event-listener-strategy", + "pin-project-lite", ] [[package]] @@ -143,17 +157,6 @@ dependencies = [ "syn 2.0.102", ] -[[package]] -name = "async-timer" -version = "0.7.4" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ba5fa6ed76cb2aa820707b4eb9ec46f42da9ce70b0eafab5e5e34942b38a44d5" -dependencies = [ - "libc", - "wasm-bindgen", - "winapi", -] - [[package]] name = "async-trait" version = "0.1.88" @@ -300,6 +303,31 @@ dependencies = [ "serde_with", ] +[[package]] +name = "bon" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c2529c31017402be841eb45892278a6c21a000c0a17643af326c73a73f83f0fb" +dependencies = [ + "bon-macros", + "rustversion", +] + +[[package]] +name = "bon-macros" +version = "3.7.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d82020dadcb845a345591863adb65d74fa8dc5c18a0b6d408470e13b7adc7005" +dependencies = [ + "darling 0.21.3", + "ident_case", + "prettyplease", + "proc-macro2", + "quote", + "rustversion", + "syn 2.0.102", +] + [[package]] name = "borrow-or-share" version = "0.2.2" @@ -389,7 +417,7 @@ dependencies = [ "fred", "hex", "jsonschema", - "rand 0.9.1", + "rand 0.9.2", "reqwest", "rocket", "rocket_flex_session", @@ -411,17 +439,16 @@ dependencies = [ [[package]] name = "chrono" -version = "0.4.41" +version = "0.4.42" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c469d952047f47f91b68d1cba3f10d63c11d73e4636f24f08daf0278abf01c4d" +checksum = "145052bdd345b87320e369255277e3fb5152762ad123a901ef5c262dd38fe8d2" dependencies = [ - "android-tzdata", "iana-time-zone", "js-sys", "num-traits", "serde", "wasm-bindgen", - "windows-link", + "windows-link 0.2.0", ] [[package]] @@ -434,6 +461,15 @@ dependencies = [ "inout", ] +[[package]] +name = "concurrent-queue" +version = "2.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ca0197aee26d1ae37445ee532fefce43251d24cc7c166799f4d46817f1d3973" +dependencies = [ + "crossbeam-utils", +] + [[package]] name = "const_format" version = "0.2.34" @@ -518,6 +554,12 @@ version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "338089f42c427b86394a5ee60ff321da23a5c89c9d89514c829687b26359fcff" +[[package]] +name = "crossbeam-utils" +version = "0.8.21" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d0a5c400df2834b80a4c3327b3aad3a4c4cd4de0629063962b03235697506a28" + [[package]] name = "crypto-common" version = "0.1.6" @@ -558,6 +600,16 @@ dependencies = [ "darling_macro 0.20.11", ] +[[package]] +name = "darling" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9cdf337090841a411e2a7f3deb9187445851f91b309c0c0a29e05f74a00a48c0" +dependencies = [ + "darling_core 0.21.3", + "darling_macro 0.21.3", +] + [[package]] name = "darling_core" version = "0.13.4" @@ -586,6 +638,20 @@ dependencies = [ "syn 2.0.102", ] +[[package]] +name = "darling_core" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1247195ecd7e3c85f83c8d2a366e4210d588e802133e1e355180a9870b517ea4" +dependencies = [ + "fnv", + "ident_case", + "proc-macro2", + "quote", + "strsim 0.11.1", + "syn 2.0.102", +] + [[package]] name = "darling_macro" version = "0.13.4" @@ -608,13 +674,25 @@ dependencies = [ "syn 2.0.102", ] +[[package]] +name = "darling_macro" +version = "0.21.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d38308df82d1080de0afee5d069fa14b0326a88c14f15c5ccda35b4a6c414c81" +dependencies = [ + "darling_core 0.21.3", + "quote", + "syn 2.0.102", +] + [[package]] name = "deadpool" -version = "0.12.2" +version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "5ed5957ff93768adf7a65ab167a17835c3d2c3c50d084fe305174c112f468e2f" +checksum = "0be2b1d1d6ec8d846f05e137292d0b89133caf95ef33695424c09568bdd39b1b" dependencies = [ "deadpool-runtime", + "lazy_static", "num_cpus", "tokio", ] @@ -673,9 +751,9 @@ dependencies = [ [[package]] name = "diesel" -version = "2.2.10" +version = "2.2.12" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "ff3e1edb1f37b4953dd5176916347289ed43d7119cc2e6c7c3f7849ff44ea506" +checksum = "229850a212cd9b84d4f0290ad9d294afc0ae70fccaa8949dbe8b43ffafa1e20c" dependencies = [ "bitflags", "byteorder", @@ -815,9 +893,9 @@ dependencies = [ [[package]] name = "dyn-clone" -version = "1.0.19" +version = "1.0.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1c7a8fb8a9fbf66c1f703fe16184d10ca0ee9d23be5b4436400408ba54a95005" +checksum = "d0881ea181b1df73ff77ffaaf9c7544ecc11e82fba9b5f27b262a3c73a332555" [[package]] name = "either" @@ -845,18 +923,18 @@ dependencies = [ [[package]] name = "enum-iterator" -version = "2.1.0" +version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "c280b9e6b3ae19e152d8e31cf47f18389781e119d4013a2a2bb0180e5facc635" +checksum = "a4549325971814bda7a44061bf3fe7e487d447cba01e4220a4b454d630d7a016" dependencies = [ "enum-iterator-derive", ] [[package]] name = "enum-iterator-derive" -version = "1.4.0" +version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "a1ab991c1362ac86c61ab6f556cff143daa22e5a15e4e189df818b2fd19fe65b" +checksum = "685adfa4d6f3d765a26bc5dbc936577de9abf756c1feeb3089b01dd395034842" dependencies = [ "proc-macro2", "quote", @@ -881,9 +959,24 @@ dependencies = [ [[package]] name = "event-listener" -version = "2.5.3" +version = "5.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e13b66accf52311f30a0db42147dadea9850cb48cd070028831ae5f5d4b856ab" +dependencies = [ + "concurrent-queue", + "parking", + "pin-project-lite", +] + +[[package]] +name = "event-listener-strategy" +version = "0.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "0206175f82b8d6bf6652ff7d71a1e27fd2e4efde587fd368662814d6ec1d9ce0" +checksum = "8be9f3dfaaffdae2972880079a491a1a8bb7cbed0b8dd7a347f668b4150a3b93" +dependencies = [ + "event-listener", + "pin-project-lite", +] [[package]] name = "fallible-iterator" @@ -898,8 +991,8 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6e24cb5a94bcae1e5408b0effca5cd7172ea3c5755049c5f3af4cd283a165298" dependencies = [ "bit-set", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", + "regex-automata", + "regex-syntax", ] [[package]] @@ -1064,6 +1157,19 @@ version = "0.3.31" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e5c1b78ca4aae1ac06c48a526a655760685149f0d465d21f37abfe57ce075c6" +[[package]] +name = "futures-lite" +version = "2.6.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f78e10609fe0e0b3f4157ffab1876319b5b0db102a2c60dc4626306dc46b44ad" +dependencies = [ + "fastrand", + "futures-core", + "futures-io", + "parking", + "pin-project-lite", +] + [[package]] name = "futures-macro" version = "0.3.31" @@ -1695,7 +1801,7 @@ dependencies = [ "percent-encoding", "referencing", "regex", - "regex-syntax 0.8.5", + "regex-syntax", "serde", "serde_json", "uuid-simd", @@ -1775,11 +1881,11 @@ checksum = "112b39cec0b298b6c1999fee3e31427f74f676e4cb9879ed1a121b43661a4154" [[package]] name = "matchers" -version = "0.1.0" +version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "8263075bb86c5a1b1427b5ae862e8889656f126e9f77c484496e8b47cf5c5558" +checksum = "d1525a2a28c7f4fa0fc98bb91ae755d1e2d1505079e05539e35bc876b5d65ae9" dependencies = [ - "regex-automata 0.1.10", + "regex-automata", ] [[package]] @@ -1861,12 +1967,11 @@ dependencies = [ [[package]] name = "nu-ansi-term" -version = "0.46.0" +version = "0.50.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77a8165726e8236064dbb45459242600304b42a5ea24ee2948e18e023bf7ba84" +checksum = "d4a28e057d01f97e61255210fcff094d74ed0466038633e95017f5beb68e4399" dependencies = [ - "overload", - "winapi", + "windows-sys 0.52.0", ] [[package]] @@ -2010,10 +2115,10 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e" [[package]] -name = "overload" -version = "0.1.1" +name = "parking" +version = "2.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "b15813163c1d831bf4a13c3610c05c0d03b39feb07f7e09fa234dac9b15aaf39" +checksum = "f38d5652c16fde515bb1ecef450ab0f6a219d619a7274976324d5e377f7dceba" [[package]] name = "parking_lot" @@ -2097,6 +2202,20 @@ version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" +[[package]] +name = "polling" +version = "3.10.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5bd19146350fe804f7cb2669c851c03d69da628803dab0d98018142aaa5d829" +dependencies = [ + "cfg-if", + "concurrent-queue", + "hermit-abi", + "pin-project-lite", + "rustix", + "windows-sys 0.60.2", +] + [[package]] name = "polyval" version = "0.6.2" @@ -2128,7 +2247,7 @@ dependencies = [ "hmac", "md-5", "memchr", - "rand 0.9.1", + "rand 0.9.2", "sha2", "stringprep", ] @@ -2178,6 +2297,16 @@ dependencies = [ "vcpkg", ] +[[package]] +name = "prettyplease" +version = "0.2.34" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6837b9e10d61f45f987d50808f83d1ee3d206c66acf650c3e4ae2e1f6ddedf55" +dependencies = [ + "proc-macro2", + "syn 2.0.102", +] + [[package]] name = "proc-macro2" version = "1.0.95" @@ -2229,7 +2358,7 @@ dependencies = [ "bytes", "getrandom 0.3.3", "lru-slab", - "rand 0.9.1", + "rand 0.9.2", "ring", "rustc-hash", "rustls 0.23.31", @@ -2283,9 +2412,9 @@ dependencies = [ [[package]] name = "rand" -version = "0.9.1" +version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "9fbfd9d094a40bf3ae768db9361049ace4c0e04a4fd6b359518bd7b73a73dd97" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" dependencies = [ "rand_chacha 0.9.0", "rand_core 0.9.3", @@ -2394,17 +2523,8 @@ checksum = "b544ef1b4eac5dc2db33ea63606ae9ffcfac26c1416a2806ae0bf5f56b201191" dependencies = [ "aho-corasick", "memchr", - "regex-automata 0.4.9", - "regex-syntax 0.8.5", -] - -[[package]] -name = "regex-automata" -version = "0.1.10" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "6c230d73fb8d8c1b9c0b3135c5142a8acee3a0558fb8db5cf1cb65f8d7862132" -dependencies = [ - "regex-syntax 0.6.29", + "regex-automata", + "regex-syntax", ] [[package]] @@ -2415,15 +2535,9 @@ checksum = "809e8dc61f6de73b46c85f4c96486310fe304c434cfa43669d7b40f711150908" dependencies = [ "aho-corasick", "memchr", - "regex-syntax 0.8.5", + "regex-syntax", ] -[[package]] -name = "regex-syntax" -version = "0.6.29" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f162c6dd7b008981e4d40210aca20b4bd0f9b60ca9271061b07f78537722f2e1" - [[package]] name = "regex-syntax" version = "0.8.5" @@ -2432,9 +2546,9 @@ checksum = "2b15c43186be67a4fd63bee50d0303afffcef381492ebe2c5d87f324e1b8815c" [[package]] name = "reqwest" -version = "0.12.20" +version = "0.12.23" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "eabf4c97d9130e2bf606614eb937e86edac8292eaa6f422f995d7e8de1eb1813" +checksum = "d429f34c8092b2d42c7c93cec323bb4adeb7c67698f70839adec842ec10c7ceb" dependencies = [ "base64 0.22.1", "bytes", @@ -2473,14 +2587,15 @@ dependencies = [ [[package]] name = "retainer" -version = "0.3.0" +version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "df8c01a8276c14d0f8d51ebcf8a48f0748f9f73f5f6b29e688126e6a52bcb145" +checksum = "7b071fe646a2ab077f74656a4602c16528829c1fafa81946c5e88eaeccf08d5b" dependencies = [ + "async-io", "async-lock", - "async-timer", + "futures-lite", "log", - "rand 0.8.5", + "rand 0.9.2", ] [[package]] @@ -2555,17 +2670,17 @@ dependencies = [ [[package]] name = "rocket_flex_session" -version = "0.1.1" -source = "git+https://github.com/fa-sharp/rocket-flex-session#fa04c65e8195135bfcf1c7a1cd28c3fc38b96c23" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a5a9f1ecfad02f26e916d6bdd25d1a9d50011e0e1d42c226bfdf447846426db7" dependencies = [ + "bon", "fred", - "rand 0.8.5", + "rand 0.9.2", "retainer", "rocket", "rocket_okapi", - "serde", "thiserror", - "time", ] [[package]] @@ -2944,9 +3059,9 @@ dependencies = [ [[package]] name = "serde_json" -version = "1.0.140" +version = "1.0.143" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "20068b6e96dc6c9bd23e01df8827e6c7e1f2fddd43c21810382803c136b99373" +checksum = "d401abef1d108fbd9cbaebc3e46611f4b1021f714a0597a71f41ee463f5f4a5a" dependencies = [ "itoa", "memchr", @@ -3199,18 +3314,18 @@ dependencies = [ [[package]] name = "thiserror" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "567b8a2dae586314f7be2a752ec7474332959c6460e02bde30d702a66d488708" +checksum = "3467d614147380f2e4e374161426ff399c91084acd2363eaf549172b3d5e60c0" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" -version = "2.0.12" +version = "2.0.16" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7f7cf42b4507d8ea322120659672cf1b9dbb93f8f2d4ecfd6e51350ff5b17a1d" +checksum = "6c5e1be1c48b9172ee610da68fd9cd2770e7a4056cb3fc98710ee6906f0c7960" dependencies = [ "proc-macro2", "quote", @@ -3330,7 +3445,7 @@ dependencies = [ "pin-project-lite", "postgres-protocol", "postgres-types", - "rand 0.9.1", + "rand 0.9.2", "socket2", "tokio", "tokio-util", @@ -3522,14 +3637,14 @@ dependencies = [ [[package]] name = "tracing-subscriber" -version = "0.3.19" +version = "0.3.20" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "e8189decb5ac0fa7bc8b96b7cb9b2701d60d48805aca84a238004d665fcc4008" +checksum = "2054a14f5307d601f88daf0553e1cbf472acc4f2c51afab632431cdcd72124d5" dependencies = [ "matchers", "nu-ansi-term", "once_cell", - "regex", + "regex-automata", "serde", "serde_json", "sharded-slab", @@ -3652,9 +3767,9 @@ checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "uuid" -version = "1.17.0" +version = "1.18.1" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "3cf4199d1e5d15ddd86a694e4d0dffa9c323ce759fea589f00fef9d81cc1931d" +checksum = "2f87b8aa10b915a06587d0dec516c282ff295b475d94abf425d62b57710070a2" dependencies = [ "getrandom 0.3.3", "js-sys", @@ -3881,7 +3996,7 @@ checksum = "c0fdd3ddb90610c7638aa2b3a3ab2904fb9e5cdbecc643ddb3647212781c4ae3" dependencies = [ "windows-implement", "windows-interface", - "windows-link", + "windows-link 0.1.1", "windows-result", "windows-strings", ] @@ -3914,13 +4029,19 @@ version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76840935b766e1b0a05c0066835fb9ec80071d4c09a16f6bd5f7e655e3c14c38" +[[package]] +name = "windows-link" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "45e46c0661abb7180e7b9c281db115305d49ca1709ab8242adf09666d2173c65" + [[package]] name = "windows-result" version = "0.3.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56f42bd332cc6c8eac5af113fc0c1fd6a8fd2aa08a0119358686e5160d0586c6" dependencies = [ - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -3929,7 +4050,7 @@ version = "0.4.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "56e6c93f3a0c3b36176cb1327a4958a0353d5d166c2a35cb268ace15e91d3b57" dependencies = [ - "windows-link", + "windows-link 0.1.1", ] [[package]] @@ -3950,6 +4071,15 @@ dependencies = [ "windows-targets 0.52.6", ] +[[package]] +name = "windows-sys" +version = "0.60.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f2f500e4d28234f72040990ec9d39e3a6b950f9f22d3dba18416c35882612bcb" +dependencies = [ + "windows-targets 0.53.2", +] + [[package]] name = "windows-targets" version = "0.48.5" @@ -3974,13 +4104,29 @@ dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", - "windows_i686_gnullvm", + "windows_i686_gnullvm 0.52.6", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] +[[package]] +name = "windows-targets" +version = "0.53.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c66f69fcc9ce11da9966ddb31a40968cad001c5bedeb5c2b82ede4253ab48aef" +dependencies = [ + "windows_aarch64_gnullvm 0.53.0", + "windows_aarch64_msvc 0.53.0", + "windows_i686_gnu 0.53.0", + "windows_i686_gnullvm 0.53.0", + "windows_i686_msvc 0.53.0", + "windows_x86_64_gnu 0.53.0", + "windows_x86_64_gnullvm 0.53.0", + "windows_x86_64_msvc 0.53.0", +] + [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" @@ -3993,6 +4139,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "86b8d5f90ddd19cb4a147a5fa63ca848db3df085e25fee3cc10b39b6eebae764" + [[package]] name = "windows_aarch64_msvc" version = "0.48.5" @@ -4005,6 +4157,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" +[[package]] +name = "windows_aarch64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c7651a1f62a11b8cbd5e0d42526e55f2c99886c77e007179efff86c2b137e66c" + [[package]] name = "windows_i686_gnu" version = "0.48.5" @@ -4017,12 +4175,24 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" +[[package]] +name = "windows_i686_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c1dc67659d35f387f5f6c479dc4e28f1d4bb90ddd1a5d3da2e5d97b42d6272c3" + [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" +[[package]] +name = "windows_i686_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9ce6ccbdedbf6d6354471319e781c0dfef054c81fbc7cf83f338a4296c0cae11" + [[package]] name = "windows_i686_msvc" version = "0.48.5" @@ -4035,6 +4205,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" +[[package]] +name = "windows_i686_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "581fee95406bb13382d2f65cd4a908ca7b1e4c2f1917f143ba16efe98a589b5d" + [[package]] name = "windows_x86_64_gnu" version = "0.48.5" @@ -4047,6 +4223,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" +[[package]] +name = "windows_x86_64_gnu" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2e55b5ac9ea33f2fc1716d1742db15574fd6fc8dadc51caab1c16a3d3b4190ba" + [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" @@ -4059,6 +4241,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0a6e035dd0599267ce1ee132e51c27dd29437f63325753051e71dd9e42406c57" + [[package]] name = "windows_x86_64_msvc" version = "0.48.5" @@ -4071,6 +4259,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "windows_x86_64_msvc" +version = "0.53.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "271414315aff87387382ec3d271b52d7ae78726f5d44ac98b4f4030c91880486" + [[package]] name = "winnow" version = "0.7.10" diff --git a/server/Cargo.toml b/server/Cargo.toml index 35a060c..a0de9f2 100644 --- a/server/Cargo.toml +++ b/server/Cargo.toml @@ -4,66 +4,66 @@ version = "0.6.0" edition = "2021" publish = false -[profile.ci] -inherits = "dev" -codegen-units = 256 - -[profile.release] -opt-level = 2 -lto = "thin" -codegen-units = 1 -panic = "abort" -strip = true - [dependencies] aes-gcm = "0.10.3" -astral-tokio-tar = "0.5.2" +astral-tokio-tar = "0.5.3" base64 = "0.22.1" bollard = { version = "0.19.1", features = ["ssl"] } -chrono = { version = "0.4.41", features = ["serde"] } +chrono = { version = "0.4.42", features = ["serde"] } const_format = "0.2.34" -deadpool = { version = "0.12.2", features = ["rt_tokio_1"] } -diesel = { version = "2.2.10", features = [ - "postgres", - "chrono", - "uuid", - "serde_json", +deadpool = { version = "0.12.3", features = ["rt_tokio_1"] } +diesel = { version = "2.2.12", features = [ + "chrono", + "postgres", + "serde_json", + "uuid", ] } -diesel-async = { version = "0.5.2", features = ["deadpool", "postgres"] } -diesel-derive-enum = { version = "3.0.0-beta.1", features = ["postgres"] } diesel_as_jsonb = "1.0.1" diesel_async_migrations = "0.15.0" +diesel-async = { version = "0.5.2", features = ["deadpool", "postgres"] } +diesel-derive-enum = { version = "3.0.0-beta.1", features = ["postgres"] } dotenvy = "0.15.7" -dyn-clone = "1.0.19" -enum-iterator = "2.1.0" +dyn-clone = "1.0.20" +enum-iterator = "2.3.0" fred = { version = "10.1.0", default-features = false, features = [ - "i-keys", - "i-streams", + "i-keys", + "i-streams", ] } hex = "0.4.3" jsonschema = { version = "0.30.0", default-features = false } -rand = "0.9.1" -reqwest = { version = "0.12.20", default-features = false, features = [ - "json", - "stream", - "rustls-tls-native-roots", +rand = "0.9.2" +reqwest = { version = "0.12.23", default-features = false, features = [ + "json", + "rustls-tls-native-roots", + "stream", ] } rocket = { version = "0.5.1", features = ["json", "uuid"] } -rocket_flex_session = { version = "0.1.1", git = "https://github.com/fa-sharp/rocket-flex-session", features = [ - "redis_fred", - "rocket_okapi", +rocket_flex_session = { version = "0.2.0", features = [ + "redis_fred", + "rocket_okapi", ] } rocket_oauth2 = "0.5.0" rocket_okapi = { version = "0.9.0", features = ["rapidoc"] } schemars = { version = "0.8.22", features = ["chrono", "uuid1"] } serde = { version = "1.0.219" } -serde_json = "1.0.140" +serde_json = "1.0.143" subst = { version = "0.3.8", features = ["json"] } -thiserror = "2.0.12" +thiserror = "2.0.16" tokio = { version = "1.45.1" } tokio-stream = "0.1.17" tokio-util = { version = "0.7.16", features = ["io"] } tracing = "0.1.41" -tracing-subscriber = { version = "0.3.19", features = ["json", "env-filter"] } +tracing-subscriber = { version = "0.3.20", features = ["env-filter", "json"] } urlencoding = "2.1.3" -uuid = { version = "1.17.0", features = ["v4", "serde"] } +uuid = { version = "1.18.1", features = ["serde", "v4"] } + +[profile.ci] +inherits = "dev" +codegen-units = 256 + +[profile.release] +opt-level = 2 +strip = true +lto = "thin" +panic = "abort" +codegen-units = 1 diff --git a/server/migrations/2025-09-03-063406_remove_old_tools/down.sql b/server/migrations/2025-09-03-063406_remove_old_tools/down.sql new file mode 100644 index 0000000..dab05c2 --- /dev/null +++ b/server/migrations/2025-09-03-063406_remove_old_tools/down.sql @@ -0,0 +1,15 @@ +-- Add back tools table +CREATE TABLE tools ( + id UUID PRIMARY KEY DEFAULT gen_random_uuid(), + user_id UUID NOT NULL REFERENCES users (id), + name TEXT NOT NULL, + description TEXT NOT NULL, + config JSONB NOT NULL, + created_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_at TIMESTAMPTZ NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +SELECT + diesel_manage_updated_at ('tools'); + +CREATE INDEX tools_user_id_idx ON tools (user_id); diff --git a/server/migrations/2025-09-03-063406_remove_old_tools/up.sql b/server/migrations/2025-09-03-063406_remove_old_tools/up.sql new file mode 100644 index 0000000..f019705 --- /dev/null +++ b/server/migrations/2025-09-03-063406_remove_old_tools/up.sql @@ -0,0 +1,2 @@ +-- Drop old tools table +DROP TABLE tools; diff --git a/server/src/api/auth.rs b/server/src/api/auth.rs index 9cf1141..bafc4c6 100644 --- a/server/src/api/auth.rs +++ b/server/src/api/auth.rs @@ -19,7 +19,7 @@ use crate::{ /// Auth routes pub fn get_routes(settings: &OpenApiSettings) -> (Vec, OpenApi) { - openapi_get_routes_spec![settings: user, auth_config, logout, delete_account] + openapi_get_routes_spec![settings: user, get_sessions, auth_config, logout, delete_account] } /// # Get User @@ -30,6 +30,31 @@ async fn user(user: ChatRsUser) -> Result, ApiError> { Ok(Json(user)) } +/// # Get Sessions +/// Get current sessions +#[openapi(tag = "Auth")] +#[get("/sessions")] +async fn get_sessions( + user_id: ChatRsUserId, + session: Session<'_, ChatRsAuthSession>, +) -> Result>, ApiError> { + let sessions: Vec = session + .get_sessions_by_identifier(&user_id.simple().to_string()) + .await + .map_err(|err| ApiError::Server(err.to_string()))? + .into_iter() + .map(|(id, data, ttl)| SessionResponse { + id, + started: data.start_time, + ip: data.ip, + user_agent: data.user_agent, + ttl, + }) + .collect(); + + Ok(Json(sessions)) +} + /// The current auth configuration of the server #[derive(Debug, JsonSchema, OpenApiFromRequest, serde::Serialize)] struct AuthConfig { @@ -93,6 +118,16 @@ async fn auth_config(config: AuthConfig) -> Json { Json(config) } +#[derive(Debug, serde::Serialize, JsonSchema)] +struct SessionResponse { + id: String, + #[schemars(with = "chrono::DateTime")] + started: Option, + ttl: u32, + ip: Option, + user_agent: Option, +} + /// # Log out #[openapi(tag = "Auth")] #[post("/logout")] diff --git a/server/src/auth.rs b/server/src/auth.rs index 30784ff..5ebd47e 100644 --- a/server/src/auth.rs +++ b/server/src/auth.rs @@ -2,6 +2,7 @@ mod api_key; mod guard; mod oauth; mod session; +mod session_meta; mod sso_header; use rocket::fairing::AdHoc; diff --git a/server/src/auth/guard.rs b/server/src/auth/guard.rs index ebd3487..d620d4e 100644 --- a/server/src/auth/guard.rs +++ b/server/src/auth/guard.rs @@ -49,9 +49,7 @@ impl<'r> FromRequest<'r> for ChatRsUserId { // Try authentication via session let session = try_outcome!(req.guard::>().await); - if let Some(user_id) = - session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) - { + if let Some(user_id) = session.tap(|data| data.map(|auth_session| auth_session.user_id)) { return Outcome::Success(ChatRsUserId(user_id)); } diff --git a/server/src/auth/oauth.rs b/server/src/auth/oauth.rs index 25268ee..217cfae 100644 --- a/server/src/auth/oauth.rs +++ b/server/src/auth/oauth.rs @@ -3,14 +3,14 @@ mod github; mod google; mod oidc; -use rocket::{fairing::AdHoc, figment::Figment, http::CookieJar, response::Redirect, Route}; +use rocket::{fairing::AdHoc, figment::Figment, http::CookieJar, response::Redirect, Route, State}; use rocket_flex_session::Session; use rocket_oauth2::{HyperRustlsAdapter, OAuth2, OAuthConfig, StaticProvider, TokenResponse}; use serde::Deserialize; use std::future::Future; use crate::{ - auth::ChatRsAuthSession, + auth::{session_meta::SessionMeta, ChatRsAuthSession}, config::{get_app_config, get_config_provider}, db::{ models::{ChatRsUser, NewChatRsUser, UpdateChatRsUser}, @@ -135,11 +135,10 @@ async fn generic_login_callback( mut db: DbConnection, token: TokenResponse, config: &P::Config, + client: &State, mut session: Session<'_, ChatRsAuthSession>, + meta: SessionMeta<'_>, ) -> Result { - let client = reqwest::Client::builder() - .build() - .map_err(|e| ApiError::Authentication(format!("Failed to build reqwest client: {}", e)))?; let mut request = client .get(P::new(config).get_user_info_url()) .header("Authorization", format!("Bearer {}", token.access_token())); @@ -162,13 +161,14 @@ async fn generic_login_callback( match P::find_linked_user(&mut db_service, &user_data).await? { // Existing linked user found: create new session Some(existing_user) => { - session.set(ChatRsAuthSession::new(existing_user.id)); + session.set(ChatRsAuthSession::new(existing_user.id, meta)); } - None => match session.tap(|data| data.and_then(|auth_session| auth_session.user_id())) { + // No linked user found, check for active session + None => match session.tap(|data| data.map(|auth_session| auth_session.user_id)) { // No linked user and no session found: create new user and session None => { let new_user = db_service.create(P::create_new_user(&user_data)).await?; - session.set(ChatRsAuthSession::new(new_user.id)); + session.set(ChatRsAuthSession::new(new_user.id, meta)); } // No linked user but there is a current session Some(user_id) => { diff --git a/server/src/auth/oauth/discord.rs b/server/src/auth/oauth/discord.rs index 49852ec..328907b 100644 --- a/server/src/auth/oauth/discord.rs +++ b/server/src/auth/oauth/discord.rs @@ -4,6 +4,7 @@ use rocket_oauth2::{OAuth2, StaticProvider, TokenResponse}; use serde::Deserialize; use crate::{ + auth::session_meta::SessionMeta, db::{ models::{ChatRsUser, NewChatRsUser, UpdateChatRsUser}, services::UserDbService, @@ -130,7 +131,9 @@ async fn discord_login_callback( db: DbConnection, token: TokenResponse, config: &State, + client: &State, session: Session<'_, ChatRsAuthSession>, + meta: SessionMeta<'_>, ) -> Result { - generic_login_callback::(db, token, config, session).await + generic_login_callback::(db, token, config, client, session, meta).await } diff --git a/server/src/auth/oauth/github.rs b/server/src/auth/oauth/github.rs index ffb3657..ef81e2b 100644 --- a/server/src/auth/oauth/github.rs +++ b/server/src/auth/oauth/github.rs @@ -4,6 +4,7 @@ use rocket_oauth2::{OAuth2, StaticProvider, TokenResponse}; use serde::Deserialize; use crate::{ + auth::session_meta::SessionMeta, db::{ models::{ChatRsUser, NewChatRsUser, UpdateChatRsUser}, services::UserDbService, @@ -126,7 +127,9 @@ async fn github_login_callback( db: DbConnection, token: TokenResponse, config: &State, + client: &State, session: Session<'_, ChatRsAuthSession>, + meta: SessionMeta<'_>, ) -> Result { - generic_login_callback::(db, token, config, session).await + generic_login_callback::(db, token, config, client, session, meta).await } diff --git a/server/src/auth/oauth/google.rs b/server/src/auth/oauth/google.rs index 63a98bc..5267311 100644 --- a/server/src/auth/oauth/google.rs +++ b/server/src/auth/oauth/google.rs @@ -4,6 +4,7 @@ use rocket_oauth2::{OAuth2, StaticProvider, TokenResponse}; use serde::Deserialize; use crate::{ + auth::session_meta::SessionMeta, db::{ models::{ChatRsUser, NewChatRsUser, UpdateChatRsUser}, services::UserDbService, @@ -122,7 +123,9 @@ async fn google_login_callback( db: DbConnection, token: TokenResponse, config: &State, + client: &State, session: Session<'_, ChatRsAuthSession>, + meta: SessionMeta<'_>, ) -> Result { - generic_login_callback::(db, token, config.inner(), session).await + generic_login_callback::(db, token, config.inner(), client, session, meta).await } diff --git a/server/src/auth/oauth/oidc.rs b/server/src/auth/oauth/oidc.rs index 52af0b2..4f1cfbd 100644 --- a/server/src/auth/oauth/oidc.rs +++ b/server/src/auth/oauth/oidc.rs @@ -4,6 +4,7 @@ use rocket_oauth2::{OAuth2, StaticProvider, TokenResponse}; use serde::Deserialize; use crate::{ + auth::session_meta::SessionMeta, db::{ models::{ChatRsUser, NewChatRsUser, UpdateChatRsUser}, services::UserDbService, @@ -144,7 +145,9 @@ async fn oidc_login_callback( db: DbConnection, token: TokenResponse, config: &State, + client: &State, session: Session<'_, ChatRsAuthSession>, + meta: SessionMeta<'_>, ) -> Result { - generic_login_callback::(db, token, config, session).await + generic_login_callback::(db, token, config, client, session, meta).await } diff --git a/server/src/auth/session.rs b/server/src/auth/session.rs index 0e9a6c9..f414958 100644 --- a/server/src/auth/session.rs +++ b/server/src/auth/session.rs @@ -1,93 +1,118 @@ -use std::ops::Deref; - use chrono::Utc; use rocket::fairing::AdHoc; -use rocket_flex_session::{storage::redis::RedisFredStorage, RocketFlexSession}; +use rocket_flex_session::{ + error::SessionError, + storage::redis::{RedisFormat, RedisFredStorage, RedisValue, SessionRedis}, + RocketFlexSession, SessionIdentifier, +}; use uuid::Uuid; -use crate::{config::get_app_config, redis::build_redis_pool}; - -const USER_ID_KEY: &str = "user_id"; -const USER_ID_BYTES_KEY: &str = "user_id_bytes"; -const START_TIME_KEY: &str = "start_time"; +use crate::auth::session_meta::SessionMeta; /// Type representing the session data. #[derive(Debug, Clone)] -pub struct ChatRsAuthSession(fred::types::Map); +pub struct ChatRsAuthSession { + pub user_id: Uuid, + pub start_time: Option, + pub ip: Option, + pub user_agent: Option, +} -impl Deref for ChatRsAuthSession { - type Target = fred::types::Map; +/// Rocket fairing that sets up persistent sessions via Redis. +pub fn setup_session() -> AdHoc { + AdHoc::on_ignite("Sessions", |rocket| async { + let pool = rocket.state::().expect("pool exists"); + let storage = RedisFredStorage::builder() + .pool(pool.clone()) + .prefix("sess:") + .index_prefix("sess:user:") + .build(); + let session_fairing = RocketFlexSession::::builder() + .with_options(|opt| { + opt.cookie_name = "auth_rs_chat".to_string(); + opt.ttl = Some(60 * 60 * 24 * 2); // 2 days + opt.rolling = true; + }) + .storage(storage) + .build(); - fn deref(&self) -> &Self::Target { - &self.0 - } + rocket.attach(session_fairing) + }) } impl ChatRsAuthSession { - pub fn new(user_id: Uuid) -> Self { - let mut hash = fred::types::Map::new(); - hash.insert(USER_ID_KEY.into(), user_id.to_string().into()); - hash.insert( - USER_ID_BYTES_KEY.into(), - user_id.as_bytes().as_slice().into(), - ); - hash.insert(START_TIME_KEY.into(), Utc::now().to_rfc3339().into()); - ChatRsAuthSession(hash) + pub fn new(user_id: Uuid, meta: SessionMeta) -> Self { + ChatRsAuthSession { + user_id, + start_time: Some(Utc::now().to_rfc3339()), + ip: meta.ip.map(|ip| ip.to_string()), + user_agent: meta.user_agent.map(|ua| ua.to_owned()), + } } +} +impl SessionIdentifier for ChatRsAuthSession { + type Id = String; - pub fn user_id(&self) -> Option { - self.get(&fred::types::Key::from_static_str(USER_ID_BYTES_KEY)) - .and_then(|val| val.as_bytes()) - .and_then(|bytes| Uuid::from_slice(bytes).ok()) + /// Group sessions by user ID, using lowercase hex keys to track each user's sessions. + fn identifier(&self) -> Option { + Some(hex::encode(self.user_id.as_bytes())) } } -/// Possible errors when parsing session data from Redis hash. -#[derive(thiserror::Error, Debug)] -pub enum SessionParseError { - #[error("Failed to parse")] - ParsingError, +/// Keys used in the session data. +mod keys { + pub const USER_ID_HEX_KEY: &str = "user_id"; + pub const START_TIME_KEY: &str = "start"; + pub const IP_KEY: &str = "ip"; + pub const USER_AGENT_KEY: &str = "ua"; } -/// Convert from Redis hash to session data. -impl TryFrom for ChatRsAuthSession { - type Error = SessionParseError; +impl SessionRedis for ChatRsAuthSession { + const REDIS_FORMAT: RedisFormat = RedisFormat::Map; + type Error = SessionError; - fn try_from(value: fred::prelude::Value) -> Result { - let map = value - .into_map() - .map_err(|_| SessionParseError::ParsingError)?; - Ok(ChatRsAuthSession(map)) - } -} + fn into_redis(self) -> Result { + let user_id_bytes = self.user_id.as_bytes(); + let mut data_pairs = vec![(keys::USER_ID_HEX_KEY.into(), hex::encode(user_id_bytes))]; + for (key, optional_val) in [ + (keys::START_TIME_KEY.into(), self.start_time), + (keys::IP_KEY.into(), self.ip), + (keys::USER_AGENT_KEY.into(), self.user_agent), + ] { + if let Some(val) = optional_val { + data_pairs.push((key, val)); + } + } -/// Convert from session data to Redis hash. -impl From for fred::prelude::Value { - fn from(session: ChatRsAuthSession) -> Self { - fred::types::Value::Map(session.0) + Ok(RedisValue::Map(data_pairs)) } -} -/// Fairing that sets up persistent sessions via Redis. -pub fn setup_session() -> AdHoc { - AdHoc::on_ignite("Sessions", |rocket| async { - let app_config = get_app_config(&rocket); - let config = fred::prelude::Config::from_url(&app_config.redis_url) - .expect("RS_CHAT_REDIS_URL should be valid Redis URL"); - let session_redis_pool = build_redis_pool(config, 2).expect("Failed to build Redis pool"); - let session_fairing: RocketFlexSession = RocketFlexSession::builder() - .with_options(|opt| { - opt.cookie_name = "auth_rs_chat".to_string(); - opt.ttl = Some(60 * 60 * 24 * 2); // 2 days - opt.rolling = true; - }) - .storage(RedisFredStorage::new( - session_redis_pool, - rocket_flex_session::storage::redis::RedisType::Hash, - "sess:", - )) - .build(); + fn from_redis(value: RedisValue) -> Result { + let map = value.into_map().expect("should always be a map"); + let mut user_id = None; + let mut start_time = None; + let mut ip = None; + let mut user_agent = None; + for (key, val) in map { + match key.as_str() { + keys::USER_ID_HEX_KEY => { + let mut bytes = [0_u8; 16]; + hex::decode_to_slice(val, &mut bytes) + .map_err(|e| SessionError::Parsing(e.into()))?; + user_id = Some(Uuid::from_bytes(bytes)) + } + keys::START_TIME_KEY => start_time = Some(val), + keys::IP_KEY => ip = Some(val), + keys::USER_AGENT_KEY => user_agent = Some(val), + _ => (), + } + } - rocket.attach(session_fairing) - }) + Ok(Self { + user_id: user_id.ok_or(SessionError::InvalidData)?, + start_time, + ip, + user_agent, + }) + } } diff --git a/server/src/auth/session_meta.rs b/server/src/auth/session_meta.rs new file mode 100644 index 0000000..4d13496 --- /dev/null +++ b/server/src/auth/session_meta.rs @@ -0,0 +1,21 @@ +use std::net::IpAddr; + +use rocket::request::{FromRequest, Outcome}; + +/// Session metadata extracted from request headers. +pub struct SessionMeta<'r> { + pub ip: Option, + pub user_agent: Option<&'r str>, +} + +#[rocket::async_trait] +impl<'r> FromRequest<'r> for SessionMeta<'r> { + type Error = (); + + async fn from_request(req: &'r rocket::Request<'_>) -> Outcome { + Outcome::Success(SessionMeta { + ip: req.client_ip(), + user_agent: req.headers().get_one("User-Agent"), + }) + } +} diff --git a/server/src/db.rs b/server/src/db.rs index d971847..21048dd 100644 --- a/server/src/db.rs +++ b/server/src/db.rs @@ -22,12 +22,15 @@ use rocket_okapi::OpenApiFromRequest; use crate::config::get_app_config; +/** The PostgreSQL connection pool, stored in Rocket's managed state */ +pub type DbPool = Pool; + /// Database connection, available as a request guard. When used as a request parameter, /// it will retrieve a connection from the managed Postgres pool. #[derive(OpenApiFromRequest)] pub struct DbConnection(pub Object); impl Deref for DbConnection { - type Target = Object; + type Target = AsyncPgConnection; fn deref(&self) -> &Self::Target { &self.0 } @@ -38,12 +41,12 @@ impl DerefMut for DbConnection { } } -/// Retrieve a connection from the managed Postgres pool. Responds with an -/// internal server error if a connection couldn't be retrieved. #[rocket::async_trait] impl<'r> FromRequest<'r> for DbConnection { type Error = &'static str; + /// Retrieve a connection from the managed Postgres pool. Responds with an + /// internal server error if a connection couldn't be retrieved. async fn from_request(req: &'r Request<'_>) -> Outcome { let Some(pool) = req.rocket().state::() else { return Outcome::Error((Status::InternalServerError, "Database not initialized")); @@ -58,52 +61,35 @@ impl<'r> FromRequest<'r> for DbConnection { } } -/** The database pool stored in Rocket's managed state */ -pub type DbPool = Pool; - /// Fairing that sets up and initializes the Postgres database pub fn setup_db() -> AdHoc { AdHoc::on_ignite("Database", |rocket| async { rocket - .attach(AdHoc::on_ignite( - "Initialize database connection", - |rocket| async { - let config = AsyncDieselConnectionManager::::new( - &get_app_config(&rocket).database_url, - ); - let pool: DbPool = Pool::builder(config) - .build() - .expect("Failed to parse database URL"); - let mut conn = pool.get().await.expect("Failed to connect to database"); + .attach(AdHoc::on_ignite("Initialize database", |rocket| async { + let config = AsyncDieselConnectionManager::::new( + &get_app_config(&rocket).database_url, + ); + let pool: DbPool = Pool::builder(config) + .build() + .expect("Failed to parse database URL"); + let mut conn = pool.get().await.expect("Failed to connect to database"); - static MIGRATIONS: EmbeddedMigrations = embed_migrations!(); - MIGRATIONS - .pending_migrations(&mut conn) - .await - .expect("Failed to get pending migrations") - .iter() - .for_each(|migration| { - rocket::info!("Running migration: {}", migration.name); - }); - MIGRATIONS - .run_pending_migrations(&mut conn) - .await - .expect("Database migrations failed"); - rocket::info!("Migrations completed successfully"); + static MIGRATIONS: EmbeddedMigrations = embed_migrations!(); + MIGRATIONS + .run_pending_migrations(&mut conn) + .await + .expect("Database migrations failed"); + rocket::info!("Migrations completed successfully"); - rocket.manage(pool) - }, - )) - .attach(AdHoc::on_shutdown( - "Shutdown database connection", - |rocket| { - Box::pin(async { - if let Some(pool) = rocket.state::() { - rocket::info!("Shutting down database connection"); - pool.close(); - } - }) - }, - )) + rocket.manage(pool) + })) + .attach(AdHoc::on_shutdown("Shutdown database", |rocket| { + Box::pin(async { + if let Some(pool) = rocket.state::() { + rocket::info!("Shutting down database connection"); + pool.close(); + } + }) + })) }) } diff --git a/server/src/db/schema.rs b/server/src/db/schema.rs index 3c39fd1..2485dab 100644 --- a/server/src/db/schema.rs +++ b/server/src/db/schema.rs @@ -107,18 +107,6 @@ diesel::table! { } } -diesel::table! { - tools (id) { - id -> Uuid, - user_id -> Uuid, - name -> Text, - description -> Text, - config -> Jsonb, - created_at -> Timestamptz, - updated_at -> Timestamptz, - } -} - diesel::table! { users (id) { id -> Uuid, @@ -144,7 +132,6 @@ diesel::joinable!(providers -> secrets (api_key_id)); diesel::joinable!(providers -> users (user_id)); diesel::joinable!(secrets -> users (user_id)); diesel::joinable!(system_tools -> users (user_id)); -diesel::joinable!(tools -> users (user_id)); diesel::allow_tables_to_appear_in_same_query!( app_api_keys, @@ -155,6 +142,5 @@ diesel::allow_tables_to_appear_in_same_query!( providers, secrets, system_tools, - tools, users, ); diff --git a/server/src/errors.rs b/server/src/errors.rs index 41ca5f0..20dcc46 100644 --- a/server/src/errors.rs +++ b/server/src/errors.rs @@ -26,6 +26,8 @@ pub enum ApiError { Tool(#[from] ToolError), #[error(transparent)] Io(#[from] std::io::Error), + #[error("Server error: {0}")] + Server(String), } #[derive(Debug, JsonSchema, serde::Serialize)] diff --git a/server/src/provider.rs b/server/src/provider.rs index 9e3c520..85bef0b 100644 --- a/server/src/provider.rs +++ b/server/src/provider.rs @@ -1,13 +1,15 @@ //! LLM providers module -use uuid::Uuid; - mod core; pub use core::*; pub mod models; pub mod providers; mod utils; +use std::collections::HashMap; + +use uuid::Uuid; + use crate::{ db::{ models::{ChatRsMessage, ChatRsMessageRole, ChatRsProviderType}, @@ -17,6 +19,7 @@ use crate::{ errors::ApiError, provider::{models::LlmModel, providers::*}, storage::LocalStorage, + tools::ToolError, }; pub const DEFAULT_MAX_TOKENS: u32 = 2000; @@ -50,7 +53,8 @@ pub fn build_llm_provider_api( } } -/// Convert database messages to the generic messages to send to the provider implementation +/// Extract any attached files, then convert the database messages to the generic format +/// for sending to LLM providers pub async fn build_llm_messages( messages: Vec, user_id: &Uuid, @@ -58,51 +62,64 @@ pub async fn build_llm_messages( db: &mut DbConnection, storage: &LocalStorage, ) -> Result, ApiError> { - let mut llm_messages = Vec::with_capacity(messages.len()); + // Get content of any attached files in the messages + let mut file_map: HashMap = HashMap::new(); + let file_ids: Vec = messages.iter().fold(Vec::new(), |mut acc, message| { + if let Some(file_ids) = message.meta.user.as_ref().and_then(|u| u.files.as_ref()) { + acc.extend(file_ids); + } + acc + }); + for file_id in file_ids { + let file = FileDbService::new(db) + .find_session_file(user_id, session_id, &file_id) + .await?; + let (file_type, content) = file.read_to_string(Some(session_id), storage).await?; + file_map.insert( + file_id, + LlmFileInput { + name: file.path, + content_type: file.content_type, + file_type, + content, + }, + ); + } - for message in messages { - match message.role { + // Convert the messages + let llm_messages = messages + .into_iter() + .map(|message| match message.role { ChatRsMessageRole::User => { - let mut files: Option> = None; - if let Some(file_ids) = message.meta.user.and_then(|u| u.files) { - let mut file_db_service = FileDbService::new(db); - for file_id in file_ids { - let file = file_db_service - .find_session_file(user_id, session_id, &file_id) - .await?; - let (file_type, content) = - file.read_to_string(Some(session_id), storage).await?; - files.get_or_insert_default().push(LlmFileInput { - name: file.path, - content_type: file.content_type, - file_type, - content, - }); - } - } - llm_messages.push(LlmMessage::User(LlmUserMessage { + let files = message.meta.user.and_then(|u| u.files).map(|file_ids| { + file_ids + .iter() + .filter_map(|id| file_map.remove(id)) + .collect() + }); + Ok(LlmMessage::User(LlmUserMessage { text: message.content, files, })) } - ChatRsMessageRole::Assistant => { - llm_messages.push(LlmMessage::Assistant(LlmAssistantMessage { - text: message.content, - tool_calls: message.meta.assistant.and_then(|a| a.tool_calls), - })) - } - ChatRsMessageRole::System => llm_messages.push(LlmMessage::System(message.content)), + ChatRsMessageRole::Assistant => Ok(LlmMessage::Assistant(LlmAssistantMessage { + text: message.content, + tool_calls: message.meta.assistant.and_then(|a| a.tool_calls), + })), + ChatRsMessageRole::System => Ok(LlmMessage::System(message.content)), ChatRsMessageRole::Tool => { if let Some(tool_call) = message.meta.tool_call { - llm_messages.push(LlmMessage::Tool(LlmToolResult { + Ok(LlmMessage::Tool(LlmToolResult { tool_call_id: tool_call.id, tool_name: tool_call.tool_name, content: message.content, })) + } else { + Err(ToolError::ToolCallNotFound) } } - } - } + }) + .collect::, ToolError>>()?; Ok(llm_messages) } diff --git a/server/src/redis.rs b/server/src/redis.rs index 978f51b..baf56be 100644 --- a/server/src/redis.rs +++ b/server/src/redis.rs @@ -20,12 +20,12 @@ const REDIS_POOL_SIZE: usize = 4; const MAX_EXCLUSIVE_CLIENTS: usize = 20; /// Timeout for connecting and executing commands. const CLIENT_TIMEOUT: Duration = Duration::from_secs(6); -/// Interval for checking idle exclusive clients. -const IDLE_TASK_INTERVAL: Duration = Duration::from_secs(30); +/// Interval to check for idle exclusive clients. +const IDLE_TASK_INTERVAL: Duration = Duration::from_secs(60); /// Shut down exclusive clients after this period of inactivity. -const IDLE_TIME: Duration = Duration::from_secs(60); +const IDLE_TIME: Duration = Duration::from_secs(60 * 5); -/// Fairing that sets up and initializes the Redis connection pool. +/// Fairing that sets up and initializes the Redis connection pools. pub fn setup_redis() -> AdHoc { AdHoc::on_ignite("Redis", |rocket| async { rocket @@ -88,7 +88,7 @@ pub fn setup_redis() -> AdHoc { }) } -pub fn build_redis_pool( +fn build_redis_pool( redis_config: fred::prelude::Config, pool_size: usize, ) -> Result { @@ -181,7 +181,8 @@ impl managed::Manager for ExclusiveClientManager { } } -/// Request guard to get a Redis client with an exclusive connection for long-running operations. +/// Represents a Redis client with an exclusive connection for long-running operations. +/// Can be used as a request guard to retrieve a client from the exclusive pool. #[derive(Debug, OpenApiFromRequest)] pub struct ExclusiveRedisClient(pub managed::Object); impl Deref for ExclusiveRedisClient { diff --git a/server/src/storage.rs b/server/src/storage.rs index 5d90cc1..000cfd5 100644 --- a/server/src/storage.rs +++ b/server/src/storage.rs @@ -30,7 +30,7 @@ pub fn setup_storage() -> AdHoc { } impl ChatRsFile { - /// Get the file type and contents for LLM input. Uses base64 URLs for image and PDF files. + /// Get the file type and contents for LLM input. Uses base64 encoding for image and PDF files. pub async fn read_to_string( &self, session_id: Option<&Uuid>, diff --git a/server/src/stream.rs b/server/src/stream.rs index 65be5e6..8606156 100644 --- a/server/src/stream.rs +++ b/server/src/stream.rs @@ -1,7 +1,8 @@ mod llm_writer; mod reader; +mod test_utils; -use std::collections::HashMap; +use std::{collections::HashMap, time::Duration}; use fred::{ prelude::{FredResult, KeysInterface, StreamsInterface}, @@ -35,15 +36,39 @@ pub async fn get_current_chat_streams( redis: &fred::clients::Client, user_id: &Uuid, ) -> FredResult> { + use fred::bytes_utils::Str; + let prefix = get_chat_stream_prefix(user_id); let pattern = format!("{}*", prefix); - let (_, keys): (String, Vec) = redis + let (_, keys): (Str, Vec) = redis .scan_page("0", &pattern, Some(20), Some(ScanType::Stream)) .await?; - Ok(keys + + // Get the last 2 entries in each stream to check if they are still active + let pipeline = redis.pipeline(); + for key in &keys { + let _: () = pipeline.xrevrange(key, "+", "-", Some(2)).await?; + } + let streams: Vec)>>> = pipeline.all().await?; + + let active_keys = keys .into_iter() - .filter_map(|key| Some(key.strip_prefix(&prefix)?.to_string())) - .collect()) + .zip(streams.into_iter()) + .filter_map(|(key, stream)| { + // Filter out streams that have already ended or been cancelled + if let Some(events) = stream { + if events.iter().any(|(_, data)| { + data.get("type") + .is_some_and(|t| *t == "end" || *t == "cancel") + }) { + return None; + } + } + Some(key.strip_prefix(&prefix)?.to_string()) + }) + .collect(); + + Ok(active_keys) } /// Check if the chat stream exists. @@ -58,7 +83,6 @@ pub async fn check_chat_stream_exists( } /// Cancel a stream by adding a `cancel` event to the stream and then deleting it from Redis -/// (not using a pipeline since we need to ensure the `cancel` event is processed before deleting the stream). pub async fn cancel_current_chat_stream( redis: &fred::clients::Client, user_id: &Uuid, @@ -67,6 +91,7 @@ pub async fn cancel_current_chat_stream( let key = get_chat_stream_key(user_id, session_id); let entry: HashMap = RedisStreamChunk::Cancel.into(); let _: () = redis.xadd(&key, true, None, "*", entry).await?; + tokio::time::sleep(Duration::from_millis(500)).await; redis.del(&key).await } diff --git a/server/src/stream/llm_writer.rs b/server/src/stream/llm_writer.rs index 7f2b99c..f5a6dc1 100644 --- a/server/src/stream/llm_writer.rs +++ b/server/src/stream/llm_writer.rs @@ -101,10 +101,9 @@ impl LlmStreamWriter { /// delete the stream from Redis. pub async fn end(&self) -> FredResult<()> { let entry: HashMap = RedisStreamChunk::End.into(); - let pipeline = self.redis.pipeline(); - let _: () = pipeline.xadd(&self.key, true, None, "*", entry).await?; - let _: () = pipeline.del(&self.key).await?; - pipeline.all().await + let _: () = self.redis.xadd(&self.key, true, None, "*", entry).await?; + tokio::time::sleep(FLUSH_INTERVAL).await; + self.redis.del(&self.key).await } /// Process the incoming stream from the LLM provider, intermittently flushing @@ -256,8 +255,8 @@ impl LlmStreamWriter { self.add_to_redis_stream(entries).await } - /// Adds new entries to the Redis stream. Returns a `LlmStreamError::StreamCancelled` error if the - /// stream has been deleted or cancelled. + /// Adds new entries to the Redis stream, while also checking for cancellation. + /// Returns a `LlmStreamError::StreamCancelled` error if the stream has been cancelled. async fn add_to_redis_stream( &self, entries: Vec>, @@ -305,29 +304,13 @@ mod tests { use super::*; use crate::{ provider::{providers::LoremProvider, LlmApiProvider, LlmProviderOptions}, - redis::{ExclusiveClientManager, ExclusiveClientPool}, - stream::{cancel_current_chat_stream, check_chat_stream_exists}, + redis::ExclusiveClientPool, + stream::{ + cancel_current_chat_stream, check_chat_stream_exists, test_utils::setup_redis_pool, + }, }; - use fred::prelude::{Builder, ClientLike, Config}; use std::time::Duration; - async fn setup_redis_pool() -> ExclusiveClientPool { - let config = - Config::from_url("redis://127.0.0.1:6379").unwrap_or_else(|_| Config::default()); - let pool = Builder::from_config(config) - .build_pool(1) - .expect("Failed to build Redis pool"); - pool.init().await.expect("Failed to connect to Redis"); - - let manager = ExclusiveClientManager::new(pool.clone()); - let deadpool: ExclusiveClientPool = deadpool::managed::Pool::builder(manager) - .max_size(3) - .build() - .unwrap(); - - deadpool - } - async fn create_test_writer( redis: &ExclusiveClientPool, user_id: &Uuid, diff --git a/server/src/stream/reader.rs b/server/src/stream/reader.rs index c0eb870..0bd9f7c 100644 --- a/server/src/stream/reader.rs +++ b/server/src/stream/reader.rs @@ -8,7 +8,7 @@ use uuid::Uuid; use crate::{provider::LlmError, redis::ExclusiveRedisClient, stream::get_chat_stream_key}; /// Timeout in milliseconds for the blocking `xread` command. -const XREAD_BLOCK_TIMEOUT: u64 = 5_000; // 5 seconds +const XREAD_BLOCK_TIMEOUT: u64 = 10_000; // 10 seconds /// Utility for reading SSE events from a Redis stream. pub struct SseStreamReader { @@ -31,15 +31,11 @@ impl SseStreamReader { ) -> Result<(Vec, String, bool), LlmError> { let key = get_chat_stream_key(user_id, session_id); let start_event_id = start_event_id.unwrap_or("0-0"); - let (_, prev_events): (String, Vec<(String, HashMap)>) = self - .redis - .xread::>, _, _>(None, None, &key, start_event_id) - .await? - .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command - .ok_or(LlmError::StreamNotFound)?; + let prev_events = self.xread(&key, start_event_id, None, None).await?; + let (last_event_id, is_end) = prev_events .last() - .map(|(id, data)| (id.to_owned(), data.get("type").is_some_and(|t| t == "end"))) + .map(|(id, data)| (id.to_owned(), is_end_event(&data))) .unwrap_or_else(|| (start_event_id.into(), false)); let sse_events = prev_events .into_iter() @@ -49,7 +45,7 @@ impl SseStreamReader { Ok((sse_events, last_event_id, is_end)) } - /// Stream the events from the given Redis stream using a blocking `xread` command. + /// Stream SSE events from the given Redis stream using a blocking `xread` command. pub async fn stream( &self, user_id: &Uuid, @@ -60,7 +56,7 @@ impl SseStreamReader { let key = get_chat_stream_key(user_id, session_id); let mut last_event_id = last_event_id.to_owned(); loop { - match self.get_next_event(&key, &mut last_event_id, tx).await { + match self.next_event(&key, &mut last_event_id).await { Ok((id, data, is_end)) => { let event = convert_redis_event_to_sse((id, data)); if let Err(_) = tx.send(event).await { @@ -79,52 +75,137 @@ impl SseStreamReader { } } - /// Get the next event from the given Redis stream using a blocking `xread` command. - /// - Updates the last event ID - /// - Cancels waiting for the next event if the client disconnects + /// Wait for the next event from the given Redis stream using a blocking `xread` command. + /// - Cancels waiting for the next event upon the blocking timeout + /// - Updates the last event ID with the ID of the received event /// - Returns the event ID, data, and a `bool` indicating whether it's an ending event - async fn get_next_event( + async fn next_event( &self, key: &str, last_event_id: &mut String, - tx: &mpsc::Sender, ) -> Result<(String, HashMap, bool), LlmError> { - let (_, mut events): (String, Vec<(String, HashMap)>) = tokio::select! { - res = self.redis.xread::>, _, _>(Some(1), Some(XREAD_BLOCK_TIMEOUT), key, &*last_event_id) => { - match res?.as_mut().and_then(|streams| streams.pop()) { - Some(stream) => stream, - None => return Err(LlmError::StreamNotFound), - } - }, - _ = tx.closed() => return Err(LlmError::ClientDisconnected) - }; - match events.pop() { - Some((id, data)) => { - *last_event_id = id.clone(); - let is_end = data - .get("type") - .is_some_and(|t| t == "end" || t == "cancel"); - Ok((id, data, is_end)) - } - None => Err(LlmError::NoStreamEvent), - } + let (id, data) = self + .xread(key, last_event_id, Some(1), Some(XREAD_BLOCK_TIMEOUT)) + .await? + .pop() // only reading 1 event + .ok_or(LlmError::NoStreamEvent)?; + *last_event_id = id.clone(); + let is_end = is_end_event(&data); + Ok((id, data, is_end)) } + + /// Read events from the given stream (friendly `XREAD` wrapper that takes care of the + /// weird types). Returns `LlmError::StreamNotFound` if there is no stream. + async fn xread( + &self, + key: &str, + start_event_id: &str, + count: Option, + block: Option, + ) -> Result)>, LlmError> { + let (_key, events) = self + .redis + .xread::>, _, _>(count, block, key, start_event_id) + .await? + .and_then(|mut streams| streams.pop()) // should only be 1 stream since we're sending 1 key in the command + .ok_or(LlmError::StreamNotFound)?; + Ok(events) + } +} + +/// Check if this event is an ending event. +fn is_end_event(data: &HashMap) -> bool { + data.get("type") + .is_some_and(|t| t == "end" || t == "cancel") } /// Convert a Redis stream event into an SSE event. Expects the event hash map to contain /// a "type" and "data" field (e.g. serialized using the appropriate serde tag and content). -fn convert_redis_event_to_sse((id, event): (String, HashMap)) -> Event { - let mut r#type: Option = None; +fn convert_redis_event_to_sse((id, hash): (String, HashMap)) -> Event { + let mut event: Option = None; let mut data: Option = None; - for (key, value) in event { + for (key, value) in hash { match key.as_str() { - "type" => r#type = Some(value), + "type" => event = Some(value), "data" => data = Some(format!(" {value}")), // SSE spec: add space before data _ => {} } } Event::data(data.unwrap_or_default()) - .event(r#type.unwrap_or_else(|| "unknown".into())) + .event(event.unwrap_or_else(|| "unknown".into())) .id(id) } + +#[cfg(test)] +mod tests { + use super::*; + use crate::stream::test_utils::setup_redis_pool; + use fred::prelude::KeysInterface; + use rand::distr::{Alphanumeric, SampleString}; + + type RedisEvent = (String, HashMap); + + #[tokio::test] + async fn xread() -> Result<(), LlmError> { + let redis = setup_redis_pool().await; + let client = redis.get().await.expect("should get a client"); + let reader_client = redis.get().await.expect("should get a client"); + let reader = SseStreamReader::new(ExclusiveRedisClient(reader_client)); + + let stream_key = format!( + "stream_{}", + Alphanumeric.sample_string(&mut rand::rng(), 10) + ); + + let event_1: RedisEvent = ( + "1".into(), + HashMap::from([ + ("type".into(), "message".into()), + ("data".into(), "Hello, world!".into()), + ]), + ); + let event_2: RedisEvent = ( + "2".into(), + HashMap::from([ + ("type".into(), "message".into()), + ("data".into(), "Goodbye, world!".into()), + ]), + ); + let event_3: RedisEvent = ("3".into(), HashMap::from([("type".into(), "end".into())])); + for (id, data) in [event_1, event_2, event_3] { + let _: () = client + .xadd(&stream_key, false, None, id, data) + .await + .expect("should add event to Redis stream"); + } + + let mut event_1 = reader.xread(&stream_key, "0-0", Some(1), None).await?; + assert_eq!(event_1.len(), 1); + let (event_1_id, event_1_data) = event_1.pop().unwrap(); + assert_eq!(event_1_id, "1-0"); + assert_eq!(event_1_data["type"], "message"); + assert_eq!(event_1_data["data"], "Hello, world!"); + + let mut event_2 = reader + .xread(&stream_key, &event_1_id, Some(1), None) + .await?; + assert_eq!(event_2.len(), 1); + let (event_2_id, event_2_data) = event_2.pop().unwrap(); + assert_eq!(event_2_id, "2-0"); + assert_eq!(event_2_data["type"], "message"); + assert_eq!(event_2_data["data"], "Goodbye, world!"); + + let mut event_3 = reader + .xread(&stream_key, &event_2_id, Some(1), None) + .await?; + assert_eq!(event_3.len(), 1); + let (event_3_id, event_3_data) = event_3.pop().unwrap(); + assert_eq!(event_3_id, "3-0"); + assert_eq!(event_3_data["type"], "end"); + + let _: () = client.del(&stream_key).await?; + + Ok(()) + } +} diff --git a/server/src/stream/test_utils.rs b/server/src/stream/test_utils.rs new file mode 100644 index 0000000..aa850ce --- /dev/null +++ b/server/src/stream/test_utils.rs @@ -0,0 +1,20 @@ +#[cfg(test)] +pub(super) async fn setup_redis_pool() -> crate::redis::ExclusiveClientPool { + use crate::redis::{ExclusiveClientManager, ExclusiveClientPool}; + use fred::prelude::{Builder, ClientLike, Config}; + + let url = std::env::var("RS_CHAT_REDIS_URL").unwrap_or("redis://127.0.0.1".to_owned()); + let config = Config::from_url(&url).unwrap(); + let pool = Builder::from_config(config) + .build_pool(1) + .expect("Failed to build Redis pool"); + pool.init().await.expect("Failed to connect to Redis"); + + let manager = ExclusiveClientManager::new(pool.clone()); + let deadpool: ExclusiveClientPool = deadpool::managed::Pool::builder(manager) + .max_size(3) + .build() + .unwrap(); + + deadpool +} diff --git a/web/src/lib/api/chat.ts b/web/src/lib/api/chat.ts index eca7db0..f42018d 100644 --- a/web/src/lib/api/chat.ts +++ b/web/src/lib/api/chat.ts @@ -50,7 +50,7 @@ export async function createChatStream( .pipeThrough(new TextDecoderStream()) .pipeThrough(new EventSourceParserStream()) .getReader(); - loop: while (true) { + while (true) { const { done, value } = await eventStream.read(); if (done) break; @@ -70,7 +70,7 @@ export async function createChatStream( break; case "end": case "cancel": - break loop; + break; default: console.warn(`Unknown event type: ${value.event}`); break; diff --git a/web/src/lib/context/streamManagerData.ts b/web/src/lib/context/streamManagerData.ts index e74bd47..166703f 100644 --- a/web/src/lib/context/streamManagerData.ts +++ b/web/src/lib/context/streamManagerData.ts @@ -48,7 +48,7 @@ export function useStreamManagerData() { updatedData?.messages?.some( (msg) => msg.role === "Assistant" && - new Date(msg.created_at).getTime() > Date.now() - 5000, // Within last 5 seconds + new Date(msg.created_at).getTime() > Date.now() - 30_000, // Within last 30 seconds ) || false; // Retry if no new assistant message