pax_global_header00006660000000000000000000000064150467136430014523gustar00rootroot0000000000000052 comment=90edb9a5ee284440717383db30d0a51aaa7a256a pushpin-1.41.0/000077500000000000000000000000001504671364300132745ustar00rootroot00000000000000pushpin-1.41.0/.github/000077500000000000000000000000001504671364300146345ustar00rootroot00000000000000pushpin-1.41.0/.github/Dockerfile000066400000000000000000000012261504671364300166270ustar00rootroot00000000000000FROM --platform=linux/amd64 ubuntu:24.04 ENV TZ=US \ DEBIAN_FRONTEND=noninteractive RUN apt-get -qq update && apt-get install -y zstd git pkg-config curl make g++ libssl-dev libzmq3-dev qtbase5-dev libboost-dev black # install toolchain RUN curl https://sh.rustup.rs -sSf | \ sh -s -- --default-toolchain stable-x86_64-unknown-linux-gnu -y ENV RUSTUP_HOME="/root/.rustup" \ CARGO_HOME="/root/.cargo" \ PATH=/root/.cargo/bin:$PATH # Keep in sync with minimum support Rust Version for Pushpin RUN rustup install 1.75.0-x86_64-unknown-linux-gnu RUN cargo install cargo-audit RUN rustup component add rustfmt RUN rustup component add clippy pushpin-1.41.0/.github/workflows/000077500000000000000000000000001504671364300166715ustar00rootroot00000000000000pushpin-1.41.0/.github/workflows/codeql.yml000066400000000000000000000025751504671364300206740ustar00rootroot00000000000000name: "CodeQL" on: push: branches: [ "main" ] pull_request: branches: [ "main" ] schedule: - cron: '21 13 * * 6' jobs: analyze: name: Analyze (${{ matrix.language }}) runs-on: 'ubuntu-latest' container: fanout/build-base:latest timeout-minutes: 15 permissions: # required for all workflows security-events: write # required to fetch internal or private CodeQL packs packages: read # only required for workflows in private repositories actions: read contents: read strategy: fail-fast: true matrix: include: - language: c-cpp build-mode: manual - language: javascript-typescript build-mode: none - language: python build-mode: none steps: - name: Checkout repository uses: actions/checkout@v4 - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "codeql" # Initializes the CodeQL tools for scanning. - name: Initialize CodeQL uses: github/codeql-action/init@v3 with: languages: ${{ matrix.language }} build-mode: ${{ matrix.build-mode }} - if: matrix.build-mode == 'manual' shell: bash run: make build - name: Perform CodeQL Analysis uses: github/codeql-action/analyze@v3 with: category: "/language:${{matrix.language}}" pushpin-1.41.0/.github/workflows/suite.yml000066400000000000000000000100711504671364300205440ustar00rootroot00000000000000name: CI on: pull_request: branches: - main - v1 defaults: run: shell: bash -leo pipefail {0} jobs: check: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: check run: RUSTFLAGS="-D warnings" cargo +${{ matrix.rust-version }} c clippy: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: check steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: clippy run: RUSTFLAGS="-D warnings" cargo +${{ matrix.rust-version }} clippy -- -D warnings lint: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: check steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: fmt run: RUSTFLAGS="-D warnings" cargo +${{ matrix.rust-version }} fmt --check - name: format python run: black --check . build: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: check steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: build run: RUSTFLAGS="-D warnings" TOOLCHAIN=${{ matrix.rust-version }} make build audit: strategy: fail-fast: true matrix: rust-version: [stable] runs-on: ubuntu-latest container: fanout/build-base:latest needs: build steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: audit run: RUSTFLAGS="-D warnings" cargo +${{ matrix.rust-version }} audit test: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: build steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: test run: TOOLCHAIN=${{ matrix.rust-version }} make cargo-test benchmark: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: build steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: bench run: RUSTFLAGS="-D warnings" cargo +${{ matrix.rust-version }} bench --no-run build-full: strategy: fail-fast: true matrix: rust-version: [stable, 1.75.0] runs-on: ubuntu-latest container: fanout/build-base:latest needs: [audit, test] steps: - name: Checkout code uses: actions/checkout@v4 with: submodules: recursive - name: Cache uses: Swatinem/rust-cache@v2 with: shared-key: "CI-Suite" - name: cargo fetch run: cargo +${{ matrix.rust-version }} fetch - name: build release run: TOOLCHAIN=${{ matrix.rust-version }} RELEASE=1 RUSTFLAGS="-D warnings" make build pushpin-1.41.0/.gitignore000066400000000000000000000004121504671364300152610ustar00rootroot00000000000000*.swp *~ qrc_* *.moc *.o *.pdb ui_* moc_* *.pyc .qmake.stash target_wrapper.sh /target **/*.rs.bk .cargo vendor .vscode conf.pri /pushpin /pushpin-legacy /bin/* /src/runner/certs/*.crt /src/runner/certs/*.key /postbuild/Makefile /postbuild/*.inst /config /run /log pushpin-1.41.0/CHANGELOG.md000066400000000000000000000334351504671364300151150ustar00rootroot00000000000000Pushpin Changelog ================= v. 1.41.0 (2025-08-08) * Add support for gone links. * WebSocket-Over-HTTP: allow keeping back data (Content-Bytes-Accepted). * Support IPv6 addresses in route targets. * Cache root certs when making outbound connections. * New config option: update_on_first_subscription. * New route condition option: log_level. * Correctly read push_in_http_max_{headers,body}_size options. * Apply next link response's own filters to its content. * runner: look for bins in target dir (makes "cargo run" work). * Remove use of Qt event loop, to reduce dependence on Qt. v. 1.40.1 (2024-08-13) * Fix build with recent rustc. v. 1.40.0 (2024-07-17) * Restore support for receiving chunked responses from backends. * For early errors such as malformed requests, respond instead of disconnecting. * Rename Condure to connmgr. * Fix crash under high load. * Fix hang when receiving large WebSocket rejections from the backend. v. 1.39.1 (2024-03-18) * Regenerate pushpin.conf.inst post-build to ensure up-to-date configuration. * Update legacy runner use revised Qt linking logic, aligning with main branch improvements. v. 1.39.0 (2024-03-14) * Add support for multiple proxy worker threads. * New config option: workers (under [proxy]). * Fix memory leak when proxying requests. * Various build system fixes/improvements. * Use Boost for signals & slots to reduce dependence on Qt's event loop. v. 1.38.0 (2024-01-08) * Publish refresh action for triggering WebSocket-over-HTTP requests. * Ability to read signing secrets from files. * Move Condure into Pushpin and name the program pushpin-condure. * Fix crash when writing partial uncompressed WebSocket frame. * Fix WebSocket proxying flow control. * Support receiving non-chunked responses of indefinite length. * Qt 6 compatibility. * Remove configure script. Configure via environment variables instead. v. 1.37.0 (2023-06-29) * Ability to use Condure instead of Zurl for outgoing connections. * Ability to set mode/user/group when listening on Unix socket. * WebSocket performance optimizations. * New config options: allow_compression, stats_connection_send, cdn_loop. * Relicense to Apache 2.0. v. 1.36.0 (2022-11-14) * Ability to accept client connections via IPv6. * Ability to sign requests using EC or RSA public keys. * Include bytes/messages in report stats. v. 1.35.0 (2022-03-11) * Add support for Prometheus metrics. * Ability to listen on a Unix socket for client connections. * New config options: prometheus_port, prometheus_prefix. * New config option: local_ports. * New config option: accept_pushpin_route. * New route condition option: no_grip. * Use the route of the initial request for retries and link requests. * pushpin-publish: fix sending hint action for http-response format. v. 1.34.0 (2021-11-30) * New config option: message_wait. * Publish command for publishing via command socket. v. 1.33.1 (2021-08-09) * Build system fixes. v. 1.33.0 (2021-08-08) * Performance optimizations. * New config option: sig_iss. v. 1.32.2 (2021-06-09) * Fix publishing to SockJS WebSocket connections. v. 1.32.1 (2021-05-13) * Build system fixes. v. 1.32.0 (2021-05-11) * pushpin-publish: support sending via HTTP, and do this by default. * pushpin-publish: support authentication. * pushpin-publish: use GRIP_URL environment variable if present. * Add Rust code to the build process. v. 1.31.0 (2020-11-06) * Use Condure instead of Mongrel2, by default. * Ability to refresh WebSocket-over-HTTP sessions by channel. * Fix crash when sending delayed WebSocket messages. v. 1.30.0 (2020-07-29) * Optional support for Condure instead of Mongrel2. * ZHTTP compatibility fixes. v. 1.29.0 (2020-07-15) * Fix crash when parsing Accept header received on control port. * Fix crash when response hold times out while pausing. * Fix handling of hints in response mode. * Fix handling of ZeroMQ errors, including EINTR. * ZHTTP compatibility fixes. v. 1.28.0 (2020-04-08) * New route target option: one_event. v. 1.27.0 (2020-03-10) * WebSocket: ability to publish close reason. * WebSocket: proxy the content of ping and pong frames. v. 1.26.0 (2019-12-11) * Respond with status 200 on HTTP control port root path. v. 1.25.0 (2019-11-20) * Set the Mongrel2 log level and capture debug output. * Ability to set different log levels per subprocess. v. 1.24.0 (2019-08-06) * runner: capture Mongrel2 logs when --merge-output is used. v. 1.23.0 (2019-07-03) * Support log levels 0 and 1. * Don't write to Mongrel2 access log for log levels < 2. * Support JSON framing on the input PULL and SUB sockets. * New config option: push_in_sub_specs. * New config option: push_in_sub_connect. v. 1.22.0 (2019-06-17) * New filter: var-subst. * Support content-filters field in ws-message format. v. 1.21.0 (2019-05-01) * GRIP keep-alive modes: idle (default) and interval. * Don't put GRIP headers in Access-Control-Expose-Headers. v. 1.20.3 (2019-04-08) * Fix Grip-Last values when route prefix is used. v. 1.20.2 (2019-03-25) * WebSocket-Over-HTTP: fix mem leak when clients disconnect during close. v. 1.20.1 (2019-02-20) * WebSocket-Over-HTTP: don't forward Content-Length header. v. 1.20.0 (2019-02-19) * WebSocket-Over-HTTP: break up response messages to fit session buffers. * New config option: stats_format. * New config option: client_buffer_size. v. 1.19.1 (2019-01-10) * WebSocket: fix crash when receiving frames after close frame. * WebSocket: include reason and headers in rejection responses. v. 1.19.0 (2018-12-18) * WebSocket: support close reasons. v. 1.18.0 (2018-08-20) * WebSocket-Over-HTTP: update headers (mainly Grip-Sig) for each request. * WebSocket-Over-HTTP: properly report errors and handle target failover. * WebSocket: support debug responses. * Option to not send non-standard X-Forwarded-Protocol header. * Increase default request buffer size to 8k. * Make http_port optional. * runner: remove mongrel2 pid file before starting. * runner: return non-zero status code if failing due to subprocess error. * runner: prevent SIGINT from being copied to subprocesses. v. 1.17.2 (2018-01-11) * Fix close actions with HTTP streaming and WebSockets. v. 1.17.1 (2017-12-12) * Fix compilation with Qt 5.10. v. 1.17.0 (2017-11-06) * De-dup published messages based on recently seen IDs (default 60s). * Limit number of subscriptions per connection (default 20). * Ensure filters update after following next links. * Support content-filters field in http-stream and http-response formats. * Include subscribers field in subscription stats. * Include duration field in report stats. * New config options: connection_subscription_max, subscription_linger. * New config options: stats_connection_ttl, stats_subscription_ttl. * New config option: stats_report_interval. v. 1.16.0 (2017-07-14) * Reliable streaming fixes. * SockJS: XHR transport fixes. * WebSocket-Over-HTTP: more fixes to ensure DISCONNECT events get sent. * Fix routes file change detection when file is replaced. * Set Grip-Last headers when retrying long-polling request. * Enable client-side TCP keep-alives. * Stats: report logical IP address rather than physical. * Published items can include no-seq flag to bypass sequencing buffer. * New config options: log_from, log_user_agent. * New filters: skip-users, build-id, require-sub. * Add randomness to stream keep alives. * pushpin-publish: --meta option. * pushpin-publish: --no-seq option. * Announce more features using Grip-Feature request header. * Fix GRIP session detection. * sub target parameter works for both HTTP and WebSocket, forbids unsub. * Packet logging uses new format that only trims content, not headers. v. 1.15.0 (2017-01-22) * Publish hint action for triggering recovery requests. * Recover command for triggering recovery requests. * Refresh command for triggering WebSocket-Over-HTTP requests. * Improve reliability of long-polling when previous ID is used. * WebSocket-Over-HTTP: ensure DISCONNECT events get sent. * WebSocket: new control messages: send-delayed, flush-delayed. * WebSocket: break large published messages into frames. * Allow unknown previous ID for first message to channel. * Forget previous ID when channel has no subscribers. * Reduce timeout of out-of-order messages to 5 seconds. * pushpin-publish: --hint option * pushpin-publish: --no-eol option * pushpin-publish: ability to use file source (@filename) * New config option: message_block_size * Remove docs files from repository. Content moved to pushpin.org. v. 1.14.0 (2016-11-15) * Reliable HTTP streaming (stream hold + "GRIP Next"). * Process messages in order if received out of order. v. 1.13.1 (2016-10-27) * Fix crash when publishing to a long-polling client that is closing. * More conservative message_rate default. v. 1.13.0 (2016-10-22) * Optimizations for higher concurrent connections. * New config options: message_rate and message_hwm. * New stats message: report. * Handle next links internally if relative. * Log accepted requests as "accept", not "hold". * Log handler-initiated requests in handler, not proxy. * Fix memory leaks. * Send anonymous usage statistics to Fanout. v. 1.12.0 (2016-09-03) * "GRIP Next" feature for streaming many responses as a single response. * header route parameter for sending custom headers when proxying. * trust_connect_host target parameter for trusting cert of connect host. * SockJS: fix bug with not receiving messages from client. * More correct handling of Host header. * Set X-Forwarded-Proto in addition to X-Forwarded-Protocol. * Various bugfixes. v. 1.11.0 (2016-07-11) * Debug mode, to get more information about errors while proxying. * Command line option to log subprocess output: --merge-output. * Command line option to log merged output to file: --logfile. * Command line options for quick config: --port, --route. * Command line option to easily run multiple instances: --id. * Rewrite runner from Python to C++. * Don't relay Content-Encoding (fixes compressed long-polling timeouts). * Fixes to log output. v. 1.10.1 (2016-05-30) * Fix SockJS crash. * Fix bug that logged successful requests as errors. v. 1.10.0 (2016-05-25) * Streaming: initial response now has no size limit. * WebSocket-Over-HTTP: retry requests to the origin server. * WebSocket: ability to disconnect clients by publishing a close action. * WebSocket: ability to publish ping/pong frames. * WebSocket: keep-alives. * New route target "test", for testing without an origin server. * Fix publishing of large payloads through HTTP control port. * New config option: log_level. * Ability to set bind interface in config (use addr:port form). * Grip-Status header, for setting alternate response code and reason. v. 1.9.0 (2016-04-14) * More practical logging. Non-verbose output more informative. * New config option: accept_x_forwarded_protocol. * Support JSON responses in HTTP control endpoint. * More accurate WebSocket activity counting. v. 1.8.0 (2016-02-22) * Fix issue proxying large responses. * Refactor README. * Port server code to Qt 5. * Rewrite pushpin-publish tool from Python to C++. * Move internal.conf into LIBDIR. v. 1.7.0 (2016-01-10) * Rewrite pushpin-handler from Python to C++. * Initial support for subscription filters and skip-self filter. * Fix sending of large responses when flow control not used. * Speed up shutdown. * Pass WebSocket GRIP logic upstream if GRIP proxy detected. * Don't forward WebSocket-Over-HTTP requests unless client trusted. * WebSocket-Over-HTTP: strip private headers from responses. * Long-polling: finish support for JSON patch. * m2adapter: dynamically enable/disable control port as needed. * publish tool: add id, prev-id, patch, and sender options. * Add monitorsubsock tool for monitoring SUB socket. * Refactor docs/grip-protocol.md. v. 1.6.0 (2015-09-24) * Fix rare assert when publishing to a WebSocket. * Remove libdir from pushpin.conf. * Mongrel2: use download flow control. * Mongrel2: enable relaxed parsing. * Auto Cross-Origin: include Access-Control-Max-Age. * Throw error if can't create runtime directories on startup. * Various cleanups. v. 1.5.0 (2015-07-23) * replace_beg route parameter. * Fixed bug where non-persistent connections were closed before data sent. * Accept invalid characters in request URIs and URL-encode them. v. 1.4.0 (2015-07-16) * Improved handling of streamed input while proxying. * WebSocket over_http mode: relay error responses rather than 502. * Various WebSocket bugfixes. * Prefer using sortedcontainers.SortedDict rather than blist.sorteddict. v. 1.3.3 (2015-07-05) * Fix crash on conflict retry introduced in previous version. v. 1.3.2 (2015-07-05) * Better handling of responses with no explicit body (HEAD, 204, 304). * Persistent connection fixes. * Proxy flow control fixes. * WebSocket over_http mode: buffer fragmented messages before sending. v. 1.3.1 (2015-06-19) * Fix http-response conflict recovery. * Correctly proxy WebSocket ping and pong frames. * Fix WebSocket compatibility with latest Zurl. v. 1.3.0 (2015-06-03) * Many fixes with subscription reporting via stats and SUB socket. * Tweaks to enable higher concurrent connection counts. * WebSocket over_http mode sends DISCONNECT events. v. 1.2.0 (2015-05-09) * http-stream: close action, keep-alive. * Check for new pushpin versions. * ZeroMQ endpoint discovery via command socket. * pushpin-publish command line tool. v. 1.1.1 (2015-04-17) * Fix auto-cross-origin feature. v. 1.1.0 (2015-03-08) * SUB socket input. SockJS client support. v. 1.0.0 (2014-09-16) * Stable version. pushpin-1.41.0/Cargo.lock000066400000000000000000001651431504671364300152130ustar00rootroot00000000000000# This file is automatically @generated by Cargo. # It is not intended for manual editing. version = 3 [[package]] name = "adler" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f26201604c87b1e01bd3d98f8d5d9a8fcbb815e8cedb41ffccbeb4bf593a35fe" [[package]] name = "ahash" version = "0.8.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e89da841a80418a9b391ebaea17f5c112ffaaa96f621d2c285b5174da76b9011" dependencies = [ "cfg-if", "once_cell", "version_check", "zerocopy", ] [[package]] name = "aho-corasick" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b2969dcb958b36655471fc61f7e416fa76033bdd4bfed0678d8fee1e2d07a1f0" dependencies = [ "memchr", ] [[package]] name = "allocator-api2" version = "0.2.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "45862d1c77f2228b9e10bc609d5bc203d86ebc9b87ad8d5d5167a6c9abf739d9" [[package]] name = "anes" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4b46cbb362ab8752921c97e041f5e366ee6297bd428a31275b9fcf1e380f7299" [[package]] name = "anstream" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ca84f3628370c59db74ee214b3263d58f9aadd9b4fe7e711fd87dc452b7f163" dependencies = [ "anstyle", "anstyle-parse", "anstyle-query", "anstyle-wincon", "colorchoice", "is-terminal", "utf8parse", ] [[package]] name = "anstyle" version = "1.0.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" [[package]] name = "anstyle-parse" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" dependencies = [ "utf8parse", ] [[package]] name = "anstyle-query" version = "1.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6d36fc52c7f6c869915e99412912f22093507da8d9e942ceaf66fe4b7c14422a" dependencies = [ "windows-sys 0.52.0", ] [[package]] name = "anstyle-wincon" version = "1.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c677ab05e09154296dd37acecd46420c17b9713e8366facafa8fc0885167cf4c" dependencies = [ "anstyle", "windows-sys 0.48.0", ] [[package]] name = "arraydeque" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d902e3d592a523def97af8f317b08ce16b7ab854c1985a0c671e6f15cebc236" [[package]] name = "arrayvec" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96d30a06541fbafbc7f82ed10c06164cfbd2c401138f6addd8404629c4b16711" [[package]] name = "async-trait" version = "0.1.74" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a66537f1bb974b254c98ed142ff995236e81b9d0fe4db0575f46612cb15eb0f9" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "autocfg" version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d468802bab17cbc0cc575e9b053f41e72aa36bfa6b7f55e3529ffa43161b97fa" [[package]] name = "base64" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e1b586273c5702936fe7b7d6896644d8be71e6314cfe09d3167c95f712589e8" [[package]] name = "base64" version = "0.21.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35636a1494ede3b646cc98f74f8e62c773a38a659ebc777a2cf26b9b74171df9" [[package]] name = "base64" version = "0.22.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6" [[package]] name = "bitflags" version = "1.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bef38d45163c2f1dde094a7dfd33ccf595c92905c8f8f4fdc18d06fb1037718a" [[package]] name = "bitflags" version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c8214115b7bf84099f1309324e63141d4c5d7cc26862f97a0a857dbefe165bd" dependencies = [ "serde", ] [[package]] name = "block-buffer" version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3078c7629b62d3f0439517fa394996acacc5cbc91c5a20d8c658e77abd503a71" dependencies = [ "generic-array", ] [[package]] name = "bumpalo" version = "3.14.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7f30e7476521f6f8af1a1c4c0b8cc94f0bee37d91763d0ca2665f299b6cd8aec" [[package]] name = "cast" version = "0.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37b2a672a2cb129a2e41c10b1224bb368f9f37a2b16b612598138befd7b37eb5" [[package]] name = "cbindgen" version = "0.27.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3fce8dd7fcfcbf3a0a87d8f515194b49d6135acab73e18bd380d1d93bb1a15eb" dependencies = [ "clap", "heck", "indexmap", "log", "proc-macro2", "quote", "serde", "serde_json", "syn", "tempfile", "toml 0.8.19", ] [[package]] name = "cc" version = "1.2.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be714c154be609ec7f5dad223a33bf1482fff90472de28f7362806e6d4832b8c" dependencies = [ "shlex", ] [[package]] name = "cfg-if" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "baf1de4339761588bc0619e3cbc0120ee582ebb74b53b4efbf79117bd2da40fd" [[package]] name = "ciborium" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "effd91f6c78e5a4ace8a5d3c0b6bfaec9e2baaef55f3efc00e45fb2e477ee926" dependencies = [ "ciborium-io", "ciborium-ll", "serde", ] [[package]] name = "ciborium-io" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cdf919175532b369853f5d5e20b26b43112613fd6fe7aee757e35f7a44642656" [[package]] name = "ciborium-ll" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "defaa24ecc093c77630e6c15e17c51f5e187bf35ee514f4e2d67baaa96dae22b" dependencies = [ "ciborium-io", "half", ] [[package]] name = "clap" version = "4.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fb690e81c7840c0d7aade59f242ea3b41b9bc27bcd5997890e7702ae4b32e487" dependencies = [ "clap_builder", "clap_derive", "once_cell", ] [[package]] name = "clap_builder" version = "4.3.24" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ed2e96bc16d8d740f6f48d663eddf4b8a0983e79210fd55479b7bcd0a69860e" dependencies = [ "anstream", "anstyle", "clap_lex", "once_cell", "strsim", "terminal_size", ] [[package]] name = "clap_derive" version = "4.3.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54a9bb5758fc5dfe728d1019941681eccaf0cf8a4189b692a0ee2f2ecf90a050" dependencies = [ "heck", "proc-macro2", "quote", "syn", ] [[package]] name = "clap_lex" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cd7cc57abe963c6d3b9d8be5b06ba7c8957a930305ca90304f24ef040aa6f961" [[package]] name = "colorchoice" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "acbf1af155f9b9ef647e42cdc158db4b64a1b61f743629225fde6f3e0be2a7c7" [[package]] name = "config" version = "0.14.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68578f196d2a33ff61b27fae256c3164f65e36382648e30666dde05b8cc9dfdf" dependencies = [ "async-trait", "convert_case", "json5", "nom", "pathdiff", "ron", "rust-ini", "serde", "serde_json", "toml 0.8.19", "yaml-rust2", ] [[package]] name = "const-random" version = "0.1.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87e00182fe74b066627d63b85fd550ac2998d4b0bd86bfed477a0ae4c7c71359" dependencies = [ "const-random-macro", ] [[package]] name = "const-random-macro" version = "0.1.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9d839f2a20b0aee515dc581a6172f2321f96cab76c1a38a4c584a194955390e" dependencies = [ "getrandom", "once_cell", "tiny-keccak", ] [[package]] name = "convert_case" version = "0.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ec182b0ca2f35d8fc196cf3404988fd8b8c739a4d270ff118a398feb0cbec1ca" dependencies = [ "unicode-segmentation", ] [[package]] name = "core-foundation" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "194a7a9e6de53fa55116934067c844d9d749312f75c6f6d0980e8c252f8c2146" dependencies = [ "core-foundation-sys", "libc", ] [[package]] name = "core-foundation-sys" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e496a50fda8aacccc86d7529e2c1e0892dbd0f898a6b5645b5561b89c3210efa" [[package]] name = "cpufeatures" version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce420fe07aecd3e67c5f910618fe65e94158f6dcc0adf44e00d69ce2bdfe0fd0" dependencies = [ "libc", ] [[package]] name = "criterion" version = "0.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f2b12d017a929603d80db1831cd3a24082f8137ce19c69e6447f54f5fc8d692f" dependencies = [ "anes", "cast", "ciborium", "clap", "criterion-plot", "is-terminal", "itertools", "num-traits", "once_cell", "oorandom", "plotters", "rayon", "regex", "serde", "serde_derive", "serde_json", "tinytemplate", "walkdir", ] [[package]] name = "criterion-plot" version = "0.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6b50826342786a51a89e2da3a28f1c32b06e387201bc2d19791f622c673706b1" dependencies = [ "cast", "itertools", ] [[package]] name = "crossbeam-deque" version = "0.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ce6fd6f855243022dcecf8702fef0c297d4338e226845fe067f6341ad9fa0cef" dependencies = [ "cfg-if", "crossbeam-epoch", "crossbeam-utils", ] [[package]] name = "crossbeam-epoch" version = "0.9.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae211234986c545741a7dc064309f67ee1e5ad243d0e48335adc0484d960bcc7" dependencies = [ "autocfg", "cfg-if", "crossbeam-utils", "memoffset", "scopeguard", ] [[package]] name = "crossbeam-utils" version = "0.8.16" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a22b2d63d4d1dc0b7f1b6b2747dd0088008a9be28b6ddf0b1e7d335e3037294" dependencies = [ "cfg-if", ] [[package]] name = "crunchy" version = "0.2.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7a81dae078cea95a014a339291cec439d2f232ebe854a9d672b796c6afafa9b7" [[package]] name = "crypto-common" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1bfb12502f3fc46cca1bb51ac28df9d618d813cdc3d2f25b9fe775a34af26bb3" dependencies = [ "generic-array", "typenum", ] [[package]] name = "deranged" version = "0.3.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b42b6fa04a440b495c8b04d0e71b707c585f83cb9cb28cf8cd0d976c315e31b4" dependencies = [ "powerfmt", ] [[package]] name = "digest" version = "0.10.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9ed9a281f7bc9b7576e61468ba615a66a5c8cfdff42420a70aa82701a3b1e292" dependencies = [ "block-buffer", "crypto-common", ] [[package]] name = "displaydoc" version = "0.2.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "97369cbbc041bc366949bc74d34658d6cda5621039731c6310521892a3a20ae0" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "dlv-list" version = "0.5.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "442039f5147480ba31067cb00ada1adae6892028e40e45fc5de7b7df6dcc1b5f" dependencies = [ "const-random", ] [[package]] name = "either" version = "1.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a26ae43d7bcc3b814de94796a5e736d4029efb0ee900c12e2d54c993ad1a1e07" [[package]] name = "encoding_rs" version = "0.8.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "75030f3c4f45dafd7586dd6780965a8c7e8e285a5ecb86713e63a79c5b2766f3" dependencies = [ "cfg-if", ] [[package]] name = "env_logger" version = "0.9.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a12e6657c4c97ebab115a42dcee77225f7f482cdd841cf7088c657a42e9e00e7" dependencies = [ "log", ] [[package]] name = "equivalent" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5443807d6dff69373d433ab9ef5378ad8df50ca6298caf15de6e52e24aaf54d5" [[package]] name = "errno" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "534c5cf6194dfab3db3242765c03bbe257cf92f22b38f6bc0c58d59108a820ba" dependencies = [ "libc", "windows-sys 0.52.0", ] [[package]] name = "error-chain" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d9435d864e017c3c6afeac1654189b06cdb491cf2ff73dbf0d73b0f292f42ff8" [[package]] name = "fastrand" version = "2.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9fc0510504f03c51ada170672ac806f1f105a88aa97a5281117e1ddc3368e51a" [[package]] name = "filetime" version = "0.2.25" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "35c0522e981e68cbfa8c3f978441a5f34b30b96e146b33cd3359176b50fe8586" dependencies = [ "cfg-if", "libc", "libredox", "windows-sys 0.59.0", ] [[package]] name = "foreign-types" version = "0.3.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6f339eb8adc052cd2ca78910fda869aefa38d22d5cb648e6485e4d3fc06f3b1" dependencies = [ "foreign-types-shared", ] [[package]] name = "foreign-types-shared" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "00b0228411908ca8685dba7fc2cdd70ec9990a6e753e89b6ac91a84c40fbaf4b" [[package]] name = "form_urlencoded" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e13624c2627564efccf4934284bdd98cbaa14e79b0b5a141218e507b3a823456" dependencies = [ "percent-encoding", ] [[package]] name = "fsevent-sys" version = "4.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "76ee7a02da4d231650c7cea31349b889be2f45ddb3ef3032d2ec8185f6313fd2" dependencies = [ "libc", ] [[package]] name = "generic-array" version = "0.14.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "85649ca51fd72272d7821adaf274ad91c288277713d9c18820d8499a7ff69e9a" dependencies = [ "typenum", "version_check", ] [[package]] name = "getrandom" version = "0.2.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fe9006bed769170c11f845cf00c7c1e9092aeb3f268e007c3e760ac68008070f" dependencies = [ "cfg-if", "js-sys", "libc", "wasi", "wasm-bindgen", ] [[package]] name = "half" version = "1.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eabb4a44450da02c90444cf74558da904edde8fb4e9035a9a6a4e15445af0bd7" [[package]] name = "hashbrown" version = "0.14.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e5274423e17b7c9fc20b6e7e208532f9b19825d82dfd615708b70edd83df41f1" dependencies = [ "ahash", "allocator-api2", ] [[package]] name = "hashbrown" version = "0.15.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bf151400ff0baff5465007dd2f3e717f3fe502074ca563069ce3a6629d07b289" [[package]] name = "hashlink" version = "0.8.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e8094feaf31ff591f651a2664fb9cfd92bba7a60ce3197265e9482ebe753c8f7" dependencies = [ "hashbrown 0.14.5", ] [[package]] name = "heck" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" [[package]] name = "hermit-abi" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d231dfb89cfffdbc30e7fc41579ed6066ad03abda9e567ccafae602b97ec5024" [[package]] name = "httparse" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d897f394bad6a705d5f4104762e116a75639e470d80901eed05a860a95cb1904" [[package]] name = "icu_collections" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "db2fa452206ebee18c4b5c2274dbf1de17008e874b4dc4f0aea9d01ca79e4526" dependencies = [ "displaydoc", "yoke", "zerofrom", "zerovec", ] [[package]] name = "icu_locid" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "13acbb8371917fc971be86fc8057c41a64b521c184808a698c02acc242dbf637" dependencies = [ "displaydoc", "litemap", "tinystr", "writeable", "zerovec", ] [[package]] name = "icu_locid_transform" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01d11ac35de8e40fdeda00d9e1e9d92525f3f9d887cdd7aa81d727596788b54e" dependencies = [ "displaydoc", "icu_locid", "icu_locid_transform_data", "icu_provider", "tinystr", "zerovec", ] [[package]] name = "icu_locid_transform_data" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdc8ff3388f852bede6b579ad4e978ab004f139284d7b28715f773507b946f6e" [[package]] name = "icu_normalizer" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "19ce3e0da2ec68599d193c93d088142efd7f9c5d6fc9b803774855747dc6a84f" dependencies = [ "displaydoc", "icu_collections", "icu_normalizer_data", "icu_properties", "icu_provider", "smallvec", "utf16_iter", "utf8_iter", "write16", "zerovec", ] [[package]] name = "icu_normalizer_data" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f8cafbf7aa791e9b22bec55a167906f9e1215fd475cd22adfcf660e03e989516" [[package]] name = "icu_properties" version = "1.5.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93d6020766cfc6302c15dbbc9c8778c37e62c14427cb7f6e601d849e092aeef5" dependencies = [ "displaydoc", "icu_collections", "icu_locid_transform", "icu_properties_data", "icu_provider", "tinystr", "zerovec", ] [[package]] name = "icu_properties_data" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "67a8effbc3dd3e4ba1afa8ad918d5684b8868b3b26500753effea8d2eed19569" [[package]] name = "icu_provider" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6ed421c8a8ef78d3e2dbc98a973be2f3770cb42b606e3ab18d6237c4dfde68d9" dependencies = [ "displaydoc", "icu_locid", "icu_provider_macros", "stable_deref_trait", "tinystr", "writeable", "yoke", "zerofrom", "zerovec", ] [[package]] name = "icu_provider_macros" version = "1.5.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ec89e9337638ecdc08744df490b221a7399bf8d164eb52a665454e60e075ad6" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "idna" version = "1.0.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "686f825264d630750a544639377bae737628043f20d38bbc029e8f29ea968a7e" dependencies = [ "idna_adapter", "smallvec", "utf8_iter", ] [[package]] name = "idna_adapter" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "daca1df1c957320b2cf139ac61e7bd64fed304c5040df000a745aa1de3b4ef71" dependencies = [ "icu_normalizer", "icu_properties", ] [[package]] name = "indexmap" version = "2.6.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "707907fe3c25f5424cce2cb7e1cbcafee6bdbe735ca90ef77c29e84591e5b9da" dependencies = [ "equivalent", "hashbrown 0.15.2", ] [[package]] name = "inotify" version = "0.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fdd168d97690d0b8c412d6b6c10360277f4d7ee495c5d0d5d5fe0854923255cc" dependencies = [ "bitflags 1.3.2", "inotify-sys", "libc", ] [[package]] name = "inotify-sys" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e05c02b5e89bff3b946cedeca278abc628fe811e604f027c45a8aa3cf793d0eb" dependencies = [ "libc", ] [[package]] name = "instant" version = "0.1.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e0242819d153cba4b4b05a5a8f2a7e9bbf97b6055b2a002b395c96b5ff3c0222" dependencies = [ "cfg-if", ] [[package]] name = "io-lifetimes" version = "1.0.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "eae7b9aee968036d54dce06cebaefd919e4472e753296daccd6d344e3e2df0c2" dependencies = [ "hermit-abi", "libc", "windows-sys 0.48.0", ] [[package]] name = "ipnet" version = "2.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f518f335dce6725a761382244631d86cf0ccb2863413590b31338feb467f9c3" [[package]] name = "is-terminal" version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cb0889898416213fab133e1d33a0e5858a48177452750691bde3666d0fdbaf8b" dependencies = [ "hermit-abi", "rustix 0.38.28", "windows-sys 0.48.0", ] [[package]] name = "itertools" version = "0.10.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b0fd2260e829bddf4cb6ea802289de2f86d6a7a690192fbe91b3f46e0f2c8473" dependencies = [ "either", ] [[package]] name = "itoa" version = "1.0.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af150ab688ff2122fcef229be89cb50dd66af9e01a4ff320cc137eecc9bacc38" [[package]] name = "js-sys" version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "54c0c35952f67de54bb584e9fd912b3023117cbafc0a77d8f3dee1fb5f572fe8" dependencies = [ "wasm-bindgen", ] [[package]] name = "json5" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "96b0db21af676c1ce64250b5f40f3ce2cf27e4e47cb91ed91eb6fe9350b430c1" dependencies = [ "pest", "pest_derive", "serde", ] [[package]] name = "jsonwebtoken" version = "9.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a87cc7a48537badeae96744432de36f4be2b4a34a05a5ef32e9dd8a1c169dde" dependencies = [ "base64 0.22.1", "js-sys", "pem", "ring", "serde", "serde_json", "simple_asn1", ] [[package]] name = "kqueue" version = "1.0.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7447f1ca1b7b563588a205fe93dea8df60fd981423a768bc1c0ded35ed147d0c" dependencies = [ "kqueue-sys", "libc", ] [[package]] name = "kqueue-sys" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed9625ffda8729b85e45cf04090035ac368927b8cebc34898e7c120f52e4838b" dependencies = [ "bitflags 1.3.2", "libc", ] [[package]] name = "libc" version = "0.2.170" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "875b3680cb2f8f71bdcf9a30f38d48282f5d3c95cbf9b3fa57269bb5d5c06828" [[package]] name = "libredox" version = "0.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c0ff37bd590ca25063e35af745c343cb7a0271906fb7b37e4813e8f79f00268d" dependencies = [ "bitflags 2.9.0", "libc", "redox_syscall 0.5.10", ] [[package]] name = "linux-raw-sys" version = "0.3.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef53942eb7bf7ff43a617b3e2c1c4a5ecf5944a7c1bc12d7ee39bbb15e5c1519" [[package]] name = "linux-raw-sys" version = "0.4.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "969488b55f8ac402214f3f5fd243ebb7206cf82de60d3172994707a4bcc2b829" [[package]] name = "litemap" version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ee93343901ab17bd981295f2cf0026d4ad018c7c31ba84549a4ddbb47a45104" [[package]] name = "log" version = "0.4.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5e6163cb8c49088c2c36f57875e58ccd8c87c7427f7fbd50ea6710b2f3f2e8f" [[package]] name = "memchr" version = "2.6.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f665ee40bc4a3c5590afb1e9677db74a508659dfd71e126420da8274909a0167" [[package]] name = "memoffset" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a634b1c61a95585bd15607c6ab0c4e5b226e695ff2800ba0cdccddf208c406c" dependencies = [ "autocfg", ] [[package]] name = "metadeps" version = "1.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73b122901b3a675fac8cecf68dcb2f0d3036193bc861d1ac0e1c337f7d5254c2" dependencies = [ "error-chain", "pkg-config", "toml 0.2.1", ] [[package]] name = "minimal-lexical" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68354c5c6bd36d73ff3feceb05efa59b6acb7626617f4962be322a825e61f79a" [[package]] name = "miniz_oxide" version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b275950c28b37e794e8c55d88aeb5e139d0ce23fdbbeda68f8d7174abdf9e8fa" dependencies = [ "adler", ] [[package]] name = "mio" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4929e1f84c5e54c3ec6141cd5d8b5a5c055f031f80cf78f2072920173cb4d880" dependencies = [ "hermit-abi", "libc", "log", "wasi", "windows-sys 0.52.0", ] [[package]] name = "nom" version = "7.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d273983c5a657a70a3e8f2a01329822f3b8c8172b73826411a55751e404a0a4a" dependencies = [ "memchr", "minimal-lexical", ] [[package]] name = "notify" version = "7.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c533b4c39709f9ba5005d8002048266593c1cfaf3c5f0739d5b8ab0c6c504009" dependencies = [ "bitflags 2.9.0", "filetime", "fsevent-sys", "inotify", "kqueue", "libc", "log", "mio", "notify-types", "walkdir", "windows-sys 0.52.0", ] [[package]] name = "notify-types" version = "1.0.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "585d3cb5e12e01aed9e8a1f70d5c6b5e86fe2a6e48fc8cd0b3e0b8df6f6eb174" dependencies = [ "instant", ] [[package]] name = "num-bigint" version = "0.4.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "608e7659b5c3d7cba262d894801b9ec9d00de989e8a82bd4bef91d08da45cdc0" dependencies = [ "autocfg", "num-integer", "num-traits", ] [[package]] name = "num-conv" version = "0.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "51d515d32fb182ee37cda2ccdcb92950d6a3c2893aa280e540671c2cd0f3b1d9" [[package]] name = "num-integer" version = "0.1.45" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "225d3389fb3509a24c93f5c29eb6bde2586b98d9f016636dff58d7c6f7569cd9" dependencies = [ "autocfg", "num-traits", ] [[package]] name = "num-traits" version = "0.2.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "39e3200413f237f41ab11ad6d161bc7239c84dcb631773ccd7de3dfe4b5c267c" dependencies = [ "autocfg", ] [[package]] name = "num_threads" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2819ce041d2ee131036f4fc9d6ae7ae125a3a40e97ba64d04fe799ad9dabbb44" dependencies = [ "libc", ] [[package]] name = "once_cell" version = "1.18.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dd8b5dd2ae5ed71462c540258bedcb51965123ad7e7ccf4b9a8cafaa4a63576d" [[package]] name = "oorandom" version = "11.1.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0ab1bc2a289d34bd04a330323ac98a1b4bc82c9d9fcb1e66b63caa84da26b575" [[package]] name = "openssl" version = "0.10.72" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fedfea7d58a1f73118430a55da6a286e7b044961736ce96a16a17068ea25e5da" dependencies = [ "bitflags 2.9.0", "cfg-if", "foreign-types", "libc", "once_cell", "openssl-macros", "openssl-sys", ] [[package]] name = "openssl-macros" version = "0.1.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a948666b637a0f465e8564c73e89d4dde00d72d4d473cc972f390fc3dcee7d9c" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "openssl-probe" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ff011a302c396a5197692431fc1948019154afc178baf7d8e37367442a4601cf" [[package]] name = "openssl-sys" version = "0.9.107" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8288979acd84749c744a9014b4382d42b8f7b2592847b5afb2ed29e5d16ede07" dependencies = [ "cc", "libc", "pkg-config", "vcpkg", ] [[package]] name = "ordered-multimap" version = "0.7.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49203cdcae0030493bad186b28da2fa25645fa276a51b6fec8010d281e02ef79" dependencies = [ "dlv-list", "hashbrown 0.14.5", ] [[package]] name = "paste" version = "1.0.14" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "de3145af08024dea9fa9914f381a17b8fc6034dfb00f3a84013f7ff43f29ed4c" [[package]] name = "pathdiff" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8835116a5c179084a830efb3adc117ab007512b535bc1a21c991d3b32a6b44dd" [[package]] name = "pem" version = "3.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38af38e8470ac9dee3ce1bae1af9c1671fffc44ddfd8bd1d0a3445bf349a8ef3" dependencies = [ "base64 0.22.1", "serde", ] [[package]] name = "percent-encoding" version = "2.3.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3148f5046208a5d56bcfc03053e3ca6334e51da8dfb19b6cdc8b306fae3283e" [[package]] name = "pest" version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ae9cee2a55a544be8b89dc6848072af97a20f2422603c10865be2a42b580fff5" dependencies = [ "memchr", "thiserror", "ucd-trie", ] [[package]] name = "pest_derive" version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "81d78524685f5ef2a3b3bd1cafbc9fcabb036253d9b1463e726a91cd16e2dfc2" dependencies = [ "pest", "pest_generator", ] [[package]] name = "pest_generator" version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "68bd1206e71118b5356dae5ddc61c8b11e28b09ef6a31acbd15ea48a28e0c227" dependencies = [ "pest", "pest_meta", "proc-macro2", "quote", "syn", ] [[package]] name = "pest_meta" version = "2.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7c747191d4ad9e4a4ab9c8798f1e82a39affe7ef9648390b7e5548d18e099de6" dependencies = [ "once_cell", "pest", "sha2", ] [[package]] name = "pkg-config" version = "0.3.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "26072860ba924cbfa98ea39c8c19b4dd6a4a25423dbdf219c1eca91aa0cf6964" [[package]] name = "plotters" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d2c224ba00d7cadd4d5c660deaf2098e5e80e07846537c51f9cfa4be50c1fd45" dependencies = [ "num-traits", "plotters-backend", "plotters-svg", "wasm-bindgen", "web-sys", ] [[package]] name = "plotters-backend" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9e76628b4d3a7581389a35d5b6e2139607ad7c75b17aed325f210aa91f4a9609" [[package]] name = "plotters-svg" version = "0.3.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "38f6d39893cca0701371e3c27294f09797214b86f1fb951b89ade8ec04e2abab" dependencies = [ "plotters-backend", ] [[package]] name = "powerfmt" version = "0.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "439ee305def115ba05938db6eb1644ff94165c5ab5e9420d1c1bcedbba909391" [[package]] name = "proc-macro2" version = "1.0.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" dependencies = [ "unicode-ident", ] [[package]] name = "pushpin" version = "1.41.0-dev" dependencies = [ "arrayvec", "base64 0.13.1", "cbindgen", "clap", "config", "criterion", "env_logger", "httparse", "ipnet", "jsonwebtoken", "libc", "log", "miniz_oxide", "mio", "notify", "openssl", "paste", "pkg-config", "rustls", "rustls-native-certs", "serde", "serde_json", "sha1", "signal-hook", "slab", "socket2", "test-log", "thiserror", "time", "url", "zmq", ] [[package]] name = "quote" version = "1.0.37" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" dependencies = [ "proc-macro2", ] [[package]] name = "rayon" version = "1.8.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c27db03db7734835b3f53954b534c91069375ce6ccaa2e065441e07d9b6cdb1" dependencies = [ "either", "rayon-core", ] [[package]] name = "rayon-core" version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5ce3fb6ad83f861aac485e76e1985cd109d9a3713802152be56c3b1f0e0658ed" dependencies = [ "crossbeam-deque", "crossbeam-utils", ] [[package]] name = "redox_syscall" version = "0.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4722d768eff46b75989dd134e5c353f0d6296e5aaa3132e776cbdb56be7731aa" dependencies = [ "bitflags 1.3.2", ] [[package]] name = "redox_syscall" version = "0.5.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b8c0c260b63a8219631167be35e6a988e9554dbd323f8bd08439c8ed1302bd1" dependencies = [ "bitflags 2.9.0", ] [[package]] name = "regex" version = "1.10.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "380b951a9c5e80ddfd6136919eef32310721aa4aacd4889a8d39124b026ab343" dependencies = [ "aho-corasick", "memchr", "regex-automata", "regex-syntax", ] [[package]] name = "regex-automata" version = "0.4.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5f804c7828047e88b2d32e2d7fe5a105da8ee3264f01902f796c8e067dc2483f" dependencies = [ "aho-corasick", "memchr", "regex-syntax", ] [[package]] name = "regex-syntax" version = "0.8.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c08c74e62047bb2de4ff487b251e4a92e24f48745648451635cec7d591162d9f" [[package]] name = "ring" version = "0.17.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed9b823fa29b721a59671b41d6b06e66b29e0628e207e8b1c3ceeda701ec928d" dependencies = [ "cc", "cfg-if", "getrandom", "libc", "untrusted", "windows-sys 0.52.0", ] [[package]] name = "ron" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b91f7eff05f748767f183df4320a63d6936e9c6107d97c9e6bdd9784f4289c94" dependencies = [ "base64 0.21.5", "bitflags 2.9.0", "serde", "serde_derive", ] [[package]] name = "rust-ini" version = "0.20.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3e0698206bcb8882bf2a9ecb4c1e7785db57ff052297085a6efd4fe42302068a" dependencies = [ "cfg-if", "ordered-multimap", ] [[package]] name = "rustix" version = "0.37.27" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fea8ca367a3a01fe35e6943c400addf443c0f57670e6ec51196f71a4b8762dd2" dependencies = [ "bitflags 1.3.2", "errno", "io-lifetimes", "libc", "linux-raw-sys 0.3.8", "windows-sys 0.48.0", ] [[package]] name = "rustix" version = "0.38.28" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "72e572a5e8ca657d7366229cdde4bd14c4eb5499a9573d4d366fe1b599daa316" dependencies = [ "bitflags 2.9.0", "errno", "libc", "linux-raw-sys 0.4.11", "windows-sys 0.52.0", ] [[package]] name = "rustls" version = "0.21.11" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7fecbfb7b1444f477b345853b1fce097a2c6fb637b2bfb87e6bc5db0f043fae4" dependencies = [ "log", "ring", "rustls-webpki", "sct", ] [[package]] name = "rustls-native-certs" version = "0.6.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a9aace74cb666635c918e9c12bc0d348266037aa8eb599b5cba565709a8dff00" dependencies = [ "openssl-probe", "rustls-pemfile", "schannel", "security-framework", ] [[package]] name = "rustls-pemfile" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1c74cae0a4cf6ccbbf5f359f08efdf8ee7e1dc532573bf0db71968cb56b1448c" dependencies = [ "base64 0.21.5", ] [[package]] name = "rustls-webpki" version = "0.101.7" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8b6275d1ee7a1cd780b64aca7726599a1dbc893b1e64144529e55c3c2f745765" dependencies = [ "ring", "untrusted", ] [[package]] name = "ryu" version = "1.0.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1ad4cc8da4ef723ed60bced201181d83791ad433213d8c24efffda1eec85d741" [[package]] name = "same-file" version = "1.0.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "93fc1dc3aaa9bfed95e02e6eadabb4baf7e3078b0bd1b4d7b6b0b68378900502" dependencies = [ "winapi-util", ] [[package]] name = "schannel" version = "0.1.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0c3733bf4cf7ea0880754e19cb5a462007c4a8c1914bff372ccc95b464f1df88" dependencies = [ "windows-sys 0.48.0", ] [[package]] name = "scopeguard" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "94143f37725109f92c262ed2cf5e59bce7498c01bcc1502d7b9afe439a4e9f49" [[package]] name = "sct" version = "0.7.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "da046153aa2352493d6cb7da4b6e5c0c057d8a1d0a9aa8560baffdd945acd414" dependencies = [ "ring", "untrusted", ] [[package]] name = "security-framework" version = "2.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "05b64fb303737d99b81884b2c63433e9ae28abebe5eb5045dcdd175dc2ecf4de" dependencies = [ "bitflags 1.3.2", "core-foundation", "core-foundation-sys", "libc", "security-framework-sys", ] [[package]] name = "security-framework-sys" version = "2.9.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e932934257d3b408ed8f30db49d85ea163bfe74961f017f405b025af298f0c7a" dependencies = [ "core-foundation-sys", "libc", ] [[package]] name = "serde" version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "bca2a08484b285dcb282d0f67b26cadc0df8b19f8c12502c13d966bf9482f001" dependencies = [ "serde_derive", ] [[package]] name = "serde_derive" version = "1.0.192" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d6c7207fbec9faa48073f3e3074cbe553af6ea512d7c21ba46e434e70ea9fbc1" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "serde_json" version = "1.0.108" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3d1c7e3eac408d115102c4c24ad393e0821bb3a5df4d506a80f85f7a742a526b" dependencies = [ "itoa", "ryu", "serde", ] [[package]] name = "serde_spanned" version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "87607cb1398ed59d48732e575a4c28a7a8ebf2454b964fe3f224f2afc07909e1" dependencies = [ "serde", ] [[package]] name = "sha1" version = "0.10.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e3bf829a2d51ab4a5ddf1352d8470c140cadc8301b2ae1789db023f01cedd6ba" dependencies = [ "cfg-if", "cpufeatures", "digest", ] [[package]] name = "sha2" version = "0.10.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "793db75ad2bcafc3ffa7c68b215fee268f537982cd901d132f89c6343f3a3dc8" dependencies = [ "cfg-if", "cpufeatures", "digest", ] [[package]] name = "shlex" version = "1.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0fda2ff0d084019ba4d7c6f371c95d8fd75ce3524c3cb8fb653a3023f6323e64" [[package]] name = "signal-hook" version = "0.3.17" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8621587d4798caf8eb44879d42e56b9a93ea5dcd315a6487c357130095b62801" dependencies = [ "libc", "signal-hook-registry", ] [[package]] name = "signal-hook-registry" version = "1.4.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d8229b473baa5980ac72ef434c4415e70c4b5e71b423043adb4ba059f89c99a1" dependencies = [ "libc", ] [[package]] name = "simple_asn1" version = "0.6.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "adc4e5204eb1910f40f9cfa375f6f05b68c3abac4b6fd879c8ff5e7ae8a0a085" dependencies = [ "num-bigint", "num-traits", "thiserror", "time", ] [[package]] name = "slab" version = "0.4.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f92a496fb766b417c996b9c5e57daf2f7ad3b0bebe1ccfca4856390e3d3bb67" dependencies = [ "autocfg", ] [[package]] name = "smallvec" version = "1.13.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3c5e1a9a646d36c3599cd173a41282daf47c44583ad367b8e6837255952e5c67" [[package]] name = "socket2" version = "0.4.10" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7916fc008ca5542385b89a3d3ce689953c143e9304a9bf8beec1de48994c0d" dependencies = [ "libc", "winapi", ] [[package]] name = "stable_deref_trait" version = "1.2.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a8f112729512f8e442d81f95a8a7ddf2b7c6b8a1a6f509a95864142b30cab2d3" [[package]] name = "strsim" version = "0.10.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "73473c0e59e6d5812c5dfe2a064a6444949f089e20eec9a2e5506596494e4623" [[package]] name = "syn" version = "2.0.89" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "44d46482f1c1c87acd84dea20c1bf5ebff4c757009ed6bf19cfd36fb10e92c4e" dependencies = [ "proc-macro2", "quote", "unicode-ident", ] [[package]] name = "synstructure" version = "0.13.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8af7666ab7b6390ab78131fb5b0fce11d6b7a6951602017c35fa82800708971" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "tempfile" version = "3.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "01ce4141aa927a6d1bd34a041795abd0db1cccba5d5f24b009f694bdf3a1f3fa" dependencies = [ "cfg-if", "fastrand", "redox_syscall 0.4.1", "rustix 0.38.28", "windows-sys 0.52.0", ] [[package]] name = "terminal_size" version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e6bf6f19e9f8ed8d4048dc22981458ebcf406d67e94cd422e5ecd73d63b3237" dependencies = [ "rustix 0.37.27", "windows-sys 0.48.0", ] [[package]] name = "test-log" version = "0.2.13" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f66edd6b6cd810743c0c71e1d085e92b01ce6a72782032e3f794c8284fe4bcdd" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "thiserror" version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f9a7210f5c9a7156bb50aa36aed4c95afb51df0df00713949448cf9e97d382d2" dependencies = [ "thiserror-impl", ] [[package]] name = "thiserror-impl" version = "1.0.50" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "266b2e40bc00e5a6c09c3584011e08b06f123c00362c92b975ba9843aaaa14b8" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "time" version = "0.3.36" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5dfd88e563464686c916c7e46e623e520ddc6d79fa6641390f2e3fa86e83e885" dependencies = [ "deranged", "itoa", "libc", "num-conv", "num_threads", "powerfmt", "serde", "time-core", "time-macros", ] [[package]] name = "time-core" version = "0.1.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ef927ca75afb808a4d64dd374f00a2adf8d0fcff8e7b184af886c3c87ec4a3f3" [[package]] name = "time-macros" version = "0.2.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3f252a68540fde3a3877aeea552b832b40ab9a69e318efd078774a01ddee1ccf" dependencies = [ "num-conv", "time-core", ] [[package]] name = "tiny-keccak" version = "2.0.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2c9d3793400a45f954c52e73d068316d76b6f4e36977e3fcebb13a2721e80237" dependencies = [ "crunchy", ] [[package]] name = "tinystr" version = "0.7.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9117f5d4db391c1cf6927e7bea3db74b9a1c1add8f7eda9ffd5364f40f57b82f" dependencies = [ "displaydoc", "zerovec", ] [[package]] name = "tinytemplate" version = "1.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "be4d6b5f19ff7664e8c98d03e2139cb510db9b0a60b55f8e8709b689d939b6bc" dependencies = [ "serde", "serde_json", ] [[package]] name = "toml" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "736b60249cb25337bc196faa43ee12c705e426f3d55c214d73a4e7be06f92cb4" [[package]] name = "toml" version = "0.8.19" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1ed1f98e3fdc28d6d910e6737ae6ab1a93bf1985935a1193e68f93eeb68d24e" dependencies = [ "serde", "serde_spanned", "toml_datetime", "toml_edit", ] [[package]] name = "toml_datetime" version = "0.6.8" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0dd7358ecb8fc2f8d014bf86f6f638ce72ba252a2c3a2572f2a795f1d23efb41" dependencies = [ "serde", ] [[package]] name = "toml_edit" version = "0.22.22" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4ae48d6208a266e853d946088ed816055e556cc6028c5e8e2b84d9fa5dd7c7f5" dependencies = [ "indexmap", "serde", "serde_spanned", "toml_datetime", "winnow", ] [[package]] name = "typenum" version = "1.17.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42ff0bf0c66b8238c6f3b578df37d0b7848e55df8577b3f74f92a69acceeb825" [[package]] name = "ucd-trie" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed646292ffc8188ef8ea4d1e0e0150fb15a5c2e12ad9b8fc191ae7a8a7f3c4b9" [[package]] name = "unicode-ident" version = "1.0.12" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "3354b9ac3fae1ff6755cb6db53683adb661634f67557942dea4facebec0fee4b" [[package]] name = "unicode-segmentation" version = "1.12.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f6ccf251212114b54433ec949fd6a7841275f9ada20dddd2f29e9ceea4501493" [[package]] name = "untrusted" version = "0.9.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8ecb6da28b8a351d773b68d5825ac39017e680750f980f3a1a85cd8dd28a47c1" [[package]] name = "url" version = "2.5.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32f8b686cadd1473f4bd0117a5d28d36b1ade384ea9b5069a1c40aefed7fda60" dependencies = [ "form_urlencoded", "idna", "percent-encoding", ] [[package]] name = "utf16_iter" version = "1.0.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "c8232dd3cdaed5356e0f716d285e4b40b932ac434100fe9b7e0e8e935b9e6246" [[package]] name = "utf8_iter" version = "1.0.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b6c140620e7ffbb22c2dee59cafe6084a59b5ffc27a8859a5f0d494b5d52b6be" [[package]] name = "utf8parse" version = "0.2.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "711b9620af191e0cdc7468a8d14e709c3dcdb115b36f838e601583af800a370a" [[package]] name = "vcpkg" version = "0.2.15" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "accd4ea62f7bb7a82fe23066fb0957d48ef677f6eeb8215f372f52e48bb32426" [[package]] name = "version_check" version = "0.9.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "49874b5167b65d7193b8aba1567f5c7d93d001cafc34600cee003eda787e483f" [[package]] name = "walkdir" version = "2.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d71d857dc86794ca4c280d616f7da00d2dbfd8cd788846559a6813e6aa4b54ee" dependencies = [ "same-file", "winapi-util", ] [[package]] name = "wasi" version = "0.11.0+wasi-snapshot-preview1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9c8d87e72b64a3b4db28d11ce29237c246188f4f51057d65a7eab63b7987e423" [[package]] name = "wasm-bindgen" version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "4be2531df63900aeb2bca0daaaddec08491ee64ceecbee5076636a3b026795a8" dependencies = [ "cfg-if", "wasm-bindgen-macro", ] [[package]] name = "wasm-bindgen-backend" version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "614d787b966d3989fa7bb98a654e369c762374fd3213d212cfc0251257e747da" dependencies = [ "bumpalo", "log", "once_cell", "proc-macro2", "quote", "syn", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-macro" version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a1f8823de937b71b9460c0c34e25f3da88250760bec0ebac694b49997550d726" dependencies = [ "quote", "wasm-bindgen-macro-support", ] [[package]] name = "wasm-bindgen-macro-support" version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "e94f17b526d0a461a191c78ea52bbce64071ed5c04c9ffe424dcb38f74171bb7" dependencies = [ "proc-macro2", "quote", "syn", "wasm-bindgen-backend", "wasm-bindgen-shared", ] [[package]] name = "wasm-bindgen-shared" version = "0.2.92" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "af190c94f2773fdb3729c55b007a722abb5384da03bc0986df4c289bf5567e96" [[package]] name = "web-sys" version = "0.3.65" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5db499c5f66323272151db0e666cd34f78617522fb0c1604d31a27c50c206a85" dependencies = [ "js-sys", "wasm-bindgen", ] [[package]] name = "winapi" version = "0.3.9" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5c839a674fcd7a98952e593242ea400abe93992746761e38641405d28b00f419" dependencies = [ "winapi-i686-pc-windows-gnu", "winapi-x86_64-pc-windows-gnu", ] [[package]] name = "winapi-i686-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ac3b87c63620426dd9b991e5ce0329eff545bccbbb34f3be09ff6fb6ab51b7b6" [[package]] name = "winapi-util" version = "0.1.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f29e6f9198ba0d26b4c9f07dbe6f9ed633e1f3d5b8b414090084349e46a52596" dependencies = [ "winapi", ] [[package]] name = "winapi-x86_64-pc-windows-gnu" version = "0.4.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "712e227841d057c1ee1cd2fb22fa7e5a5461ae8e48fa2ca79ec42cfc1931183f" [[package]] name = "windows-sys" version = "0.48.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "677d2418bec65e3338edb076e806bc1ec15693c5d0104683f2efe857f61056a9" dependencies = [ "windows-targets 0.48.5", ] [[package]] name = "windows-sys" version = "0.52.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "282be5f36a8ce781fad8c8ae18fa3f9beff57ec1b52cb3de0789201425d9a33d" dependencies = [ "windows-targets 0.52.6", ] [[package]] name = "windows-sys" version = "0.59.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" dependencies = [ "windows-targets 0.52.6", ] [[package]] name = "windows-targets" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9a2fa6e2155d7247be68c096456083145c183cbbbc2764150dda45a87197940c" dependencies = [ "windows_aarch64_gnullvm 0.48.5", "windows_aarch64_msvc 0.48.5", "windows_i686_gnu 0.48.5", "windows_i686_msvc 0.48.5", "windows_x86_64_gnu 0.48.5", "windows_x86_64_gnullvm 0.48.5", "windows_x86_64_msvc 0.48.5", ] [[package]] name = "windows-targets" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" dependencies = [ "windows_aarch64_gnullvm 0.52.6", "windows_aarch64_msvc 0.52.6", "windows_i686_gnu 0.52.6", "windows_i686_gnullvm", "windows_i686_msvc 0.52.6", "windows_x86_64_gnu 0.52.6", "windows_x86_64_gnullvm 0.52.6", "windows_x86_64_msvc 0.52.6", ] [[package]] name = "windows_aarch64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2b38e32f0abccf9987a4e3079dfb67dcd799fb61361e53e2882c3cbaf0d905d8" [[package]] name = "windows_aarch64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" [[package]] name = "windows_aarch64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "dc35310971f3b2dbbf3f0690a219f40e2d9afcf64f9ab7cc1be722937c26b4bc" [[package]] name = "windows_aarch64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" [[package]] name = "windows_i686_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a75915e7def60c94dcef72200b9a8e58e5091744960da64ec734a6c6e9b3743e" [[package]] name = "windows_i686_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" [[package]] name = "windows_i686_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" [[package]] name = "windows_i686_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8f55c233f70c4b27f66c523580f78f1004e8b5a8b659e05a4eb49d4166cca406" [[package]] name = "windows_i686_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" [[package]] name = "windows_x86_64_gnu" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "53d40abd2583d23e4718fddf1ebec84dbff8381c07cae67ff7768bbf19c6718e" [[package]] name = "windows_x86_64_gnu" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" [[package]] name = "windows_x86_64_gnullvm" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "0b7b52767868a23d5bab768e390dc5f5c55825b6d30b86c844ff2dc7414044cc" [[package]] name = "windows_x86_64_gnullvm" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" [[package]] name = "windows_x86_64_msvc" version = "0.48.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ed94fce61571a4006852b7389a063ab983c02eb1bb37b47f8272ce92d06d9538" [[package]] name = "windows_x86_64_msvc" version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" [[package]] name = "winnow" version = "0.6.20" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "36c1fec1a2bb5866f07c25f68c26e565c4c200aebb96d7e55710c19d3e8ac49b" dependencies = [ "memchr", ] [[package]] name = "write16" version = "1.0.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d1890f4022759daae28ed4fe62859b1236caebfc61ede2f63ed4e695f3f6d936" [[package]] name = "writeable" version = "0.5.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1e9df38ee2d2c3c5948ea468a8406ff0db0b29ae1ffde1bcf20ef305bcc95c51" [[package]] name = "yaml-rust2" version = "0.8.1" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8902160c4e6f2fb145dbe9d6760a75e3c9522d8bf796ed7047c85919ac7115f8" dependencies = [ "arraydeque", "encoding_rs", "hashlink", ] [[package]] name = "yoke" version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "120e6aef9aa629e3d4f52dc8cc43a015c7724194c97dfaf45180d2daf2b77f40" dependencies = [ "serde", "stable_deref_trait", "yoke-derive", "zerofrom", ] [[package]] name = "yoke-derive" version = "0.7.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2380878cad4ac9aac1e2435f3eb4020e8374b5f13c296cb75b4620ff8e229154" dependencies = [ "proc-macro2", "quote", "syn", "synstructure", ] [[package]] name = "zerocopy" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "1b9b4fd18abc82b8136838da5d50bae7bdea537c574d8dc1a34ed098d6c166f0" dependencies = [ "zerocopy-derive", ] [[package]] name = "zerocopy-derive" version = "0.7.35" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "fa4f8080344d4671fb4e831a13ad1e68092748387dfc4f55e356242fae12ce3e" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "zerofrom" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "cff3ee08c995dee1859d998dea82f7374f2826091dd9cd47def953cae446cd2e" dependencies = [ "zerofrom-derive", ] [[package]] name = "zerofrom-derive" version = "0.1.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "595eed982f7d355beb85837f651fa22e90b3c044842dc7f2c2842c086f295808" dependencies = [ "proc-macro2", "quote", "syn", "synstructure", ] [[package]] name = "zerovec" version = "0.10.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aa2b893d79df23bfb12d5461018d408ea19dfafe76c2c7ef6d4eba614f8ff079" dependencies = [ "yoke", "zerofrom", "zerovec-derive", ] [[package]] name = "zerovec-derive" version = "0.10.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "6eafa6dfb17584ea3e2bd6e76e0cc15ad7af12b09abdd1ca55961bed9b1063c6" dependencies = [ "proc-macro2", "quote", "syn", ] [[package]] name = "zmq" version = "0.9.2" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "aad98a7a617d608cd9e1127147f630d24af07c7cd95ba1533246d96cbdd76c66" dependencies = [ "bitflags 1.3.2", "libc", "log", "zmq-sys", ] [[package]] name = "zmq-sys" version = "0.11.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d33a2c51dde24d5b451a2ed4b488266df221a5eaee2ee519933dc46b9a9b3648" dependencies = [ "libc", "metadeps", ] pushpin-1.41.0/Cargo.toml000066400000000000000000000036211504671364300152260ustar00rootroot00000000000000[package] name = "pushpin" version = "1.41.0-dev" authors = ["Justin Karneges "] description = "Reverse proxy for realtime web services" repository = "https://github.com/fastly/pushpin" readme = "README.md" license = "Apache-2.0" edition = "2018" rust-version = "1.75" default-run = "pushpin" [profile.dev] panic = "abort" [profile.release] panic = "abort" [lib] crate-type = ["rlib", "staticlib"] [dependencies] arrayvec = "0.7" base64 = "0.13" clap = { version = "=4.3.24", features = ["cargo", "string", "wrap_help", "derive"] } config = "0.14" httparse = "1.7" ipnet = "2" jsonwebtoken = "9" libc = "0.2" log = "0.4" miniz_oxide = "0.6" mio = { version = "1", features = ["os-poll", "os-ext", "net"] } notify = "7" openssl = "=0.10.72" paste = "1.0" rustls = "0.21" rustls-native-certs = "0.6" serde = { version = "1.0", features = ["derive"] } serde_json = "1.0" sha1 = "0.10" signal-hook = "0.3" slab = "0.4" socket2 = "0.4" thiserror = "1.0" time = { version = "0.3.36", features = ["formatting", "local-offset", "macros"] } url = "2.3" zmq = "0.9" [dev-dependencies] criterion = "0.5" env_logger = { version = "0.9", default-features = false } test-log = "0.2" [build-dependencies] pkg-config = "0.3" time = { version = "0.3.36", features = ["formatting", "local-offset", "macros"] } cbindgen = "0.27" [[bench]] name = "server" harness = false [[bench]] name = "client" harness = false [[bin]] name = "pushpin-connmgr" test = false bench = false [[bin]] name = "m2adapter" test = false bench = false [[bin]] name = "pushpin-proxy" test = false bench = false [[bin]] name = "pushpin-handler" test = false bench = false [[bin]] name = "pushpin" test = false bench = false [[bin]] name = "pushpin-publish" test = false bench = false [lints.rust] unexpected_cfgs = { level = "warn", check-cfg = ['cfg(qt_lib_prefix, values("Qt", "Qt6", "Qt5"))'] } [lints.clippy] uninlined_format_args = "allow" pushpin-1.41.0/LICENSE000066400000000000000000000261351504671364300143100ustar00rootroot00000000000000 Apache License Version 2.0, January 2004 http://www.apache.org/licenses/ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION 1. Definitions. "License" shall mean the terms and conditions for use, reproduction, and distribution as defined by Sections 1 through 9 of this document. "Licensor" shall mean the copyright owner or entity authorized by the copyright owner that is granting the License. "Legal Entity" shall mean the union of the acting entity and all other entities that control, are controlled by, or are under common control with that entity. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity. "You" (or "Your") shall mean an individual or Legal Entity exercising permissions granted by this License. "Source" form shall mean the preferred form for making modifications, including but not limited to software source code, documentation source, and configuration files. "Object" form shall mean any form resulting from mechanical transformation or translation of a Source form, including but not limited to compiled object code, generated documentation, and conversions to other media types. "Work" shall mean the work of authorship, whether in Source or Object form, made available under the License, as indicated by a copyright notice that is included in or attached to the work (an example is provided in the Appendix below). "Derivative Works" shall mean any work, whether in Source or Object form, that is based on (or derived from) the Work and for which the editorial revisions, annotations, elaborations, or other modifications represent, as a whole, an original work of authorship. For the purposes of this License, Derivative Works shall not include works that remain separable from, or merely link (or bind by name) to the interfaces of, the Work and Derivative Works thereof. "Contribution" shall mean any work of authorship, including the original version of the Work and any modifications or additions to that Work or Derivative Works thereof, that is intentionally submitted to Licensor for inclusion in the Work by the copyright owner or by an individual or Legal Entity authorized to submit on behalf of the copyright owner. For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication sent to the Licensor or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, the Licensor for the purpose of discussing and improving the Work, but excluding communication that is conspicuously marked or otherwise designated in writing by the copyright owner as "Not a Contribution." "Contributor" shall mean Licensor and any individual or Legal Entity on behalf of whom a Contribution has been received by Licensor and subsequently incorporated within the Work. 2. Grant of Copyright License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable copyright license to reproduce, prepare Derivative Works of, publicly display, publicly perform, sublicense, and distribute the Work and such Derivative Works in Source or Object form. 3. Grant of Patent License. Subject to the terms and conditions of this License, each Contributor hereby grants to You a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by such Contributor that are necessarily infringed by their Contribution(s) alone or by combination of their Contribution(s) with the Work to which such Contribution(s) was submitted. If You institute patent litigation against any entity (including a cross-claim or counterclaim in a lawsuit) alleging that the Work or a Contribution incorporated within the Work constitutes direct or contributory patent infringement, then any patent licenses granted to You under this License for that Work shall terminate as of the date such litigation is filed. 4. Redistribution. You may reproduce and distribute copies of the Work or Derivative Works thereof in any medium, with or without modifications, and in Source or Object form, provided that You meet the following conditions: (a) You must give any other recipients of the Work or Derivative Works a copy of this License; and (b) You must cause any modified files to carry prominent notices stating that You changed the files; and (c) You must retain, in the Source form of any Derivative Works that You distribute, all copyright, patent, trademark, and attribution notices from the Source form of the Work, excluding those notices that do not pertain to any part of the Derivative Works; and (d) If the Work includes a "NOTICE" text file as part of its distribution, then any Derivative Works that You distribute must include a readable copy of the attribution notices contained within such NOTICE file, excluding those notices that do not pertain to any part of the Derivative Works, in at least one of the following places: within a NOTICE text file distributed as part of the Derivative Works; within the Source form or documentation, if provided along with the Derivative Works; or, within a display generated by the Derivative Works, if and wherever such third-party notices normally appear. The contents of the NOTICE file are for informational purposes only and do not modify the License. You may add Your own attribution notices within Derivative Works that You distribute, alongside or as an addendum to the NOTICE text from the Work, provided that such additional attribution notices cannot be construed as modifying the License. You may add Your own copyright statement to Your modifications and may provide additional or different license terms and conditions for use, reproduction, or distribution of Your modifications, or for any such Derivative Works as a whole, provided Your use, reproduction, and distribution of the Work otherwise complies with the conditions stated in this License. 5. Submission of Contributions. Unless You explicitly state otherwise, any Contribution intentionally submitted for inclusion in the Work by You to the Licensor shall be under the terms and conditions of this License, without any additional terms or conditions. Notwithstanding the above, nothing herein shall supersede or modify the terms of any separate license agreement you may have executed with Licensor regarding such Contributions. 6. Trademarks. This License does not grant permission to use the trade names, trademarks, service marks, or product names of the Licensor, except as required for reasonable and customary use in describing the origin of the Work and reproducing the content of the NOTICE file. 7. Disclaimer of Warranty. Unless required by applicable law or agreed to in writing, Licensor provides the Work (and each Contributor provides its Contributions) on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied, including, without limitation, any warranties or conditions of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE. You are solely responsible for determining the appropriateness of using or redistributing the Work and assume any risks associated with Your exercise of permissions under this License. 8. Limitation of Liability. In no event and under no legal theory, whether in tort (including negligence), contract, or otherwise, unless required by applicable law (such as deliberate and grossly negligent acts) or agreed to in writing, shall any Contributor be liable to You for damages, including any direct, indirect, special, incidental, or consequential damages of any character arising as a result of this License or out of the use or inability to use the Work (including but not limited to damages for loss of goodwill, work stoppage, computer failure or malfunction, or any and all other commercial damages or losses), even if such Contributor has been advised of the possibility of such damages. 9. Accepting Warranty or Additional Liability. While redistributing the Work or Derivative Works thereof, You may choose to offer, and charge a fee for, acceptance of support, warranty, indemnity, or other liability obligations and/or rights consistent with this License. However, in accepting such obligations, You may act only on Your own behalf and on Your sole responsibility, not on behalf of any other Contributor, and only if You agree to indemnify, defend, and hold each Contributor harmless for any liability incurred by, or claims asserted against, such Contributor by reason of your accepting any such warranty or additional liability. END OF TERMS AND CONDITIONS APPENDIX: How to apply the Apache License to your work. To apply the Apache License to your work, attach the following boilerplate notice, with the fields enclosed by brackets "[]" replaced with your own identifying information. (Don't include the brackets!) The text should be enclosed in the appropriate comment syntax for the file format. We also recommend that a file or class name and description of purpose be included on the same "printed page" as the copyright notice for easier identification within third-party archives. Copyright [yyyy] [name of copyright owner] Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at http://www.apache.org/licenses/LICENSE-2.0 Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. pushpin-1.41.0/Makefile000066400000000000000000000015121504671364300147330ustar00rootroot00000000000000# controlled leading whitespace, per the GNU make manual nullstring := space := $(nullstring) # end of the line ifdef RELEASE cargo_flags = $(space)--offline --locked --release endif ifdef TOOLCHAIN cargo_toolchain = $(space)+$(TOOLCHAIN) endif all: postbuild build: FORCE cargo$(cargo_toolchain) build$(cargo_flags) cargo-test: FORCE cargo$(cargo_toolchain) test$(cargo_flags) --all-features cargo-clean: FORCE cargo clean postbuild: build FORCE cd postbuild && $(MAKE) -f Makefile postbuild-install: FORCE cd postbuild && $(MAKE) -f Makefile install postbuild-clean: FORCE cd postbuild && $(MAKE) -f Makefile clean postbuild-distclean: FORCE cd postbuild && $(MAKE) -f Makefile distclean check: cargo-test install: postbuild-install clean: cargo-clean postbuild-clean distclean: cargo-clean postbuild-distclean FORCE: pushpin-1.41.0/README.md000066400000000000000000000316231504671364300145600ustar00rootroot00000000000000# Pushpin Website: https://pushpin.org/ Forum: https://community.fastly.com/c/pushpin/12 Pushpin is a reverse proxy server written in Rust & C++ that makes it easy to implement WebSocket, HTTP streaming, and HTTP long-polling services. The project is unique among realtime push solutions in that it is designed to address the needs of API creators. Pushpin is transparent to clients and integrates easily into an API stack. ## How it works Pushpin is placed in the network path between the backend and any clients:

pushpin-abstract

Pushpin communicates with backend web applications using regular, short-lived HTTP requests. This allows backend applications to be written in any language and use any webserver. There are two main integration points: 1. The backend must handle proxied requests. For HTTP, each incoming request is proxied to the backend. For WebSockets, the activity of each connection is translated into a series of HTTP requests[1](#proxy-modes) sent to the backend. Pushpin's behavior is determined by how the backend responds to these requests. 2. The backend must tell Pushpin to push data. Regardless of how clients are connected, data may be pushed to them by making an HTTP POST request to Pushpin's private control API (`http://localhost:5561/publish/` by default). Pushpin will inject this data into any client connections as necessary. To assist with integration, there are [libraries](https://pushpin.org/docs/usage/#libraries) for many backend languages and frameworks. Pushpin has no libraries on the client side because it is transparent to clients. ## Example To create an HTTP streaming connection, respond to a proxied request with special headers `Grip-Hold` and `Grip-Channel`[2](#grip): ```http HTTP/1.1 200 OK Content-Type: text/plain Content-Length: 22 Grip-Hold: stream Grip-Channel: test welcome to the stream ``` When Pushpin receives the above response from the backend, it will process it and send an initial response to the client that instead looks like this: ```http HTTP/1.1 200 OK Content-Type: text/plain Transfer-Encoding: chunked Connection: Transfer-Encoding welcome to the stream ``` Pushpin eats the special headers and switches to chunked encoding (notice there's no `Content-Length`). The request between Pushpin and the backend is now complete, but the request between the client and Pushpin remains held open. The request is subscribed to a channel called `test`. Data can then be pushed to the client by publishing data on the `test` channel: ```bash curl -d '{ "items": [ { "channel": "test", "formats": { "http-stream": \ { "content": "hello there\n" } } } ] }' \ http://localhost:5561/publish ``` The client would then see the line "hello there" appended to the response stream. Ta-da, transparent realtime push! For more details, see the [HTTP streaming](https://pushpin.org/docs/usage/#http-streaming) section of the documentation. Pushpin also supports [HTTP long-polling](https://pushpin.org/docs/usage/#http-long-polling) and [WebSockets](https://pushpin.org/docs/usage/#websockets). ## Example using a library Using a library on the backend makes integration even easier. Here's another HTTP streaming example, similar to the one shown above, except using Pushpin's [Django library](https://github.com/fanout/django-grip). Please note that Pushpin is not Python/Django-specific and there are backend libraries for [other languages/frameworks, too](https://pushpin.org/docs/usage/#libraries). The Django library requires configuration in `settings.py`: ```python MIDDLEWARE_CLASSES = ( 'django_grip.GripMiddleware', ... ) GRIP_PROXIES = [{'control_uri': 'http://localhost:5561'}] ``` Here's a simple view: ```python from django.http import HttpResponse from django_grip import set_hold_stream def myendpoint(request): if request.method == 'GET': # subscribe every incoming request to a channel in stream mode set_hold_stream(request, 'test') return HttpResponse('welcome to the stream\n', content_type='text/plain') ... ``` What happens here is the `set_hold_stream()` method flags the request as needing to turn into a stream, bound to channel `test`. The middleware will see this and add the necessary `Grip-Hold` and `Grip-Channel` headers to the response. Publishing data is easy: ```python from gripcontrol import HttpStreamFormat from django_grip import publish publish('test', HttpStreamFormat('hello there\n')) ``` ## Example using WebSockets Pushpin supports WebSockets by converting connection activity/messages into HTTP requests and sending them to the backend. For this example, we'll use Pushpin's [Express library](https://github.com/fanout/js-serve-grip). As before, please note that Pushpin is not Node/Express-specific and there are backend libraries for [other languages/frameworks, too](https://pushpin.org/docs/usage/#libraries). The Express library requires configuration and setting up a middleware handler: ```javascript const express = require('express'); const { ServeGrip } = require('@fanoutio/serve-grip'); var app = express(); // Instantiate the middleware and register it with Express const serveGrip = new ServeGrip({ grip: { 'control_uri': 'http://localhost:5561', 'key': 'changeme' } }); app.use(serveGrip); // Instantiate the publisher to use from your code to publish messages const publisher = serveGrip.getPublisher(); app.get('/hello', (req, res) => { res.send('hello world\n'); }); ``` With that structure in place, here's an example of a WebSocket endpoint: ```javascript const { WebSocketMessageFormat } = require( '@fanoutio/grip' ); app.post('/websocket', async (req, res) => { const { wsContext } = req.grip; // If this is a new connection, accept it and subscribe it to a channel if (wsContext.isOpening()) { wsContext.accept(); wsContext.subscribe('all'); } while (wsContext.canRecv()) { var message = wsContext.recv(); // If return value is null then connection is closed if (message == null) { wsContext.close(); break; } // broadcast the message to everyone connected await publisher.publishFormats('all', WebSocketMessageFormat(message)); } res.end(); }); ``` The above code binds all incoming connections to a channel called `all`. Any received messages are published out to all connected clients. What's particularly noteworthy is that the above endpoint is stateless. The app doesn't keep track of connections, and the handler code only runs whenever messages arrive. Restarting the app won't disconnect clients. The `while` loop is deceptive. It looks like it's looping for the lifetime of the WebSocket connection, but what it's really doing is looping through a batch of WebSocket messages that was just received via HTTP. Often this will be one message, and so the loop performs one iteration and then exits. Similarly, the `wsContext` object only exists for the duration of the handler invocation, rather than for the lifetime of the connection as you might expect. It may look like socket code, but it's all an illusion. :tophat: For details on the underlying protocol conversion, see the [WebSocket-Over-HTTP Protocol spec](https://pushpin.org/docs/protocols/websocket-over-http/). ## Example without a webserver Pushpin can also connect to backend servers via ZeroMQ instead of HTTP. This may be preferred for writing lower-level services where a real webserver isn't needed. The messages exchanged over the ZeroMQ connection contain the same information as HTTP, encoded as TNetStrings. To use a ZeroMQ backend, first make sure there's an appropriate route in Pushpin's `routes` file: ``` * zhttpreq/tcp://127.0.0.1:10000 ``` The above line tells Pushpin to bind a REQ-compatible socket on port 10000 that handlers can connect to. Activating an HTTP stream is as easy as responding on a REP socket: ```python import zmq import tnetstring zmq_context = zmq.Context() sock = zmq_context.socket(zmq.REP) sock.connect('tcp://127.0.0.1:10000') while True: req = tnetstring.loads(sock.recv()[1:]) resp = { 'id': req['id'], 'code': 200, 'reason': 'OK', 'headers': [ ['Grip-Hold', 'stream'], ['Grip-Channel', 'test'], ['Content-Type', 'text/plain'] ], 'body': 'welcome to the stream\n' } sock.send('T' + tnetstring.dumps(resp)) ``` ## Why another realtime solution? Pushpin is an ambitious project with two primary goals: * Make realtime API development easier. There are many other solutions out there that are excellent for building realtime apps, but few are useful within the context of *APIs*. For example, you can't use Socket.io to build Twitter's streaming API. A new kind of project is needed in this case. * Make realtime push behavior delegable. The reason there isn't a realtime push CDN yet is because the standards and practices necessary for delegating to a third party in a transparent way are not yet established. Pushpin is more than just another realtime push solution; it represents the next logical step in the evolution of realtime web architectures. To really understand Pushpin, you need to think of it as more like a gateway than a message queue. Pushpin does not persist data and it is agnostic to your application's data model. Your backend provides the mapping to whatever that data model is. Tools like Kafka and RabbitMQ are complementary. Pushpin is also agnostic to your API definition. Clients don't necessarily subscribe to "channels" or receive "messages". Clients make HTTP requests or send WebSocket frames, and your backend decides the meaning of those inputs. Pushpin could perhaps be awkwardly described as "a proxy server that enables web services to delegate the handling of realtime push primitives". On a practical level, there are many benefits to Pushpin that you don't see anywhere else: * The proxy design allows Pushpin to fit nicely within an API stack. This means it can inherit other facilities from your REST API, such as authentication, logging, throttling, etc. It can be combined with an API management system. * As your API scales, a multi-tiered architecture will become inevitable. With Pushpin you can easily do this from the start. * It works well with microservices. Each microservice can have its own Pushpin instance. No central bus needed. * Hot reload. Restarting the backend doesn't disconnect clients. * In the case of WebSocket messages being proxied out as HTTP requests, the messages may be handled statelessly by the backend. Messages from a single connection can even be load balanced across a set of backend instances. ## Install Check out the [the Install guide](https://pushpin.org/docs/install/), which covers how to install and run. There are packages available for Linux (Debian, Ubuntu, CentOS, Red Hat), Mac (Homebrew), or you can build from source. By default, Pushpin listens on port 7999 and requests are handled by its internal test handler. You can confirm the server is working by browsing to `http://localhost:7999/`. Next, you should modify the `routes` config file to route requests to your backend webserver. See [Configuration](https://pushpin.org/docs/configuration/). ## Scalability Pushpin is horizontally scalable. Instances don’t talk to each other, and sticky routing is not needed. Backends must publish data to all instances to ensure clients connected to any instance will receive the data. Most of the backend libraries support configuring more than one Pushpin instance, so that a single publish call will send data to multiple instances at once. Optionally, ZeroMQ PUB/SUB can be used to send data to Pushpin instead of using HTTP POST. When this method is used, subscription information is forwarded to each publisher, such that data will only be published to instances that have listeners. As for vertical scalability, Pushpin has been tested with up to [1 million concurrent connections](https://github.com/fanout/pushpin-c1m) running on a single DigitalOcean droplet with 8 CPU cores. In practice, you may want to plan for fewer connections per instance, depending on your throughput. The new connection accept rate is about 800/sec (though this also depends on the speed of your backend), and the message throughput is about 8,000/sec. The important thing is that Pushpin is horizontally scalable which is effectively limitless. ## What does the name mean? Pushpin means to "pin" connections open for "pushing". ## License Pushpin is offered under the Apache License, Version 2.0. See the LICENSE file. ## Footnotes 1: Pushpin can communicate WebSocket activity to the backend using either HTTP or WebSockets. Conversion to HTTP is generally recommended as it makes the backend easier to reason about. 2: GRIP (Generic Realtime Intermediary Protocol) is the name of Pushpin's backend protocol. More about that [here](https://pushpin.org/docs/protocols/grip/). pushpin-1.41.0/SECURITY.md000066400000000000000000000015741504671364300150740ustar00rootroot00000000000000## Report a security issue The project team welcomes security reports and is committed to providing prompt attention to security issues. Security issues should be reported privately via [Fastly’s security issue reporting process](https://www.fastly.com/security/report-security-issue). ## Security advisories Remediation of security vulnerabilities is prioritized by the project team. The project team endeavors to coordinate remediation with third-party stakeholders, and is committed to transparency in the disclosure process. The team announces security issues via [GitHub](https://github.com/fastly/pushpin/releases) as well as [RustSec](https://rustsec.org/advisories/) on a best-effort basis. Note that communications related to security issues in Fastly-maintained OSS as described here are distinct from [Fastly Security Advisories](https://www.fastly.com/security-advisories). pushpin-1.41.0/benches/000077500000000000000000000000001504671364300147035ustar00rootroot00000000000000pushpin-1.41.0/benches/client.rs000066400000000000000000000126231504671364300165330ustar00rootroot00000000000000/* * Copyright (C) 2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use criterion::{criterion_group, criterion_main, Criterion}; use mio::net::TcpListener; use pushpin::connmgr::client::TestClient; use pushpin::core::channel; use pushpin::core::executor::Executor; use pushpin::core::io::{AsyncReadExt, AsyncWriteExt}; use pushpin::core::net::{AsyncTcpListener, AsyncTcpStream}; use pushpin::core::reactor::Reactor; use std::net::SocketAddr; use std::rc::Rc; use std::str; const REQS_PER_ITER: usize = 10; fn req(listener: TcpListener, start: F1, wait: F2) -> TcpListener where F1: Fn(SocketAddr) + 'static, F2: Fn() + 'static, { let executor = Executor::new(REQS_PER_ITER + 1); let addr = listener.local_addr().unwrap(); let (s, r) = channel::channel(1); for _ in 0..REQS_PER_ITER { start(addr); } let spawner = executor.spawner(); executor .spawn(async move { let s = channel::AsyncSender::new(s); let listener = AsyncTcpListener::new(listener); for _ in 0..REQS_PER_ITER { let (stream, _) = listener.accept().await.unwrap(); let mut stream = AsyncTcpStream::new(stream); spawner .spawn(async move { let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).await.unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ).await .unwrap(); }) .unwrap(); } s.send(listener.into_inner()).await.unwrap(); }) .unwrap(); executor .run(|timeout| Reactor::current().unwrap().poll(timeout)) .unwrap(); for _ in 0..REQS_PER_ITER { wait(); } let listener = r.recv().unwrap(); listener } fn criterion_benchmark(c: &mut Criterion) { let mut req_listener = Some(TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap()); let mut stream_listener = Some(TcpListener::bind("127.0.0.1:0".parse().unwrap()).unwrap()); let _reactor = Reactor::new(REQS_PER_ITER * 10); { let client = Rc::new(TestClient::new(1)); c.bench_function("req_client workers=1", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); req_listener = Some(req( req_listener.take().unwrap(), move |addr| c1.do_req(addr), move || c2.wait_req(), )) }) }); c.bench_function("stream_client workers=1", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); stream_listener = Some(req( stream_listener.take().unwrap(), move |addr| c1.do_stream_http(addr), move || c2.wait_stream(), )) }) }); } { let client = Rc::new(TestClient::new(2)); c.bench_function("req_client workers=2", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); req_listener = Some(req( req_listener.take().unwrap(), move |addr| c1.do_req(addr), move || c2.wait_req(), )) }) }); c.bench_function("stream_client workers=2", |b| { b.iter(|| { let c1 = Rc::clone(&client); let c2 = Rc::clone(&client); stream_listener = Some(req( stream_listener.take().unwrap(), move |addr| c1.do_stream_http(addr), move || c2.wait_stream(), )) }) }); } } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); pushpin-1.41.0/benches/server.rs000066400000000000000000000113671504671364300165670ustar00rootroot00000000000000/* * Copyright (C) 2020 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use criterion::{criterion_group, criterion_main, Criterion}; use pushpin::connmgr::connection::testutil::{ BenchServerReqConnection, BenchServerReqHandler, BenchServerStreamConnection, BenchServerStreamHandler, }; use pushpin::connmgr::server::TestServer; use pushpin::connmgr::websocket::testutil::{BenchRecvMessage, BenchSendMessage}; use pushpin::core::executor::Executor; use pushpin::core::io::{AsyncReadExt, AsyncWriteExt}; use pushpin::core::net::AsyncTcpStream; use pushpin::core::reactor::Reactor; use std::io::{self, Write}; use std::net::SocketAddr; use std::str; const REQS_PER_ITER: usize = 10; fn req(addr: SocketAddr) { let reactor = Reactor::new(REQS_PER_ITER * 10); let executor = Executor::new(REQS_PER_ITER); for _ in 0..REQS_PER_ITER { executor .spawn(async move { let mut client = AsyncTcpStream::connect(&[addr]).await.unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .await .unwrap(); let mut resp = [0u8; 1024]; let mut resp = io::Cursor::new(&mut resp[..]); loop { let mut buf = [0; 1024]; let size = client.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.write(&buf[..size]).unwrap(); } let size = resp.position() as usize; let resp = str::from_utf8(&resp.get_ref()[..size]).unwrap(); assert_eq!(resp, "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n"); }) .unwrap(); } executor.run(|timeout| reactor.poll(timeout)).unwrap(); } fn criterion_benchmark(c: &mut Criterion) { { let t = BenchServerReqHandler::new(); c.bench_function("req_handler", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerStreamHandler::new(); c.bench_function("stream_handler", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerReqConnection::new(); c.bench_function("req_connection", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchServerStreamConnection::new(); c.bench_function("stream_connection", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchSendMessage::new(false); c.bench_function("ws_send", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchSendMessage::new(true); c.bench_function("ws_send_deflate", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchRecvMessage::new(false); c.bench_function("ws_recv", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let t = BenchRecvMessage::new(true); c.bench_function("ws_recv_deflate", |b| { b.iter_batched_ref(|| t.init(), |i| t.run(i), criterion::BatchSize::SmallInput) }); } { let server = TestServer::new(1); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req_server workers=1", |b| b.iter(|| req(req_addr))); c.bench_function("stream_server workers=1", |b| b.iter(|| req(stream_addr))); } { let server = TestServer::new(2); let req_addr = server.req_addr(); let stream_addr = server.stream_addr(); c.bench_function("req_server workers=2", |b| b.iter(|| req(req_addr))); c.bench_function("stream_server workers=2", |b| b.iter(|| req(stream_addr))); } } criterion_group!(benches, criterion_benchmark); criterion_main!(benches); pushpin-1.41.0/build.rs000066400000000000000000000376531504671364300147570ustar00rootroot00000000000000use std::collections::HashMap; use std::env; use std::error::Error; use std::ffi::OsStr; use std::fmt; use std::fs::{self, File}; use std::io::{self, BufRead, Write}; use std::os::unix::ffi::OsStrExt; use std::path::{Path, PathBuf}; use std::process::{Command, ExitStatus, Output, Stdio}; use std::str::FromStr; use time::macros::format_description; use time::OffsetDateTime; const DEFAULT_PREFIX: &str = "/usr/local"; fn get_version() -> String { let mut version = env!("CARGO_PKG_VERSION").to_string(); if version.ends_with("-dev") { let format = format_description!("[year][month][day]"); let date_str = OffsetDateTime::now_utc().format(&format).unwrap(); version.push_str(&format!("-{}", date_str)); } version } #[derive(Clone)] struct LibVersion { maj: u16, min: u16, orig: String, } #[derive(Debug)] struct ParseVersionError; impl fmt::Display for ParseVersionError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { Ok(write!(f, "failed to parse version")?) } } impl Error for ParseVersionError {} impl FromStr for LibVersion { type Err = ParseVersionError; fn from_str(s: &str) -> Result { let parts: Vec<&str> = s.split('.').collect(); if parts.len() < 2 { return Err(ParseVersionError); } let (maj, min): (u16, u16) = match (parts[0].parse(), parts[1].parse()) { (Ok(maj), Ok(min)) => (maj, min), _ => return Err(ParseVersionError), }; Ok(LibVersion { maj, min, orig: s.to_string(), }) } } fn check_version( pkg: &str, found: LibVersion, expect_maj: u16, expect_min: u16, ) -> Result<(), Box> { if found.maj < expect_maj || (found.maj == expect_maj && found.min < expect_min) { return Err(format!( "{} version >={}.{} required, found: {}", pkg, expect_maj, expect_min, found.orig, ) .into()); } Ok(()) } fn prefixed_vars(prefix: &str) -> HashMap { let mut out = HashMap::new(); out.insert("BINDIR".into(), format!("{}/bin", prefix)); out.insert("CONFIGDIR".into(), format!("{}/etc", prefix)); out.insert("LIBDIR".into(), format!("{}/lib", prefix)); out.insert("LOGDIR".into(), "/var/log".into()); out.insert("RUNDIR".into(), "/var/run".into()); out } fn env_or_default(name: &str, defaults: &HashMap) -> String { match env::var(name) { Ok(s) => s, Err(_) => defaults.get(name).unwrap().to_string(), } } fn write_if_different(dest: &Path, content: &[u8]) -> Result<(), Box> { let do_write = match fs::read(dest) { Ok(v) => v != content, Err(e) if e.kind() == io::ErrorKind::NotFound => true, Err(e) => return Err(e.into()), }; if do_write { fs::write(dest, content)?; } Ok(()) } fn write_cpp_conf_pri( dest: &Path, release: bool, include_paths: &[&Path], deny_warnings: bool, ) -> Result<(), Box> { let mut out = Vec::new(); writeln!(&mut out, "CONFIG -= debug_and_release")?; if release { writeln!(&mut out, "CONFIG += release")?; } else { writeln!(&mut out, "CONFIG += debug")?; } writeln!(&mut out)?; for path in include_paths { writeln!(&mut out, "INCLUDEPATH += {}", path.display())?; } writeln!(&mut out)?; if deny_warnings { writeln!(&mut out, "QMAKE_CXXFLAGS += \"-Werror\"")?; } write_if_different(dest, &out) } fn write_postbuild_conf_pri( dest: &Path, bin_dir: &str, lib_dir: &str, config_dir: &str, run_dir: &str, log_dir: &str, ) -> Result<(), Box> { let mut out = Vec::new(); writeln!(&mut out, "BINDIR = {}", bin_dir)?; writeln!(&mut out, "LIBDIR = {}/pushpin", lib_dir)?; writeln!(&mut out, "CONFIGDIR = {}/pushpin", config_dir)?; writeln!(&mut out, "RUNDIR = {}/pushpin", run_dir)?; writeln!(&mut out, "LOGDIR = {}/pushpin", log_dir)?; write_if_different(dest, &out) } // returned vec size guaranteed >= 1 fn get_args_lossy(command: &mut Command) -> Vec { let mut args = vec![command.get_program().to_string_lossy().into_owned()]; for s in command.get_args() { args.push(s.to_string_lossy().into_owned()); } args } // convert Result to Result, separating stdout fn take_stdout(result: io::Result) -> (io::Result, Vec) { match result { Ok(output) => (Ok(output.status), output.stdout), Err(e) => (Err(e), Vec::new()), } } fn check_command_result( program: &str, result: io::Result, ) -> Result<(), Box> { let status = match result { Ok(status) => status, Err(e) => return Err(format!("{} failed: {}", program, e).into()), }; if !status.success() { return Err(format!("{} failed, {}", program, status).into()); } Ok(()) } fn check_command(command: &mut Command) -> Result<(), Box> { let args = get_args_lossy(command); println!("{}", args.join(" ")); check_command_result(&args[0], command.status()) } fn check_command_capture_stdout(command: &mut Command) -> Result, Box> { let args = get_args_lossy(command); println!("{}", args.join(" ")); // don't capture stderr let command = command.stderr(Stdio::inherit()); let (result, output) = take_stdout(command.output()); check_command_result(&args[0], result)?; Ok(output) } fn check_qmake(qmake_path: &Path) -> Result> { let version: LibVersion = { let output = check_command_capture_stdout(Command::new(qmake_path).args(["-query", "QT_VERSION"]))?; let s = String::from_utf8(output)?; let s = s.trim(); match s.parse() { Ok(v) => v, Err(_) => return Err(format!("unexpected qt version string: [{}]", s).into()), } }; check_version("qt", version.clone(), 5, 12)?; Ok(version) } fn find_in_path(name: &str) -> Option { for d in env::var("PATH").unwrap_or_default().split(':') { if d.is_empty() { continue; } let path = Path::new(d).join(name); if path.exists() { return Some(path); } } None } fn find_qmake() -> Result<(PathBuf, LibVersion), Box> { let mut errors = Vec::new(); // check for a usable qmake in PATH let names = &["qmake", "qmake6", "qmake5"]; for name in names { if let Some(p) = find_in_path(name) { match check_qmake(&p) { Ok(version) => return Ok((p, version)), Err(e) => errors.push(format!("skipping {}: {}", p.display(), e)), } } } if errors.is_empty() { errors.push(format!("none of ({}) found in PATH", names.join(", "))); } // check pkg-config let pkg = "Qt5Core"; match pkg_config::get_variable(pkg, "host_bins") { Ok(host_bins) if !host_bins.is_empty() => { let host_bins = PathBuf::from(host_bins); match fs::canonicalize(host_bins.join("qmake")) { Ok(p) => match check_qmake(&p) { Ok(version) => return Ok((p, version)), Err(e) => errors.push(format!("skipping {}: {}", p.display(), e)), }, Err(e) => errors.push(format!("qmake not found in {}: {}", host_bins.display(), e)), } } Ok(_) => errors.push(format!( "pkg-config variable host_bins does not exist for {}", pkg )), Err(e) => errors.push(format!("pkg-config error for {}: {}", pkg, e)), } Err(format!("unable to find a usable qmake: {}", errors.join(", ")).into()) } fn get_qmake() -> Result<(PathBuf, LibVersion), Box> { match env::var("QMAKE") { Ok(s) => { let path = PathBuf::from(s); let version = check_qmake(&path)?; Ok((path, version)) } Err(env::VarError::NotPresent) => find_qmake(), Err(env::VarError::NotUnicode(_)) => Err("QMAKE not unicode".into()), } } fn contains_file_prefix(dir: &Path, prefix: &str) -> Result { for entry in fs::read_dir(dir)? { let entry = entry?; if entry.file_name().as_bytes().starts_with(prefix.as_bytes()) { return Ok(true); } } Ok(false) } fn get_qt_lib_prefix(lib_dir: &Path, version_maj: u16) -> Result> { let prefixes = if cfg!(target_os = "macos") { [format!("Qt{}", version_maj), "Qt".to_string()] } else { [format!("libQt{}", version_maj), "libQt".to_string()] }; for prefix in &prefixes { if contains_file_prefix(lib_dir, prefix)? { return Ok(prefix.strip_prefix("lib").unwrap_or(prefix).to_string()); } } Err(format!( "no files in {} beginning with any of: {}", lib_dir.display(), prefixes.join(", ") ) .into()) } fn find_boost_include_dir() -> Result> { let paths = [ "/opt/homebrew/include", "/usr/local/include", "/usr/include", ]; let version_filename = "boost/version.hpp"; for path in paths { let path = PathBuf::from(path); let full_path = path.join(version_filename); if !full_path.exists() { continue; } let file = File::open(&full_path)?; let reader = io::BufReader::new(file); let mut version_line = None; for line in reader.lines() { match line { Ok(s) if s.contains("#define BOOST_LIB_VERSION") => version_line = Some(s), Ok(_) => continue, Err(e) => { return Err(format!("failed to read {}: {}", full_path.display(), e).into()) } } } let version_line = match version_line { Some(s) => s, None => return Err(format!("version line not found in {}", full_path.display()).into()), }; let parts: Vec<&str> = version_line.split('"').collect(); if parts.len() < 2 { return Err(format!("failed to parse version line in {}", full_path.display()).into()); } let version = parts[1].replace('_', "."); let version = match version.parse() { Ok(v) => v, Err(_) => return Err(format!("unexpected boost version string: {}", version).into()), }; check_version("boost", version, 1, 71)?; return Ok(path); } Err(format!( "{} not found in any of: {}", version_filename, paths.join(", ") ) .into()) } fn contains_subslice(haystack: &[T], needle: &[T]) -> bool { haystack.windows(needle.len()).any(|w| w == needle) } fn main() -> Result<(), Box> { let crate_dir = env::var("CARGO_MANIFEST_DIR").unwrap(); cbindgen::generate(crate_dir).map_or_else( |error| match error { cbindgen::Error::ParseSyntaxError { .. } => {} e => panic!("{:?}", e), }, |bindings| { bindings.write_to_file("target/include/rust/bindings.h"); }, ); let (qmake_path, qt_version) = get_qmake()?; let qt_install_libs = { let output = check_command_capture_stdout( Command::new(&qmake_path).args(["-query", "QT_INSTALL_LIBS"]), )?; let libs_dir = PathBuf::from(String::from_utf8(output)?.trim()); fs::canonicalize(&libs_dir) .map_err(|_| format!("QT_INSTALL_LIBS dir {} not found", libs_dir.display()))? }; let qt_lib_prefix = get_qt_lib_prefix(&qt_install_libs, qt_version.maj)?; let boost_include_dir = match env::var("BOOST_INCLUDE_DIR") { Ok(s) => PathBuf::from(s), Err(env::VarError::NotPresent) => find_boost_include_dir()?, Err(env::VarError::NotUnicode(_)) => return Err("BOOST_INCLUDE_DIR not unicode".into()), }; let default_vars = { let prefix = match env::var("PREFIX") { Ok(s) => Some(s), Err(env::VarError::NotPresent) => None, Err(env::VarError::NotUnicode(_)) => return Err("PREFIX not unicode".into()), }; if let Some(prefix) = prefix { prefixed_vars(&prefix) } else { prefixed_vars(DEFAULT_PREFIX) } }; let bin_dir = env_or_default("BINDIR", &default_vars); let config_dir = env_or_default("CONFIGDIR", &default_vars); let lib_dir = env_or_default("LIBDIR", &default_vars); let log_dir = env_or_default("LOGDIR", &default_vars); let run_dir = env_or_default("RUNDIR", &default_vars); let root_dir = PathBuf::from(env::var("CARGO_MANIFEST_DIR")?); let out_dir = PathBuf::from(env::var("OUT_DIR")?); let profile = env::var("PROFILE")?; let cpp_pro = root_dir.join("src/cpp.pro"); let cpp_tests_pro = root_dir.join("src/cpptests.pro"); for dir in ["moc", "obj", "test-moc", "test-obj", "test-work"] { fs::create_dir_all(out_dir.join(dir))?; } let mut include_paths = Vec::new(); include_paths.push(out_dir.as_ref()); if boost_include_dir != Path::new("/usr/include") { include_paths.push(boost_include_dir.as_ref()); } let deny_warnings = match env::var("CARGO_ENCODED_RUSTFLAGS") { Ok(s) => { let flags: Vec<&str> = s.split('\x1f').collect(); contains_subslice(&flags, &["-D", "warnings"]) } Err(env::VarError::NotPresent) => false, Err(env::VarError::NotUnicode(_)) => { return Err("CARGO_ENCODED_RUSTFLAGS not unicode".into()) } }; write_cpp_conf_pri( &out_dir.join("conf.pri"), profile == "release", &include_paths, deny_warnings, )?; write_postbuild_conf_pri( &Path::new("postbuild").join("conf.pri"), &bin_dir, &lib_dir, &config_dir, &run_dir, &log_dir, )?; check_command(Command::new(&qmake_path).args([ OsStr::new("-o"), out_dir.join("Makefile").as_os_str(), cpp_pro.as_os_str(), ]))?; check_command(Command::new(&qmake_path).args([ OsStr::new("-o"), out_dir.join("Makefile.test").as_os_str(), cpp_tests_pro.as_os_str(), ]))?; check_command( Command::new(&qmake_path) .args(["-o", "Makefile", "postbuild.pro"]) .current_dir("postbuild"), )?; check_command( Command::new("make") .env("MAKEFLAGS", env::var("CARGO_MAKEFLAGS")?) .args(["-f", "Makefile"]) .current_dir(&out_dir), )?; check_command( Command::new("make") .env("MAKEFLAGS", env::var("CARGO_MAKEFLAGS")?) .args(["-f", "Makefile.test"]) .current_dir(&out_dir), )?; println!("cargo:rustc-env=APP_VERSION={}", get_version()); println!("cargo:rustc-env=CONFIG_DIR={}/pushpin", config_dir); println!("cargo:rustc-env=LIB_DIR={}/pushpin", lib_dir); println!("cargo:rustc-cfg=qt_lib_prefix=\"{}\"", qt_lib_prefix); println!("cargo:rustc-link-search={}", out_dir.display()); if cfg!(target_os = "macos") { println!( "cargo:rustc-link-search=framework={}", qt_install_libs.display() ); } else { println!("cargo:rustc-link-search={}", qt_install_libs.display()); } println!("cargo:rerun-if-env-changed=RELEASE"); println!("cargo:rerun-if-env-changed=PREFIX"); println!("cargo:rerun-if-env-changed=BINDIR"); println!("cargo:rerun-if-env-changed=CONFIGDIR"); println!("cargo:rerun-if-env-changed=LIBDIR"); println!("cargo:rerun-if-env-changed=LOGDIR"); println!("cargo:rerun-if-env-changed=RUNDIR"); println!("cargo:rerun-if-changed=src"); println!("cargo:rerun-if-changed=cbindgen.toml"); Ok(()) } pushpin-1.41.0/cbindgen.toml000066400000000000000000000003621504671364300157430ustar00rootroot00000000000000include_guard = "RUST_BINDINGS_H" namespace = "ffi" after_includes = """ // zmq crate version 0.10 bundles zeromq 4.3 #define WZMQ_VERSION_MAJOR 4 #define WZMQ_VERSION_MINOR 3 """ [parse] parse_deps = true include = ["jsonwebtoken", "zmq"] pushpin-1.41.0/examples/000077500000000000000000000000001504671364300151125ustar00rootroot00000000000000pushpin-1.41.0/examples/config/000077500000000000000000000000001504671364300163575ustar00rootroot00000000000000pushpin-1.41.0/examples/config/pushpin.conf000066400000000000000000000103601504671364300207140ustar00rootroot00000000000000[global] include={libdir}/internal.conf # directory to save runtime files rundir=run # prefix for zmq ipc specs ipc_prefix=pushpin- # port offset for zmq tcp specs and http control server port_offset=0 # TTL (seconds) for connection stats stats_connection_ttl=120 # whether to send individual connection stats stats_connection_send=true [runner] # services to start services=connmgr,proxy,handler # plain HTTP port to listen on for client connections http_port=7999 # list of HTTPS ports to listen on for client connections (you must have certs set) #https_ports=443 # list of unix socket paths to listen on for client connections #local_ports={rundir}/{ipc_prefix}server # directory to save log files logdir=log # logging level. 2 = info, >2 = verbose log_level=2 # client full request header must fit in this buffer client_buffer_size=8192 # maximum number of client connections client_maxconn=50000 # whether connections can use compression allow_compression=false # paths mongrel2_bin=mongrel2 m2sh_bin=m2sh zurl_bin=zurl [proxy] # routes config file (path relative to location of this file) routesfile=routes # enable debug mode to get informative error responses debug=false # whether to use automatic CORS and JSON-P wrapping auto_cross_origin=false # whether to accept x-forwarded-proto accept_x_forwarded_protocol=false # whether to assert x-forwarded-proto set_x_forwarded_protocol=proto-only # how to treat x-forwarded-for. example: "truncate:0,append" x_forwarded_for= # how to treat x-forwarded-for if grip-signed x_forwarded_for_trusted= # the following headers must be marked in order to qualify as orig orig_headers_need_mark= # whether to accept Pushpin-Route header accept_pushpin_route=false # value to append to the CDN-Loop header cdn_loop= # include client IP address in logs log_from=false # include client user agent in logs log_user_agent=false # for signing proxied requests sig_iss=pushpin # for signing proxied requests. use "base64:" prefix for binary key sig_key=changeme # use this to allow grip to be forwarded upstream (e.g. to fanout.io) upstream_key= # for the sockjs iframe transport sockjs_url=http://cdn.jsdelivr.net/sockjs/0.3.4/sockjs.min.js # updates check has three modes: # report: check for new pushpin version and report anonymous usage info to # the pushpin developers # check: check for new pushpin version only, don't report anything # off: don't do any reporting or checking # pushpin will output a log message when a new version is available. report # mode helps the pushpin project build credibility, so please enable it if you # enjoy this software :) updates_check=report # use this field to identify your organization in updates requests. if left # blank, updates requests will be anonymous organization_name= [handler] # ipc permissions (octal) #ipc_file_mode=777 # bind PULL for receiving publish commands push_in_spec=tcp://127.0.0.1:5560 # list of bind SUB for receiving published messages push_in_sub_specs=tcp://127.0.0.1:5562 # whether the above SUB socket should connect instead of bind push_in_sub_connect=false # addr/port to listen on for receiving publish commands via HTTP push_in_http_addr=127.0.0.1 push_in_http_port=5561 # maximum headers and body size in bytes when receiving publish commands via HTTP push_in_http_max_headers_size=10000 push_in_http_max_body_size=1000000 # bind PUB for sending stats (metrics, subscription info, etc) stats_spec=ipc://{rundir}/{ipc_prefix}stats # bind REP for responding to commands command_spec=tcp://127.0.0.1:5563 # max messages per second message_rate=2500 # max rate-limited messages message_hwm=25000 # set to report blocks counts in stats (content size / block size) #message_block_size= # max time (milliseconds) for out-of-order messages to wait message_wait=5000 # time (seconds) to cache message ids id_cache_ttl=60 # retry/recover sessions soon after the first subscription to a channel update_on_first_subscription=true # max subscriptions per connection connection_subscription_max=20 # time (seconds) to linger response mode subscriptions subscription_linger=60 # TTL (seconds) for subscription stats stats_subscription_ttl=60 # interval (seconds) to send report stats stats_report_interval=10 # stats output format stats_format=tnetstring pushpin-1.41.0/examples/config/routes000066400000000000000000000000071504671364300176200ustar00rootroot00000000000000* test pushpin-1.41.0/examples/config/runner/000077500000000000000000000000001504671364300176705ustar00rootroot00000000000000pushpin-1.41.0/examples/config/runner/certs/000077500000000000000000000000001504671364300210105ustar00rootroot00000000000000pushpin-1.41.0/examples/config/runner/certs/README000066400000000000000000000000211504671364300216610ustar00rootroot00000000000000Empty directory. pushpin-1.41.0/header.APACHE2000066400000000000000000000010611504671364300154470ustar00rootroot00000000000000 * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * pushpin-1.41.0/package.sh000077500000000000000000000012051504671364300152240ustar00rootroot00000000000000#!/bin/sh set -e if [ $# -lt 1 ]; then echo "usage: $0 [version]" exit 1 fi VERSION=$1 DESTDIR=build/pushpin-$VERSION mkdir -p $DESTDIR cp -a .gitignore benches build.rs Cargo.lock Cargo.toml cbindgen.toml CHANGELOG.md examples LICENSE Makefile postbuild SECURITY.md README.md src tools $DESTDIR sed -i.orig -e "s/^version = .*/version = \"$VERSION\"/g" $DESTDIR/Cargo.toml rm $DESTDIR/Cargo.toml.orig cd $DESTDIR mkdir -p .cargo cat >.cargo/config.toml < pushpin.conf.inst pushpin_conf_inst.depends = ../examples/config/pushpin.conf conf.pri QMAKE_EXTRA_TARGETS += pushpin_conf_inst PRE_TARGETDEPS += pushpin.conf.inst # install bin files unix:!isEmpty(BINDIR) { binfiles.path = $$BINDIR binfiles.files = \ $$bin_dir/pushpin-connmgr \ $$bin_dir/m2adapter \ $$bin_dir/pushpin-proxy \ $$bin_dir/pushpin-handler \ $$root_dir/pushpin \ $$bin_dir/pushpin-publish binfiles.CONFIG += no_check_exist executable symlinks.path = $$BINDIR symlinks.extra = ln -sf pushpin-connmgr $(INSTALL_ROOT)$$symlinks.path/pushpin-condure INSTALLS += binfiles symlinks } # install lib files libfiles.path = $$LIBDIR libfiles.files = $$PWD/../src/internal.conf runnerlibfiles.path = $$LIBDIR/runner runnerlibfiles.files = $$PWD/../src/runner/*.template # install config files runnerconfigfiles.path = $$CONFIGDIR/runner runnerconfigfiles.files = $$PWD/../examples/config/runner/certs routes.path = $$CONFIGDIR routes.extra = test -e $(INSTALL_ROOT)$$routes.path/routes || cp -f ../examples/config/routes $(INSTALL_ROOT)$$routes.path/routes pushpinconf.path = $$CONFIGDIR pushpinconf.extra = test -e $(INSTALL_ROOT)$$pushpinconf.path/pushpin.conf || cp -f pushpin.conf.inst $(INSTALL_ROOT)$$pushpinconf.path/pushpin.conf INSTALLS += libfiles runnerlibfiles runnerconfigfiles routes pushpinconf pushpin-1.41.0/src/000077500000000000000000000000001504671364300140635ustar00rootroot00000000000000pushpin-1.41.0/src/bin/000077500000000000000000000000001504671364300146335ustar00rootroot00000000000000pushpin-1.41.0/src/bin/m2adapter.rs000066400000000000000000000016161504671364300170640ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use pushpin::core::call_c_main; use pushpin::import_cpp; use std::env; use std::process::ExitCode; import_cpp! { fn m2adapter_main(argc: libc::c_int, argv: *const *const libc::c_char) -> libc::c_int; } fn main() -> ExitCode { unsafe { ExitCode::from(call_c_main(m2adapter_main, env::args_os())) } } pushpin-1.41.0/src/bin/pushpin-connmgr.rs000066400000000000000000000451131504671364300203340ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use clap::{Arg, ArgAction, Command}; use log::{error, LevelFilter}; use pushpin::connmgr::{run, App, Config, ListenConfig, ListenSpec}; use pushpin::core::log::{get_simple_logger, local_offset_check}; use pushpin::core::version; use std::error::Error; use std::path::PathBuf; use std::process; use std::time::Duration; // safety values const WORKERS_MAX: usize = 1024; const CONNS_MAX: usize = 10_000_000; const PRIVATE_SUBNETS: &[&str] = &[ "127.0.0.0/8", "10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "169.254.0.0/16", "::1/128", "fc00::/7", "fe80::/10", ]; struct Args { id: String, workers: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, blocks_max: usize, connection_blocks_max: usize, messages_max: usize, req_timeout: usize, stream_timeout: usize, listen: Vec, zclient_req_specs: Vec, zclient_stream_specs: Vec, zclient_connect: bool, zserver_req_specs: Vec, zserver_stream_specs: Vec, zserver_connect: bool, ipc_file_mode: u32, tls_identities_dir: String, allow_compression: bool, deny_out_internal: bool, } fn process_args_and_run(args: Args) -> Result<(), Box> { if args.id.is_empty() || args.id.contains(' ') { return Err("failed to parse id: value cannot be empty or contain a space".into()); } if args.workers > WORKERS_MAX { return Err("failed to parse workers: value too large".into()); } if args.req_maxconn + args.stream_maxconn > CONNS_MAX { return Err("total maxconn is too large".into()); } if args.blocks_max < args.stream_maxconn * 2 { return Err("blocks-max is too small".into()); } if args.connection_blocks_max < 2 { return Err("connection-blocks-max is too small".into()); } let mut config = Config { instance_id: args.id, workers: args.workers, req_maxconn: args.req_maxconn, stream_maxconn: args.stream_maxconn, buffer_size: args.buffer_size, body_buffer_size: args.body_buffer_size, blocks_max: args.blocks_max, connection_blocks_max: args.connection_blocks_max, messages_max: args.messages_max, req_timeout: Duration::from_secs(args.req_timeout as u64), stream_timeout: Duration::from_secs(args.stream_timeout as u64), listen: Vec::new(), zclient_req: args.zclient_req_specs, zclient_stream: args.zclient_stream_specs, zclient_connect: args.zclient_connect, zserver_req: args.zserver_req_specs, zserver_stream: args.zserver_stream_specs, zserver_connect: args.zserver_connect, ipc_file_mode: args.ipc_file_mode, certs_dir: PathBuf::from(args.tls_identities_dir), allow_compression: args.allow_compression, deny: Vec::new(), }; for v in args.listen.iter() { let mut parts = v.split(','); // there's always a first part let part1 = parts.next().unwrap(); let mut stream = true; let mut tls = false; let mut default_cert = None; let mut local = false; let mut mode = None; let mut user = None; let mut group = None; for part in parts { let (k, v) = match part.find('=') { Some(pos) => (&part[..pos], &part[(pos + 1)..]), None => (part, ""), }; match k { "req" => stream = false, "stream" => stream = true, "tls" => tls = true, "default-cert" => default_cert = Some(String::from(v)), "local" => local = true, "mode" => match u32::from_str_radix(v, 8) { Ok(x) => mode = Some(x), Err(e) => return Err(format!("failed to parse mode: {}", e).into()), }, "user" => user = Some(String::from(v)), "group" => group = Some(String::from(v)), _ => return Err(format!("failed to parse listen: invalid param: {}", part).into()), } } let spec = if local { ListenSpec::Local { path: PathBuf::from(part1), mode, user, group, } } else { let port_pos = match part1.rfind(':') { Some(pos) => pos + 1, None => 0, }; let port = &part1[port_pos..]; if port.parse::().is_err() { return Err(format!("failed to parse listen: invalid port {}", port).into()); } let addr = if port_pos > 0 { String::from(part1) } else { format!("0.0.0.0:{}", part1) }; let addr = match addr.parse() { Ok(addr) => addr, Err(e) => { return Err(format!("failed to parse listen: {}", e).into()); } }; ListenSpec::Tcp { addr, tls, default_cert, } }; config.listen.push(ListenConfig { spec, stream }); } if args.deny_out_internal { for s in PRIVATE_SUBNETS.iter() { config.deny.push(s.parse().unwrap()); } } run(&config) } fn main() { let matches = Command::new("pushpin-connmgr") .version(version()) .about("HTTP/WebSocket connection manager") .arg( Arg::new("log-level") .long("log-level") .num_args(1) .value_name("N") .help("Log level") .default_value("2"), ) .arg( Arg::new("id") .long("id") .num_args(1) .value_name("ID") .help("Instance ID") .default_value("connmgr"), ) .arg( Arg::new("workers") .long("workers") .num_args(1) .value_name("N") .help("Number of worker threads") .default_value("2"), ) .arg( Arg::new("req-maxconn") .long("req-maxconn") .num_args(1) .value_name("N") .help("Maximum number of concurrent connections in req mode") .default_value("100"), ) .arg( Arg::new("stream-maxconn") .long("stream-maxconn") .num_args(1) .value_name("N") .help("Maximum number of concurrent connections in stream mode") .default_value("10000"), ) .arg( Arg::new("buffer-size") .long("buffer-size") .num_args(1) .value_name("N") .help("Connection buffer size (two buffers per connection)") .default_value("8192"), ) .arg( Arg::new("body-buffer-size") .long("body-buffer-size") .num_args(1) .value_name("N") .help("Body buffer size for connections in req mode") .default_value("100000"), ) .arg( Arg::new("blocks-max") .long("blocks-max") .num_args(1) .value_name("N") .help("Maximum number of buffer blocks in stream mode (minimum 2*maxconn)"), ) .arg( Arg::new("connection-blocks-max") .long("connection-blocks-max") .num_args(1) .value_name("N") .help("Maximum number of buffer blocks per connection in stream mode (minimum 2)") .default_value("2"), ) .arg( Arg::new("messages-max") .long("messages-max") .num_args(1) .value_name("N") .help("Maximum number of queued WebSocket messages per connection") .default_value("100"), ) .arg( Arg::new("req-timeout") .long("req-timeout") .num_args(1) .value_name("N") .help("Connection timeout in req mode (seconds)") .default_value("30"), ) .arg( Arg::new("stream-timeout") .long("stream-timeout") .num_args(1) .value_name("N") .help("Connection timeout in stream mode (seconds)") .default_value("1800"), ) .arg( Arg::new("listen") .long("listen") .num_args(1) .value_name("[addr:]port[,params...]") .action(ArgAction::Append) .help("Port to listen on"), ) .arg( Arg::new("zclient-req") .long("zclient-req") .num_args(1) .value_name("spec") .action(ArgAction::Append) .help("ZeroMQ client REQ spec") .default_value("ipc://client"), ) .arg( Arg::new("zclient-stream") .long("zclient-stream") .num_args(1) .value_name("spec-base") .action(ArgAction::Append) .help("ZeroMQ client PUSH/ROUTER/SUB spec base") .default_value("ipc://client"), ) .arg( Arg::new("zclient-connect") .long("zclient-connect") .action(ArgAction::SetTrue) .help("ZeroMQ client sockets should connect instead of bind"), ) .arg( Arg::new("zserver-req") .long("zserver-req") .num_args(1) .value_name("spec") .action(ArgAction::Append) .help("ZeroMQ server REQ spec"), ) .arg( Arg::new("zserver-stream") .long("zserver-stream") .num_args(1) .value_name("spec-base") .action(ArgAction::Append) .help("ZeroMQ server PULL/ROUTER/PUB spec base"), ) .arg( Arg::new("zserver-connect") .long("zserver-connect") .action(ArgAction::SetTrue) .help("ZeroMQ server sockets should connect instead of bind"), ) .arg( Arg::new("ipc-file-mode") .long("ipc-file-mode") .num_args(1) .value_name("octal") .help("Permissions for ZeroMQ IPC binds"), ) .arg( Arg::new("tls-identities-dir") .long("tls-identities-dir") .num_args(1) .value_name("directory") .help("Directory containing certificates and private keys") .default_value("."), ) .arg( Arg::new("compression") .long("compression") .action(ArgAction::SetTrue) .help("Allow compression to be used"), ) .arg( Arg::new("deny-out-internal") .long("deny-out-internal") .action(ArgAction::SetTrue) .help("Block outbound connections to local/internal IP address ranges"), ) .arg( Arg::new("sizes") .long("sizes") .action(ArgAction::SetTrue) .help("Prints sizes of tasks and other objects"), ) .get_matches(); log::set_logger(get_simple_logger()).unwrap(); log::set_max_level(LevelFilter::Info); let level = matches.get_one::("log-level").unwrap(); let level: usize = match level.parse() { Ok(x) => x, Err(e) => { error!("failed to parse log-level: {}", e); process::exit(1); } }; let level = match level { 0 => LevelFilter::Error, 1 => LevelFilter::Warn, 2 => LevelFilter::Info, 3 => LevelFilter::Debug, 4..=usize::MAX => LevelFilter::Trace, _ => unreachable!(), }; log::set_max_level(level); local_offset_check(); if *matches.get_one("sizes").unwrap() { for (name, size) in App::sizes() { println!("{}: {} bytes", name, size); } process::exit(0); } let id = matches.get_one::("id").unwrap(); let workers = matches.get_one::("workers").unwrap(); let workers: usize = match workers.parse() { Ok(x) => x, Err(e) => { error!("failed to parse workers: {}", e); process::exit(1); } }; let req_maxconn = matches.get_one::("req-maxconn").unwrap(); let req_maxconn: usize = match req_maxconn.parse() { Ok(x) => x, Err(e) => { error!("failed to parse req-maxconn: {}", e); process::exit(1); } }; let stream_maxconn = matches.get_one::("stream-maxconn").unwrap(); let stream_maxconn: usize = match stream_maxconn.parse() { Ok(x) => x, Err(e) => { error!("failed to parse stream-maxconn: {}", e); process::exit(1); } }; let buffer_size = matches.get_one::("buffer-size").unwrap(); let buffer_size: usize = match buffer_size.parse() { Ok(x) => x, Err(e) => { error!("failed to parse buffer-size: {}", e); process::exit(1); } }; let body_buffer_size = matches.get_one::("body-buffer-size").unwrap(); let body_buffer_size: usize = match body_buffer_size.parse() { Ok(x) => x, Err(e) => { error!("failed to parse body-buffer-size: {}", e); process::exit(1); } }; let blocks_max: usize = match matches.get_one::("blocks-max") { Some(v) => match v.parse() { Ok(x) => x, Err(e) => { error!("failed to parse blocks-max: {}", e); process::exit(1); } }, None => stream_maxconn * 2, }; let connection_blocks_max = matches.get_one::("connection-blocks-max").unwrap(); let connection_blocks_max: usize = match connection_blocks_max.parse() { Ok(x) => x, Err(e) => { error!("failed to parse connection-blocks-max: {}", e); process::exit(1); } }; let messages_max = matches.get_one::("messages-max").unwrap(); let messages_max: usize = match messages_max.parse() { Ok(x) => x, Err(e) => { error!("failed to parse messages-max: {}", e); process::exit(1); } }; let req_timeout = matches.get_one::("req-timeout").unwrap(); let req_timeout: usize = match req_timeout.parse() { Ok(x) => x, Err(e) => { error!("failed to parse req-timeout: {}", e); process::exit(1); } }; let stream_timeout = matches.get_one::("stream-timeout").unwrap(); let stream_timeout: usize = match stream_timeout.parse() { Ok(x) => x, Err(e) => { error!("failed to parse stream-timeout: {}", e); process::exit(1); } }; let mut listen: Vec = matches .get_many::("listen") .unwrap_or_default() .map(|v| v.to_owned()) .collect(); let zclient_req_specs: Vec = matches .get_many::("zclient-req") .unwrap() .map(|v| v.to_owned()) .collect(); let zclient_stream_specs: Vec = matches .get_many::("zclient-stream") .unwrap() .map(|v| v.to_owned()) .collect(); let zclient_connect = *matches.get_one("zclient-connect").unwrap(); let zserver_req_specs: Vec = matches .get_many::("zserver-req") .unwrap_or_default() .map(|v| v.to_owned()) .collect(); let zserver_stream_specs: Vec = matches .get_many::("zserver-stream") .unwrap_or_default() .map(|v| v.to_owned()) .collect(); let zserver_connect = *matches.get_one("zserver-connect").unwrap(); let ipc_file_mode = matches .get_one::("ipc-file-mode") .cloned() .unwrap_or_else(|| String::from("0")); let ipc_file_mode = match u32::from_str_radix(&ipc_file_mode, 8) { Ok(x) => x, Err(e) => { error!("failed to parse ipc-file-mode: {}", e); process::exit(1); } }; let tls_identities_dir = matches.get_one::("tls-identities-dir").unwrap(); let allow_compression = *matches.get_one("compression").unwrap(); let deny_out_internal = *matches.get_one("deny-out-internal").unwrap(); // if no zmq server specs are set (needed by client mode), specify // default listen configuration in order to enable server mode. this // means if zmq server specs are set, then server mode won't be enabled // by default if listen.is_empty() && zserver_req_specs.is_empty() && zserver_stream_specs.is_empty() { listen.push("0.0.0.0:8000,stream".to_string()); } let args = Args { id: id.to_string(), workers, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, blocks_max, connection_blocks_max, messages_max, req_timeout, stream_timeout, listen, zclient_req_specs, zclient_stream_specs, zclient_connect, zserver_req_specs, zserver_stream_specs, zserver_connect, ipc_file_mode, tls_identities_dir: tls_identities_dir.to_string(), allow_compression, deny_out_internal, }; if let Err(e) = process_args_and_run(args) { error!("{}", e); process::exit(1); } } pushpin-1.41.0/src/bin/pushpin-handler.rs000066400000000000000000000016121504671364300203020ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use pushpin::core::call_c_main; use pushpin::import_cpp; use std::env; use std::process::ExitCode; import_cpp! { fn handler_main(argc: libc::c_int, argv: *const *const libc::c_char) -> libc::c_int; } fn main() -> ExitCode { unsafe { ExitCode::from(call_c_main(handler_main, env::args_os())) } } pushpin-1.41.0/src/bin/pushpin-proxy.rs000066400000000000000000000016061504671364300200510ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use pushpin::core::call_c_main; use pushpin::import_cpp; use std::env; use std::process::ExitCode; import_cpp! { fn proxy_main(argc: libc::c_int, argv: *const *const libc::c_char) -> libc::c_int; } fn main() -> ExitCode { unsafe { ExitCode::from(call_c_main(proxy_main, env::args_os())) } } pushpin-1.41.0/src/bin/pushpin-publish.rs000066400000000000000000000204371504671364300203410ustar00rootroot00000000000000/* * Copyright (C) 2021-2023 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ use clap::{Arg, ArgAction, Command}; use pushpin::core::version; use pushpin::publish::{run, Action, Config, Content, Message}; use std::env; use std::error::Error; use std::process; const PROGRAM_NAME: &str = "pushpin-publish"; const DEFAULT_SPEC: &str = "http://localhost:5561"; struct Args { channel: String, content: Option, id: String, prev_id: String, sender: String, code: u16, headers: Vec, meta: Vec, hint: bool, close: bool, patch: bool, no_seq: bool, no_eol: bool, spec: String, user: Option, } fn process_args_and_run(args: Args) -> Result<(), Box> { let action = if args.hint { Action::Hint } else if args.close { Action::Close } else { if args.code > 999 { return Err("code must be an integer between 0 and 999".into()); } let content = match args.content { Some(s) => s, None => return Err("must specify content".into()), }; let content = if args.patch { let v: serde_json::Value = serde_json::from_str(&content)?; let arr = match v { serde_json::Value::Array(arr) => arr, _ => return Err("patch content must be a JSON array".into()), }; Content::Patch(arr) } else { Content::Value(content) }; Action::Send(Message { code: args.code, content, }) }; let mut headers = Vec::new(); for v in args.headers { let pos = match v.find(':') { Some(pos) => pos, None => return Err("header must be in the form \"name: value\"".into()), }; let name = &v[..pos]; let val = &v[(pos + 1)..].trim(); headers.push((name.to_string(), val.to_string())); } let mut meta = Vec::new(); for v in args.meta { let pos = match v.find('=') { Some(pos) => pos, None => return Err("meta must be in the form \"name=value\"".into()), }; let name = &v[..pos]; let val = &v[(pos + 1)..].trim(); meta.push((name.to_string(), val.to_string())); } let config = Config { spec: args.spec, basic_auth: args.user, channel: args.channel, id: args.id, prev_id: args.prev_id, sender: args.sender, action, headers, meta, no_seq: args.no_seq, eol: !args.no_eol, }; run(&config) } fn main() { let default_spec = match env::var("GRIP_URL") { Ok(s) => s, Err(_) => DEFAULT_SPEC.to_string(), }; let matches = Command::new(PROGRAM_NAME) .version(version()) .about("Publish messages to Pushpin") .arg( Arg::new("channel") .required(true) .num_args(1) .value_name("channel") .help("Channel to send to"), ) .arg( Arg::new("content") .num_args(1) .value_name("content") .help("Content to use for HTTP body and WebSocket message"), ) .arg( Arg::new("id") .long("id") .num_args(1) .value_name("id") .help("Payload ID"), ) .arg( Arg::new("prev-id") .long("prev-id") .num_args(1) .value_name("id") .help("Previous payload ID"), ) .arg( Arg::new("sender") .long("sender") .num_args(1) .value_name("sender") .help("Sender meta value"), ) .arg( Arg::new("code") .long("code") .num_args(1) .value_name("code") .help("HTTP response code to use") .default_value("200"), ) .arg( Arg::new("header") .short('H') .long("header") .num_args(1) .value_name("\"K: V\"") .action(ArgAction::Append) .help("Add HTTP response header"), ) .arg( Arg::new("meta") .short('M') .long("meta") .num_args(1) .value_name("\"K=V\"") .action(ArgAction::Append) .help("Add meta variable"), ) .arg( Arg::new("hint") .long("hint") .action(ArgAction::SetTrue) .help("Send hint instead of content"), ) .arg( Arg::new("close") .long("close") .action(ArgAction::SetTrue) .help("Close streaming and WebSocket connections"), ) .arg( Arg::new("patch") .long("patch") .action(ArgAction::SetTrue) .help("Content is JSON patch"), ) .arg( Arg::new("no-seq") .long("no-seq") .action(ArgAction::SetTrue) .help("Bypass sequencing buffer"), ) .arg( Arg::new("no-eol") .long("no-eol") .action(ArgAction::SetTrue) .help("Don't add newline to HTTP payloads"), ) .arg( Arg::new("spec") .long("spec") .num_args(1) .value_name("spec") .help("GRIP URL or ZeroMQ PUSH spec") .default_value(default_spec), ) .arg( Arg::new("user") .short('u') .long("user") .num_args(1) .value_name("user:pass") .help("Authenticate using basic auth"), ) .get_matches(); let channel = matches.get_one::("channel").unwrap().clone(); let content = matches.get_one::("content").cloned(); let id = matches.get_one::("id").cloned().unwrap_or_default(); let prev_id = matches .get_one::("prev-id") .cloned() .unwrap_or_default(); let sender = matches .get_one::("sender") .cloned() .unwrap_or_default(); let code = matches.get_one::("code").unwrap(); let code: u16 = match code.parse() { Ok(x) => x, Err(e) => { eprintln!("Error: failed to parse code: {}", e); process::exit(1); } }; let headers = matches .get_many::("header") .unwrap_or_default() .map(|v| v.to_owned()) .collect(); let meta = matches .get_many::("meta") .unwrap_or_default() .map(|v| v.to_owned()) .collect(); let hint = *matches.get_one("hint").unwrap(); let close = *matches.get_one("close").unwrap(); let patch = *matches.get_one("patch").unwrap(); let no_seq = *matches.get_one("no-seq").unwrap(); let no_eol = *matches.get_one("no-eol").unwrap(); let spec = matches.get_one::("spec").unwrap().clone(); let user = matches.get_one::("user").cloned(); let args = Args { channel, content, id, prev_id, sender, code, headers, meta, hint, close, patch, no_seq, no_eol, spec, user, }; if let Err(e) = process_args_and_run(args) { eprintln!("Error: {}", e); process::exit(1); } } pushpin-1.41.0/src/bin/pushpin.rs000066400000000000000000000016101504671364300166650ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use pushpin::core::call_c_main; use pushpin::import_cpp; use std::env; use std::process::ExitCode; import_cpp! { fn runner_main(argc: libc::c_int, argv: *const *const libc::c_char) -> libc::c_int; } fn main() -> ExitCode { unsafe { ExitCode::from(call_c_main(runner_main, env::args_os())) } } pushpin-1.41.0/src/connmgr/000077500000000000000000000000001504671364300155265ustar00rootroot00000000000000pushpin-1.41.0/src/connmgr/batch.rs000066400000000000000000000232561504671364300171650ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::connmgr::zhttppacket; use crate::connmgr::zhttpsocket::FROM_MAX; use crate::core::arena; use crate::core::list; use arrayvec::ArrayVec; use slab::Slab; use std::convert::TryFrom; pub struct BatchKey { addr_index: usize, nkey: usize, } pub struct BatchGroup<'a, 'b> { addr: &'b [u8], use_router: bool, ids: arena::ReusableVecHandle<'b, zhttppacket::Id<'a>>, } impl<'a> BatchGroup<'a, '_> { pub fn addr(&self) -> &[u8] { self.addr } pub fn use_router(&self) -> bool { self.use_router } pub fn ids(&self) -> &[zhttppacket::Id<'a>] { &self.ids } } struct AddrItem { addr: ArrayVec, use_router: bool, keys: list::List, } pub struct Batch { nodes: Slab>, addrs: Vec, addr_index: usize, group_ids: arena::ReusableVec, last_group_ckeys: Vec, } impl Batch { pub fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), addrs: Vec::with_capacity(capacity), addr_index: 0, group_ids: arena::ReusableVec::new::(capacity), last_group_ckeys: Vec::with_capacity(capacity), } } pub fn len(&self) -> usize { self.nodes.len() } pub fn capacity(&self) -> usize { self.nodes.capacity() } pub fn is_empty(&self) -> bool { self.nodes.is_empty() } pub fn clear(&mut self) { self.addrs.clear(); self.nodes.clear(); self.addr_index = 0; } pub fn add(&mut self, to_addr: &[u8], use_router: bool, ckey: usize) -> Result { if self.nodes.len() == self.nodes.capacity() { return Err(()); } // if all existing nodes have been removed via remove() or take_group(), // such that is_empty() returns true, start clean if self.nodes.is_empty() { self.addrs.clear(); self.addr_index = 0; } let mut pos = self.addrs.len(); for (n, ai) in self.addrs.iter().enumerate() { if ai.addr.as_slice() == to_addr && ai.use_router == use_router { pos = n; } } if pos == self.addrs.len() { if self.addrs.len() == self.addrs.capacity() { return Err(()); } // connection limits to_addr to FROM_MAX so this is guaranteed to succeed let addr = ArrayVec::try_from(to_addr).unwrap(); self.addrs.push(AddrItem { addr, use_router, keys: list::List::default(), }); } else { // adding not allowed if take_group() has already moved past the index if pos < self.addr_index { return Err(()); } } let nkey = self.nodes.insert(list::Node::new(ckey)); self.addrs[pos].keys.push_back(&mut self.nodes, nkey); Ok(BatchKey { addr_index: pos, nkey, }) } pub fn remove(&mut self, key: BatchKey) { self.addrs[key.addr_index] .keys .remove(&mut self.nodes, key.nkey); self.nodes.remove(key.nkey); } pub fn take_group<'a, 'b: 'a, F>(&'a mut self, get_id: F) -> Option> where F: Fn(usize) -> Option<(&'b [u8], u32)>, { let addrs = &mut self.addrs; let mut ids = self.group_ids.get_as_new(); while ids.is_empty() { // find the next addr with items while self.addr_index < addrs.len() && addrs[self.addr_index].keys.is_empty() { self.addr_index += 1; } // if all are empty, we're done if self.addr_index == addrs.len() { assert!(self.nodes.is_empty()); return None; } let keys = &mut addrs[self.addr_index].keys; self.last_group_ckeys.clear(); ids.clear(); // get ids/seqs while ids.len() < zhttppacket::IDS_MAX { let nkey = match keys.pop_front(&mut self.nodes) { Some(nkey) => nkey, None => break, }; let ckey = self.nodes[nkey].value; self.nodes.remove(nkey); if let Some((id, seq)) = get_id(ckey) { self.last_group_ckeys.push(ckey); ids.push(zhttppacket::Id { id, seq: Some(seq) }); } } } let ai = &addrs[self.addr_index]; Some(BatchGroup { addr: &ai.addr, use_router: ai.use_router, ids, }) } pub fn last_group_ckeys(&self) -> &[usize] { &self.last_group_ckeys } } #[cfg(test)] mod tests { use super::*; #[test] fn add_take() { let ids = ["id-1", "id-2", "id-3", "id-4"]; let mut batch = Batch::new(4); assert_eq!(batch.capacity(), 4); assert_eq!(batch.len(), 0); assert!(batch.last_group_ckeys().is_empty()); assert!(batch.add(b"addr-a", false, 1).is_ok()); assert!(batch.add(b"addr-a", false, 2).is_ok()); assert!(batch.add(b"addr-b", false, 3).is_ok()); assert!(batch.add(b"addr-b", true, 4).is_ok()); assert_eq!(batch.len(), 4); assert!(batch.add(b"addr-c", false, 5).is_err()); assert_eq!(batch.len(), 4); assert_eq!(batch.is_empty(), false); let group = batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .unwrap(); assert_eq!(group.ids().len(), 2); assert_eq!(group.ids()[0].id, b"id-1"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.ids()[1].id, b"id-2"); assert_eq!(group.ids()[1].seq, Some(0)); assert_eq!(group.addr(), b"addr-a"); assert!(!group.use_router()); drop(group); assert_eq!(batch.is_empty(), false); assert_eq!(batch.last_group_ckeys(), &[1, 2]); let group = batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-3"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); assert!(!group.use_router()); drop(group); assert_eq!(batch.is_empty(), false); assert_eq!(batch.last_group_ckeys(), &[3]); let group = batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-4"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); assert!(group.use_router()); drop(group); assert_eq!(batch.is_empty(), true); assert_eq!(batch.last_group_ckeys(), &[4]); assert!(batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .is_none()); assert_eq!(batch.last_group_ckeys(), &[4]); } #[test] fn add_remove_take() { let ids = ["id-1", "id-2", "id-3"]; let mut batch = Batch::new(3); let bkey = batch.add(b"addr-a", false, 1).unwrap(); assert!(batch.add(b"addr-b", false, 2).is_ok()); assert_eq!(batch.len(), 2); batch.remove(bkey); assert_eq!(batch.len(), 1); let group = batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-2"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); drop(group); assert_eq!(batch.is_empty(), true); assert!(batch.add(b"addr-a", false, 3).is_ok()); assert_eq!(batch.len(), 1); assert!(!batch.is_empty()); let group = batch .take_group(|ckey| Some((ids[ckey - 1].as_bytes(), 0))) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-3"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-a"); drop(group); assert_eq!(batch.is_empty(), true); } #[test] fn add_take_omit() { let ids = ["id-1", "id-2", "id-3"]; let mut batch = Batch::new(3); assert!(batch.add(b"addr-a", false, 1).is_ok()); assert!(batch.add(b"addr-b", false, 2).is_ok()); assert!(batch.add(b"addr-b", false, 3).is_ok()); let group = batch .take_group(|ckey| { if ckey < 3 { None } else { Some((ids[ckey - 1].as_bytes(), 0)) } }) .unwrap(); assert_eq!(group.ids().len(), 1); assert_eq!(group.ids()[0].id, b"id-3"); assert_eq!(group.ids()[0].seq, Some(0)); assert_eq!(group.addr(), b"addr-b"); drop(group); assert_eq!(batch.is_empty(), true); } } pushpin-1.41.0/src/connmgr/client.rs000066400000000000000000002757231504671364300173720ustar00rootroot00000000000000/* * Copyright (C) 2023 Fanout, Inc. * Copyright (C) 2023-2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::connmgr::batch::{Batch, BatchKey}; use crate::connmgr::connection::{ client_req_connection, client_stream_connection, make_zhttp_response, ConnectionPool, StreamSharedData, }; use crate::connmgr::counter::Counter; use crate::connmgr::resolver::Resolver; use crate::connmgr::tls::TlsConfigCache; use crate::connmgr::zhttppacket; use crate::connmgr::zhttpsocket::{self, SessionKey, FROM_MAX, REQ_ID_MAX}; use crate::core::arena; use crate::core::buffer::TmpBuffer; use crate::core::channel::{self, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver}; use crate::core::event; use crate::core::executor::{Executor, Spawner}; use crate::core::list; use crate::core::reactor::Reactor; use crate::core::select::{select_2, select_5, select_6, select_option, Select2, Select5, Select6}; use crate::core::task::{self, yield_to_local_events, CancellationSender, CancellationToken}; use crate::core::time::Timeout; use crate::core::tnetstring; use crate::core::zmq::{MultipartHeader, SpecInfo}; use arrayvec::ArrayVec; use ipnet::IpNet; use log::{debug, error, warn}; use mio::unix::SourceFd; use slab::Slab; use std::cell::Cell; use std::cell::RefCell; use std::collections::{HashMap, VecDeque}; use std::convert::TryFrom; use std::io::{self, Write}; use std::mem; use std::pin::pin; use std::rc::Rc; use std::str; use std::sync::{mpsc, Arc}; use std::thread; use std::time::Duration; const REQ_SENDER_BOUND: usize = 1; // we read and process each request message one at a time, wrapping it in an // rc, and sending it to connections via channels. on the other side of each // channel, the message is received and processed immediately, except for the // first message. this means the max number of messages retained per // connection is the channel bound per connection plus one pub const MSG_RETAINED_PER_CONNECTION_MAX: usize = REQ_SENDER_BOUND + 1; // the max number of messages retained outside of connections is one per // handle we read from (req and stream), in preparation for sending to any // connections pub const MSG_RETAINED_PER_WORKER_MAX: usize = 2; // run x1 // req_handle_task x1 // stream_handle_task x1 // keep_alives_task x1 const WORKER_NON_CONNECTION_TASKS_MAX: usize = 10; // this is meant to be an average max of registrations per task, in order // to determine the total number of registrations sufficient for all tasks, // however it is not enforced per task const REGISTRATIONS_PER_TASK_MAX: usize = 32; const REACTOR_BUDGET: u32 = 100; const KEEP_ALIVE_TIMEOUT_MS: usize = 45_000; const KEEP_ALIVE_BATCH_MS: usize = 100; const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(KEEP_ALIVE_BATCH_MS as u64); const KEEP_ALIVE_BATCHES: usize = KEEP_ALIVE_TIMEOUT_MS / KEEP_ALIVE_BATCH_MS; const BULK_PACKET_SIZE_MAX: usize = 65_000; const SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10_000); const RESOLVER_THREADS: usize = 10; fn local_channel( bound: usize, max_senders: usize, ) -> (channel::LocalSender, channel::LocalReceiver) { let (s, r) = channel::local_channel( bound, max_senders, &Reactor::current().unwrap().local_registration_memory(), ); (s, r) } fn async_local_channel( bound: usize, max_senders: usize, ) -> (AsyncLocalSender, AsyncLocalReceiver) { let (s, r) = local_channel(bound, max_senders); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); (s, r) } enum BatchType { KeepAlive, Cancel, } struct ChannelPool { items: RefCell, channel::LocalReceiver)>>, } impl ChannelPool { fn new(capacity: usize) -> Self { Self { items: RefCell::new(VecDeque::with_capacity(capacity)), } } fn take(&self) -> Option<(channel::LocalSender, channel::LocalReceiver)> { let p = &mut *self.items.borrow_mut(); p.pop_back() } fn push(&self, pair: (channel::LocalSender, channel::LocalReceiver)) { let p = &mut *self.items.borrow_mut(); p.push_back(pair); } } struct ConnectionDone { ckey: usize, } struct ConnectionItem { id: Option, stop: Option, zreceiver_sender: Option, usize)>>, shared: Option>, batch_key: Option, } struct ConnectionItems { nodes: Slab>, nodes_by_id: HashMap, batch: Batch, } impl ConnectionItems { fn new(capacity: usize, batch: Batch) -> Self { Self { nodes: Slab::with_capacity(capacity), nodes_by_id: HashMap::with_capacity(capacity), batch, } } } struct ConnectionsInner { active: list::List, count: usize, max: usize, } struct Connections { items: Rc>, inner: RefCell, } impl Connections { fn new(items: Rc>, max: usize) -> Self { Self { items, inner: RefCell::new(ConnectionsInner { active: list::List::default(), count: 0, max, }), } } fn count(&self) -> usize { self.inner.borrow().count } fn max(&self) -> usize { self.inner.borrow().max } fn add( &self, stop: CancellationSender, zreceiver_sender: Option< channel::LocalSender<(arena::Rc, usize)>, >, shared: Option>, ) -> Result { let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); if items.nodes.len() == items.nodes.capacity() { return Err(()); } let nkey = items.nodes.insert(list::Node::new(ConnectionItem { id: None, stop: Some(stop), zreceiver_sender, shared, batch_key: None, })); c.active.push_back(&mut items.nodes, nkey); c.count += 1; Ok(nkey) } // return zreceiver_sender fn remove( &self, ckey: usize, ) -> Option, usize)>> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } c.active.remove(&mut items.nodes, nkey); c.count -= 1; let ci = items.nodes.remove(nkey).value; if let Some(id) = &ci.id { items.nodes_by_id.remove(id); } ci.zreceiver_sender } fn set_id(&self, ckey: usize, id: Option<&SessionKey>) { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; // unset current id, if any if let Some(cur_id) = &ci.id { items.nodes_by_id.remove(cur_id); ci.id = None; } if let Some(id) = id.cloned() { ci.id = Some(id.clone()); items.nodes_by_id.insert(id, nkey); } else { // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } } } fn find_key(&self, id: &SessionKey) -> Option { let items = &*self.items.borrow(); items.nodes_by_id.get(id).copied() } fn try_send( &self, ckey: usize, value: (arena::Rc, usize), ) -> Result<(), mpsc::TrySendError<(arena::Rc, usize)>> { let nkey = ckey; let items = &*self.items.borrow(); let ci = &items.nodes[nkey].value; let sender = match &ci.zreceiver_sender { Some(s) => s, None => return Err(mpsc::TrySendError::Disconnected(value)), }; sender.try_send(value) } fn stop_all(&self, about_to_stop: F) where F: Fn(usize), { let items = &mut *self.items.borrow_mut(); let cinner = &*self.inner.borrow_mut(); let mut next = cinner.active.head; while let Some(nkey) = next { let n = &mut items.nodes[nkey]; let ci = &mut n.value; about_to_stop(nkey); ci.stop = None; next = n.next; } } fn items_capacity(&self) -> usize { self.items.borrow().nodes.capacity() } fn can_stream(&self, ckey: usize) -> bool { let items = &*self.items.borrow(); match items.nodes.get(ckey) { Some(n) => { let ci = &n.value; // is stream mode with an id ci.shared.is_some() && ci.id.is_some() } None => false, } } fn batch_is_empty(&self) -> bool { let items = &*self.items.borrow(); items.batch.is_empty() } fn batch_len(&self) -> usize { let items = &*self.items.borrow(); items.batch.len() } fn batch_capacity(&self) -> usize { let items = &*self.items.borrow(); items.batch.capacity() } fn batch_clear(&self) { let items = &mut *self.items.borrow_mut(); items.batch.clear(); } fn batch_add(&self, ckey: usize) -> Result<(), ()> { let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // only batch connections with known handler addresses let addr_ref = cshared.to_addr(); let addr = match addr_ref.get() { Some(addr) => addr, None => return Err(()), }; let bkey = items.batch.add(addr, cshared.router_resp(), ckey)?; ci.batch_key = Some(bkey); Ok(()) } fn next_batch_message( &self, from: &str, btype: BatchType, ) -> Option<(usize, Option>, zmq::Message)> { let items = &mut *self.items.borrow_mut(); let nodes = &mut items.nodes; let batch = &mut items.batch; while !batch.is_empty() { let group = { let group = batch.take_group(|ckey| { let ci = &nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // addr could have been removed after adding to the batch cshared.to_addr().get()?; // item is guaranteed to have an id. only items with an // id are added to a batch, and if an item's id is // removed then the item is removed from the batch let id = ci.id.as_ref().unwrap(); Some((&id.1, cshared.out_seq())) }); match group { Some(group) => group, None => continue, } }; let count = group.ids().len(); assert!(count <= zhttppacket::IDS_MAX); let zresp = zhttppacket::Response { from: from.as_bytes(), ids: group.ids(), multi: true, ptype: match btype { BatchType::KeepAlive => zhttppacket::ResponsePacket::KeepAlive, BatchType::Cancel => zhttppacket::ResponsePacket::Cancel, }, ptype_str: "", }; let mut scratch = [0; BULK_PACKET_SIZE_MAX]; let (addr, msg) = match make_zhttp_response(group.addr(), group.use_router(), zresp, &mut scratch) { Ok(resp) => resp, Err(e) => { error!( "failed to serialize keep-alive packet with {} ids: {}", count, e ); continue; } }; drop(group); for &ckey in batch.last_group_ckeys() { let ci = &mut nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); cshared.inc_out_seq(); ci.batch_key = None; } return Some((count, addr, msg)); } None } } #[derive(Clone)] struct ConnectionOpts { instance_id: Rc, buffer_size: usize, timeout: Duration, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } struct ConnectionReqOpts { body_buffer_size: usize, sender: channel::LocalSender<(MultipartHeader, zmq::Message)>, } struct ConnectionStreamOpts { blocks_max: usize, blocks_avail: Arc, messages_max: usize, allow_compression: bool, sender: channel::LocalSender<(Option>, zmq::Message)>, } struct Worker { thread: Option>, stop: Option>, } impl Worker { #[allow(clippy::too_many_arguments)] fn new( instance_id: &str, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, connection_blocks_max: usize, blocks_avail: &Arc, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: &[IpNet], resolver: &Arc, tls_config_cache: &Arc, pool: &Arc, zsockman: &Arc, handle_bound: usize, ) -> Self { debug!("client worker {}: starting", id); let (stop, r_stop) = channel::channel(1); let (s_ready, ready) = channel::channel(1); let instance_id = String::from(instance_id); let blocks_avail = Arc::clone(blocks_avail); let deny = deny.to_vec(); let resolver = Arc::clone(resolver); let tls_config_cache = Arc::clone(tls_config_cache); let pool = Arc::clone(pool); let zsockman = Arc::clone(zsockman); let thread = thread::Builder::new() .name(format!("client-worker-{}", id)) .spawn(move || { let maxconn = req_maxconn + stream_maxconn; // 1 task per connection, plus a handful of supporting tasks let tasks_max = maxconn + WORKER_NON_CONNECTION_TASKS_MAX; let registrations_max = REGISTRATIONS_PER_TASK_MAX * tasks_max; let reactor = Reactor::new(registrations_max); let executor = Executor::new(tasks_max); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(REACTOR_BUDGET)); }); } executor .spawn(Self::run( r_stop, s_ready, instance_id, id, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, connection_blocks_max, blocks_avail, messages_max, req_timeout, stream_timeout, allow_compression, deny, resolver, tls_config_cache, pool, zsockman, handle_bound, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("client worker {}: stopped", id); }) .unwrap(); ready.recv().unwrap(); Self { thread: Some(thread), stop: Some(stop), } } fn stop(&mut self) { self.stop = None; } #[allow(clippy::too_many_arguments)] async fn run( stop: channel::Receiver<()>, ready: channel::Sender<()>, instance_id: String, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, connection_blocks_max: usize, blocks_avail: Arc, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: Vec, resolver: Arc, tls_config_cache: Arc, pool: Arc, zsockman: Arc, handle_bound: usize, ) { let executor = Executor::current().unwrap(); let reactor = Reactor::current().unwrap(); let stop = AsyncReceiver::new(stop); debug!("client-worker {}: allocating buffers", id); let rb_tmp = Rc::new(TmpBuffer::new(buffer_size * connection_blocks_max)); // large enough to fit anything let packet_buf = Rc::new(RefCell::new(vec![0; buffer_size + body_buffer_size + 4096])); // same size as working buffers let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let instance_id = Rc::new(instance_id); let ka_batch = stream_maxconn.div_ceil(KEEP_ALIVE_BATCHES); let batch = Batch::new(ka_batch); let maxconn = req_maxconn + stream_maxconn; let conn_items = Rc::new(RefCell::new(ConnectionItems::new(maxconn, batch))); let req_conns = Rc::new(Connections::new(conn_items.clone(), req_maxconn)); let stream_conns = Rc::new(Connections::new(conn_items.clone(), stream_maxconn)); let (req_handle_stop, r_req_handle_stop) = async_local_channel(1, 1); let (stream_handle_stop, r_stream_handle_stop) = async_local_channel(1, 1); let (keep_alives_stop, r_keep_alives_stop) = async_local_channel(1, 1); let (s_req_handle_done, req_handle_done) = async_local_channel(1, 1); let (s_stream_handle_done, stream_handle_done) = async_local_channel(1, 1); let (s_keep_alives_done, keep_alives_done) = async_local_channel(1, 1); // max_senders is 1 per connection + 1 for the handle task + 1 for the keep alive task let (zstream_out_sender, zstream_out_receiver) = local_channel(handle_bound, stream_maxconn + 2); let zstream_out_receiver = AsyncLocalReceiver::new(zstream_out_receiver); let req_handle = zhttpsocket::AsyncServerReqHandle::new(zsockman.server_req_handle()); let stream_handle = zhttpsocket::AsyncServerStreamHandle::new(zsockman.server_stream_handle()); let deny = Rc::new(deny); executor .spawn(Self::req_handle_task( id, r_req_handle_stop, s_req_handle_done, executor.spawner(), Arc::clone(&resolver), Arc::clone(&tls_config_cache), Arc::clone(&pool), req_handle, req_maxconn, req_conns, body_buffer_size, Rc::clone(&deny), handle_bound, ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: req_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, )) .unwrap(); { let zstream_out_sender = zstream_out_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); executor .spawn(Self::stream_handle_task( id, r_stream_handle_stop, s_stream_handle_done, zstream_out_receiver, zstream_out_sender, executor.spawner(), Arc::clone(&resolver), Arc::clone(&tls_config_cache), Arc::clone(&pool), stream_handle, stream_maxconn, stream_conns.clone(), connection_blocks_max, blocks_avail, messages_max, allow_compression, Rc::clone(&deny), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: stream_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, )) .unwrap(); } executor .spawn(Self::keep_alives_task( id, r_keep_alives_stop, s_keep_alives_done, instance_id.clone(), zstream_out_sender, stream_conns.clone(), )) .unwrap(); debug!("client-worker {}: started", id); ready.send(()).unwrap(); drop(ready); // wait for stop let _ = stop.recv().await; // stop keep alives drop(keep_alives_stop); let _ = keep_alives_done.recv().await; // stop remaining tasks drop(req_handle_stop); drop(stream_handle_stop); let _ = req_handle_done.recv().await; let stream_handle = stream_handle_done.recv().await.unwrap(); // send cancels stream_conns.batch_clear(); let now = reactor.now(); let shutdown_timeout = Timeout::new(now + SHUTDOWN_TIMEOUT); let mut next_cancel_index = 0; 'outer: while next_cancel_index < stream_conns.items_capacity() { while stream_conns.batch_len() < stream_conns.batch_capacity() && next_cancel_index < stream_conns.items_capacity() { let key = next_cancel_index; next_cancel_index += 1; if stream_conns.can_stream(key) { // ignore errors let _ = stream_conns.batch_add(key); } } while let Some((count, addr, msg)) = stream_conns.next_batch_message(&instance_id, BatchType::Cancel) { debug!( "client-worker {}: sending cancels for {} sessions", id, count ); match select_2( pin!(stream_handle.send(addr, msg)), shutdown_timeout.elapsed(), ) .await { Select2::R1(r) => r.unwrap(), Select2::R2(_) => break 'outer, } } } } #[allow(clippy::too_many_arguments)] async fn req_handle_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, spawner: Spawner, resolver: Arc, tls_config_cache: Arc, conn_pool: Arc, req_handle: zhttpsocket::AsyncServerReqHandle, req_maxconn: usize, conns: Rc, body_buffer_size: usize, deny: Rc>, handle_bound: usize, opts: ConnectionOpts, ) { let reactor = Reactor::current().unwrap(); let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * req_maxconn); let req_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let req_req_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); // max_senders is 1 per connection + 1 for this task let (zreq_sender, zreq_receiver) = local_channel(handle_bound, req_maxconn + 1); let zreq_receiver = AsyncLocalReceiver::new(zreq_receiver); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for this task let (s_cdone, r_cdone) = channel::local_channel::( conns.max(), conns.max() + 1, &reactor.local_registration_memory(), ); let r_cdone = AsyncLocalReceiver::new(r_cdone); debug!("client-worker {}: task started: req_handle", id); let mut handle_send = pin!(None); loop { let receiver_recv = if handle_send.is_none() { Some(zreq_receiver.recv()) } else { None }; let req_handle_recv = if conns.count() < conns.max() { Some(req_handle.recv()) } else { None }; match select_5( stop.recv(), select_option(receiver_recv), select_option(handle_send.as_mut().as_pin_mut()), r_cdone.recv(), select_option(pin!(req_handle_recv).as_pin_mut()), ) .await { // stop.recv Select5::R1(_) => break, // receiver_recv Select5::R2(result) => match result { Ok((header, msg)) => handle_send.set(Some(req_handle.send(header, msg))), Err(e) => panic!("zreq_receiver channel error: {}", e), }, // handle_send Select5::R3(result) => { handle_send.set(None); if let Err(e) = result { error!("req send error: {}", e); } } // r_cdone.recv Select5::R4(result) => match result { Ok(done) => { let ret = conns.remove(done.ckey); // req mode doesn't have a sender assert!(ret.is_none()); } Err(e) => panic!("r_cdone channel error: {}", e), }, // req_handle_recv Select5::R5(result) => match result { Ok((header, msg)) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq_ref = zreq.get(); let ids = zreq_ref.ids; if ids.len() > 1 { warn!( "client-worker {}: request contained more than one id, skipping", id ); continue; } let from: ArrayVec = match ArrayVec::try_from(zreq_ref.from) { Ok(v) => v, Err(_) => { warn!("client-worker {}: from address too long, skipping", id); continue; } }; let cid: Option> = if !ids.is_empty() { match ArrayVec::try_from(ids[0].id) { Ok(v) => Some(v), Err(_) => { warn!("client-worker {}: request id too long, skipping", id); continue; } } } else { None }; let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let (cstop, r_cstop) = CancellationToken::new(&reactor.local_registration_memory()); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let zreq_sender = zreq_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let ckey = conns.add(cstop, None, None).unwrap(); if let Some(cid) = &cid { let cid = (from, cid.clone()); conns.set_id(ckey, Some(&cid)); } debug!( "client-worker {}: req conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); if spawner .spawn(Self::req_connection_task( r_cstop, s_cdone, id, ckey, cid, (header, zreq), Arc::clone(&resolver), Arc::clone(&tls_config_cache), Arc::clone(&conn_pool), Rc::clone(&deny), opts.clone(), ConnectionReqOpts { body_buffer_size, sender: zreq_sender, }, )) .is_err() { // this should never happen. we only read a message // if we know we can spawn panic!("failed to spawn req_connection_task"); } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, } } drop(s_cdone); conns.stop_all(|ckey| debug!("client-worker {}: stopping {}", id, ckey)); while r_cdone.recv().await.is_ok() {} debug!("client-worker {}: task stopped: req_handle", id); } #[allow(clippy::too_many_arguments)] async fn stream_handle_task( id: usize, stop: AsyncLocalReceiver<()>, done: AsyncLocalSender, zstream_out_receiver: AsyncLocalReceiver<(Option>, zmq::Message)>, zstream_out_sender: channel::LocalSender<(Option>, zmq::Message)>, spawner: Spawner, resolver: Arc, tls_config_cache: Arc, conn_pool: Arc, stream_handle: zhttpsocket::AsyncServerStreamHandle, stream_maxconn: usize, conns: Rc, connection_blocks_max: usize, blocks_avail: Arc, messages_max: usize, allow_compression: bool, deny: Rc>, opts: ConnectionOpts, ) { let reactor = Reactor::current().unwrap(); let stream_shared_mem = Rc::new(arena::RcMemory::new(stream_maxconn)); let zreceiver_pool = Rc::new(ChannelPool::new(stream_maxconn)); for _ in 0..stream_maxconn { zreceiver_pool.push(local_channel(REQ_SENDER_BOUND, 1)); } let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * stream_maxconn); let stream_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let stream_req_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for this task let (s_cdone, r_cdone) = channel::local_channel::( conns.max(), conns.max() + 1, &reactor.local_registration_memory(), ); let r_cdone = AsyncLocalReceiver::new(r_cdone); let resume_waker = task::create_resume_waker(); debug!("client-worker {}: task started: stream_handle", id); { let mut handle_send = pin!(None); loop { let receiver_recv = if handle_send.is_none() { Some(zstream_out_receiver.recv()) } else { None }; let stream_handle_recv_from_any = if conns.count() < conns.max() { Some(stream_handle.recv_from_any()) } else { None }; match select_6( stop.recv(), select_option(receiver_recv), select_option(handle_send.as_mut().as_pin_mut()), r_cdone.recv(), select_option(pin!(stream_handle_recv_from_any).as_pin_mut()), pin!(stream_handle.recv_directed()), ) .await { // stop.recv Select6::R1(_) => break, // receiver_recv Select6::R2(result) => match result { Ok((addr, msg)) => handle_send.set(Some(stream_handle.send(addr, msg))), Err(e) => panic!("zstream_out_receiver channel error: {}", e), }, // handle_send Select6::R3(result) => { handle_send.set(None); if let Err(e) = result { error!("stream send error: {}", e); } } // r_cdone.recv Select6::R4(result) => match result { Ok(done) => { let zreceiver_sender = conns.remove(done.ckey).unwrap(); let zreceiver = zreceiver_sender .make_receiver(&reactor.local_registration_memory()) .unwrap(); zreceiver.clear(); zreceiver_pool.push((zreceiver_sender, zreceiver)); } Err(e) => panic!("r_cdone channel error: {}", e), }, // stream_handle_recv_from_any Select6::R5(result) => match result { Ok(ret) => { let (msg, session) = ret; let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq_ref = zreq.get(); let ids = zreq_ref.ids; if ids.len() != 1 { warn!("client-worker {}: packet did not contain exactly one id, skipping", id); continue; } if ids[0].seq != Some(0) { warn!("client-worker {}: received message with seq != 0 as first message, skipping", id); continue; } if !zreq_ref.ptype_str.is_empty() { warn!("client-worker {}: received non-data message as first message, skipping", id); continue; } if zreq_ref.from.len() > FROM_MAX { warn!("client-worker {}: from address too long, skipping", id); continue; } let cid: ArrayVec = match ArrayVec::try_from(ids[0].id) { Ok(v) => v, Err(_) => { warn!("client-worker {}: request id too long, skipping", id); continue; } }; let zreq = arena::Rc::new(zreq, &stream_req_mem).unwrap(); let (cstop, r_cstop) = CancellationToken::new(&reactor.local_registration_memory()); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let zstream_out_sender = zstream_out_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zstream_receiver_sender, zstream_receiver) = zreceiver_pool.take().unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), &stream_shared_mem) .unwrap(); let ckey = conns .add( cstop, Some(zstream_receiver_sender), Some(arena::Rc::clone(&shared)), ) .unwrap(); debug!( "client-worker {}: stream conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); if spawner .spawn(Self::stream_connection_task( r_cstop, s_cdone, id, ckey, cid, arena::Rc::clone(&zreq), Arc::clone(&resolver), Arc::clone(&tls_config_cache), Arc::clone(&conn_pool), zstream_receiver, Rc::clone(&deny), Rc::clone(&conns), opts.clone(), ConnectionStreamOpts { blocks_max: connection_blocks_max, blocks_avail: Arc::clone(&blocks_avail), messages_max, allow_compression, sender: zstream_out_sender, }, shared, Some(session), )) .is_err() { // this should never happen. we only read a message // if we know we can spawn panic!("failed to spawn stream_connection_task"); } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, // stream_handle.recv_directed Select6::R6(result) => match result { Ok(msg) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zreq = match zhttppacket::OwnedRequest::parse(msg, 0, scratch) { Ok(zreq) => zreq, Err(e) => { warn!("client-worker {}: zhttp parse error: {}", id, e); continue; } }; let zreq = arena::Rc::new(zreq, &stream_req_mem).unwrap(); let zreq_ref = zreq.get().get(); let ids = zreq_ref.ids; if ids.is_empty() { warn!("client-worker {}: packet contained no ids, skipping", id); continue; } let from: ArrayVec = match ArrayVec::try_from(zreq_ref.from) { Ok(v) => v, Err(_) => { warn!( "client-worker {}: from address too long, skipping", id ); continue; } }; let mut count = 0; for (i, rid) in ids.iter().enumerate() { let cid: ArrayVec = match ArrayVec::try_from(rid.id) { Ok(v) => v, Err(_) => { warn!( "client-worker {}: request id too long, skipping", id ); continue; } }; let cid = (from.clone(), cid); let key = match conns.find_key(&cid) { Some(key) => key, None => continue, }; // this should always succeed, since afterwards we yield // to let the connection receive the message match conns.try_send(key, (arena::Rc::clone(&zreq), i)) { Ok(()) => count += 1, Err(mpsc::TrySendError::Full(_)) => error!( "client-worker {}: connection-{} cannot receive message", id, key ), Err(mpsc::TrySendError::Disconnected(_)) => {} // conn task ended } } debug!( "client-worker {}: queued zmq message for {} conns", id, count ); if count > 0 { yield_to_local_events(&resume_waker).await; } } Err(e) => panic!("client-worker {}: handle read error {}", id, e), }, } } } drop(s_cdone); conns.stop_all(|ckey| debug!("client-worker {}: stopping {}", id, ckey)); while r_cdone.recv().await.is_ok() {} // give the handle back done.send(stream_handle).await.unwrap(); debug!("client-worker {}: task stopped: stream_handle", id); } #[allow(clippy::too_many_arguments)] async fn req_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: Option>, zreq: (MultipartHeader, arena::Rc), resolver: Arc, tls_config_cache: Arc, pool: Arc, deny: Rc>, opts: ConnectionOpts, req_opts: ConnectionReqOpts, ) { let done = AsyncLocalSender::new(done); debug!( "client-worker {}: task started: connection-{}", worker_id, ckey ); let log_id = if let Some(cid) = &cid { // zhttp ids are pretty much always valid strings, but we'll // do a lossy conversion just in case let cid_str = String::from_utf8_lossy(cid); format!("{}-{}-{}", worker_id, ckey, cid_str) } else { format!("{}-{}", worker_id, ckey) }; client_req_connection( token, &log_id, cid.as_deref(), zreq, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.packet_buf, opts.timeout, &deny, &resolver, &tls_config_cache, &pool, AsyncLocalSender::new(req_opts.sender), ) .await; done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "client-worker {}: task stopped: connection-{}", worker_id, ckey ); } #[allow(clippy::too_many_arguments)] async fn stream_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: ArrayVec, zreq: arena::Rc, resolver: Arc, tls_config_cache: Arc, pool: Arc, zreceiver: channel::LocalReceiver<(arena::Rc, usize)>, deny: Rc>, conns: Rc, opts: ConnectionOpts, stream_opts: ConnectionStreamOpts, shared: arena::Rc, session: Option, ) { let done = AsyncLocalSender::new(done); let zreceiver = AsyncLocalReceiver::new(zreceiver); debug!( "client-worker {}: task started: connection-{}", worker_id, ckey ); let log_id = { // zhttp ids are pretty much always valid strings, but we'll // do a lossy conversion just in case let cid_str = String::from_utf8_lossy(&cid); format!("{}-{}-{}", worker_id, ckey, cid_str) }; client_stream_connection( token, &log_id, &cid, arena::Rc::clone(&zreq), opts.buffer_size, stream_opts.blocks_max, &stream_opts.blocks_avail, stream_opts.messages_max, &opts.rb_tmp, opts.packet_buf, opts.tmp_buf, opts.timeout, stream_opts.allow_compression, &deny, &opts.instance_id, &resolver, &tls_config_cache, &pool, zreceiver, AsyncLocalSender::new(stream_opts.sender), shared, &|| { // handle task limits addr to FROM_MAX so this is guaranteed to succeed let from: ArrayVec = ArrayVec::try_from(zreq.get().get().from).unwrap(); let cid = (from, cid.clone()); conns.set_id(ckey, Some(&cid)) }, ) .await; drop(session); done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "client-worker {}: task stopped: connection-{}", worker_id, ckey ); } async fn keep_alives_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, instance_id: Rc, sender: channel::LocalSender<(Option>, zmq::Message)>, conns: Rc, ) { debug!("client-worker {}: task started: keep_alives", id); let reactor = Reactor::current().unwrap(); let mut keep_alive_count = 0; let mut next_keep_alive_time = reactor.now() + KEEP_ALIVE_INTERVAL; let next_keep_alive_timeout = Timeout::new(next_keep_alive_time); let mut next_keep_alive_index = 0; let sender = AsyncLocalSender::new(sender); 'main: loop { // wait for next keep alive time match select_2(stop.recv(), next_keep_alive_timeout.elapsed()).await { Select2::R1(_) => break, Select2::R2(_) => {} } for _ in 0..conns.batch_capacity() { if next_keep_alive_index >= conns.items_capacity() { break; } let key = next_keep_alive_index; next_keep_alive_index += 1; if conns.can_stream(key) { // ignore errors let _ = conns.batch_add(key); } } keep_alive_count += 1; if keep_alive_count >= KEEP_ALIVE_BATCHES { keep_alive_count = 0; next_keep_alive_index = 0; } // keep steady pace next_keep_alive_time += KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); while !conns.batch_is_empty() { let send = match select_2(stop.recv(), sender.wait_sendable()).await { Select2::R1(_) => break 'main, Select2::R2(send) => send, }; // there could be no message if items removed or message construction failed let (count, addr, msg) = match conns.next_batch_message(&instance_id, BatchType::KeepAlive) { Some(ret) => ret, None => continue, }; debug!( "client-worker {}: sending keep alives for {} sessions", id, count ); if let Err(e) = send.try_send((addr, msg)) { error!("zhttp write error: {}", e); } } let now = reactor.now(); if now >= next_keep_alive_time + KEEP_ALIVE_INTERVAL { // got really behind somehow. just skip ahead next_keep_alive_time = now + KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); } } debug!("client-worker {}: task stopped: keep_alives", id); } } impl Drop for Worker { fn drop(&mut self) { self.stop(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } pub struct Client { workers: Vec, } impl Client { #[allow(clippy::too_many_arguments)] pub fn new( instance_id: &str, worker_count: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, blocks_max: usize, connection_blocks_max: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, deny: &[IpNet], zsockman: Arc, handle_bound: usize, ) -> Result { assert!(blocks_max >= stream_maxconn * 2); // 1 active query per connection let queries_max = req_maxconn + stream_maxconn; let resolver = Arc::new(Resolver::new(RESOLVER_THREADS, queries_max)); let tls_config_cache = Arc::new(TlsConfigCache::new()); let pool_max = if event::can_move_mio_sockets_between_threads() { (req_maxconn + stream_maxconn) / 10 } else { // disable persistent connections 0 }; let pool = Arc::new(ConnectionPool::new(pool_max)); if !deny.is_empty() { debug!("default policy: block outgoing connections to {:?}", deny); } let blocks_avail = Arc::new(Counter::new(blocks_max - (stream_maxconn * 2))); let mut workers = Vec::new(); for i in 0..worker_count { let w = Worker::new( instance_id, i, req_maxconn / worker_count, stream_maxconn / worker_count, buffer_size, body_buffer_size, connection_blocks_max, &blocks_avail, messages_max, req_timeout, stream_timeout, allow_compression, deny, &resolver, &tls_config_cache, &pool, &zsockman, handle_bound, ); workers.push(w); } Ok(Self { workers }) } pub fn task_sizes() -> Vec<(String, usize)> { let req_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let req_scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_req_mem = Rc::new(arena::RcMemory::new(1)); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let msg = concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ); let msg = arena::Arc::new(zmq::Message::from(msg.as_bytes()), &msg_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let resolver = Arc::new(Resolver::new(1, 1)); let tls_config_cache = Arc::new(TlsConfigCache::new()); let pool = Arc::new(ConnectionPool::new(0)); let fut = Worker::req_connection_task( stop, done, 0, 0, None, (MultipartHeader::new(), zreq), resolver, tls_config_cache, pool, Rc::new(Vec::new()), ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionReqOpts { body_buffer_size: 0, sender, }, ); mem::size_of_val(&fut) }; let stream_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (_, zreceiver) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let batch = Batch::new(1); let conn_items = Rc::new(RefCell::new(ConnectionItems::new(1, batch))); let conns = Rc::new(Connections::new(conn_items, 1)); let req_scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_req_mem = Rc::new(arena::RcMemory::new(1)); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let msg = concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ); let msg = arena::Arc::new(zmq::Message::from(msg.as_bytes()), &msg_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_req_mem).unwrap(); let resolver = Arc::new(Resolver::new(1, 1)); let tls_config_cache = Arc::new(TlsConfigCache::new()); let pool = Arc::new(ConnectionPool::new(0)); let stream_shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &stream_shared_mem).unwrap(); let fut = Worker::stream_connection_task( stop, done, 0, 0, ArrayVec::new(), zreq, resolver, tls_config_cache, pool, zreceiver, Rc::new(Vec::new()), conns, ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionStreamOpts { blocks_max: 2, blocks_avail: Arc::new(Counter::new(0)), messages_max: 0, allow_compression: false, sender, }, shared, None, ); mem::size_of_val(&fut) }; vec![ ("client_req_connection_task".to_string(), req_task_size), ( "client_stream_connection_task".to_string(), stream_task_size, ), ] } } impl Drop for Client { fn drop(&mut self) { for w in self.workers.iter_mut() { w.stop(); } } } #[derive(Debug, Eq, PartialEq)] enum StatusMessage { Started, ReqFinished, StreamFinished, } enum ControlMessage { Stop, Req(zmq::Message), Stream(zmq::Message), } pub struct TestClient { _client: Client, thread: Option>, status: channel::Receiver, control: channel::Sender, next_id: Cell, } impl TestClient { pub fn new(workers: usize) -> Self { let zmq_context = Arc::new(zmq::Context::new()); let req_maxconn = 100; let stream_maxconn = 100; let maxconn = req_maxconn + stream_maxconn; let mut zsockman = zhttpsocket::ServerSocketManager::new( Arc::clone(&zmq_context), "test", (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * workers), 100, 100, 100, stream_maxconn, ); zsockman .set_server_req_specs(&[SpecInfo { spec: String::from("inproc://client-test"), bind: true, ipc_file_mode: 0, }]) .unwrap(); let zsockman = Arc::new(zsockman); let client = Client::new( "test", workers, req_maxconn, stream_maxconn, 1024, 1024, stream_maxconn * 2, 2, 10, Duration::from_secs(5), Duration::from_secs(5), false, &[], zsockman.clone(), 100, ) .unwrap(); zsockman .set_server_stream_specs( &[SpecInfo { spec: String::from("inproc://client-test-out"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://client-test-out-stream"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://client-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let (status_s, status_r) = channel::channel(1000); let (control_s, control_r) = channel::channel(1000); let thread = thread::Builder::new() .name("test-client".to_string()) .spawn(move || { Self::run(status_s, control_r, zmq_context); }) .unwrap(); // wait for handler thread to start assert_eq!(status_r.recv().unwrap(), StatusMessage::Started); Self { _client: client, thread: Some(thread), status: status_r, control: control_s, next_id: Cell::new(0), } } pub fn do_req(&self, addr: std::net::SocketAddr) { let msg = self.make_req_message(addr).unwrap(); self.control.send(ControlMessage::Req(msg)).unwrap(); } pub fn do_stream_http(&self, addr: std::net::SocketAddr) { let msg = self.make_stream_message(addr, false, false).unwrap(); self.control.send(ControlMessage::Stream(msg)).unwrap(); } pub fn do_stream_http_router_resp(&self, addr: std::net::SocketAddr) { let msg = self.make_stream_message(addr, true, false).unwrap(); self.control.send(ControlMessage::Stream(msg)).unwrap(); } pub fn do_stream_ws(&self, addr: std::net::SocketAddr) { let msg = self.make_stream_message(addr, false, true).unwrap(); self.control.send(ControlMessage::Stream(msg)).unwrap(); } pub fn wait_req(&self) { assert_eq!(self.status.recv().unwrap(), StatusMessage::ReqFinished); } pub fn wait_stream(&self) { assert_eq!(self.status.recv().unwrap(), StatusMessage::StreamFinished); } fn make_req_message(&self, addr: std::net::SocketAddr) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; let mut tmp = [0u8; 1024]; let id = { let id = self.next_id.get(); self.next_id.set(id + 1); let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "{}", id)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"method")?; w.write_string(b"GET")?; let mut tmp = [0u8; 1024]; let uri = { let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "http://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"uri")?; w.write_string(uri)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn make_stream_message( &self, addr: std::net::SocketAddr, router_resp: bool, ws: bool, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; let mut tmp = [0u8; 1024]; let id = { let id = self.next_id.get(); self.next_id.set(id + 1); let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "{}", id)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; let mut tmp = [0u8; 1024]; let uri = if ws { let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "ws://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] } else { w.write_string(b"method")?; w.write_string(b"GET")?; let mut cursor = io::Cursor::new(&mut tmp[..]); write!(&mut cursor, "http://{}/path", addr)?; let pos = cursor.position() as usize; &tmp[..pos] }; w.write_string(b"uri")?; w.write_string(uri)?; w.write_string(b"credits")?; w.write_int(1024)?; if router_resp { w.write_string(b"router-resp")?; w.write_bool(true)?; } w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_msg( id: &[u8], seq: u32, ptype: &str, content_type: &str, body: &[u8], code: Option, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(seq as isize)?; if ptype.is_empty() { w.write_string(b"content-type")?; w.write_string(content_type.as_bytes())?; } else { w.write_string(b"type")?; w.write_string(ptype.as_bytes())?; } if let Some(x) = code { w.write_string(b"code")?; w.write_int(x as isize)?; } w.write_string(b"body")?; w.write_string(body)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn run( status: channel::Sender, control: channel::Receiver, zmq_context: Arc, ) { let req_sock = zmq_context.socket(zmq::DEALER).unwrap(); req_sock.connect("inproc://client-test").unwrap(); let out_sock = zmq_context.socket(zmq::PUSH).unwrap(); out_sock.connect("inproc://client-test-out").unwrap(); let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); out_stream_sock.set_identity(b"handler").unwrap(); out_stream_sock .connect("inproc://client-test-out-stream") .unwrap(); let in_sock = zmq_context.socket(zmq::SUB).unwrap(); in_sock.set_subscribe(b"handler ").unwrap(); in_sock.connect("inproc://client-test-in").unwrap(); // ensure zsockman is subscribed thread::sleep(Duration::from_millis(100)); status.send(StatusMessage::Started).unwrap(); let mut poller = event::Poller::new(1).unwrap(); poller .register_custom( control.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&req_sock.get_fd().unwrap()), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&out_stream_sock.get_fd().unwrap()), mio::Token(3), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_sock.get_fd().unwrap()), mio::Token(4), mio::Interest::READABLE, ) .unwrap(); let mut req_events = req_sock.get_events().unwrap(); let mut out_stream_events = out_stream_sock.get_events().unwrap(); let mut in_events = in_sock.get_events().unwrap(); 'main: loop { while req_events.contains(zmq::POLLIN) { let parts = match req_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { req_events = req_sock.get_events().unwrap(); break; } Err(e) => panic!("recv error: {:?}", e), }; req_events = req_sock.get_events().unwrap(); assert_eq!(parts.len(), 2); let msg = &parts[1]; assert_eq!(msg[0], b'T'); let mut ptype = ""; let mut code: u16 = 0; let mut reason = ""; let mut body = b"".as_slice(); for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "type" => { let s = tnetstring::parse_string(f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "code" => { let x = tnetstring::parse_int(f.data).unwrap(); code = x as u16; } "reason" => { let s = tnetstring::parse_string(f.data).unwrap(); reason = str::from_utf8(s).unwrap(); } "body" => { let s = tnetstring::parse_string(f.data).unwrap(); body = s; } _ => {} } } debug!("received req message"); assert_eq!(ptype, ""); assert_eq!(code, 200); assert_eq!(reason, "OK"); assert_eq!(str::from_utf8(body).unwrap(), "hello\n"); status.send(StatusMessage::ReqFinished).unwrap(); } while out_stream_events.contains(zmq::POLLIN) || in_events.contains(zmq::POLLIN) { let mut msg_and_pos = None; if out_stream_events.contains(zmq::POLLIN) { match out_stream_sock.recv_multipart(zmq::DONTWAIT) { Ok(mut parts) => { out_stream_events = out_stream_sock.get_events().unwrap(); assert_eq!(parts.len(), 3); msg_and_pos = Some((parts.remove(2), 0)); } Err(zmq::Error::EAGAIN) => { out_stream_events = out_stream_sock.get_events().unwrap(); } Err(e) => panic!("recv error: {:?}", e), } } if msg_and_pos.is_none() && in_events.contains(zmq::POLLIN) { match in_sock.recv_multipart(zmq::DONTWAIT) { Ok(mut parts) => { in_events = in_sock.get_events().unwrap(); assert_eq!(parts.len(), 1); let buf = &parts[0]; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } msg_and_pos = Some((parts.remove(0), pos.unwrap() + 1)); } Err(zmq::Error::EAGAIN) => { in_events = in_sock.get_events().unwrap(); } Err(e) => panic!("recv error: {:?}", e), }; } let (msg, from_router) = match &msg_and_pos { Some((msg, pos)) => (&msg[*pos..], *pos == 0), None => break, }; assert_eq!(msg[0], b'T'); let mut id = ""; let mut seq = None; let mut ptype = ""; let mut code = None; let mut reason = ""; let mut content_type = ""; let mut body = &b""[..]; let mut more = false; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "seq" => { let x = tnetstring::parse_int(f.data).unwrap(); seq = Some(x as u32); } "type" => { let s = tnetstring::parse_string(f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "code" => { let x = tnetstring::parse_int(f.data).unwrap(); code = Some(x as u16); } "reason" => { let s = tnetstring::parse_string(f.data).unwrap(); reason = str::from_utf8(s).unwrap(); } "content-type" => { let s = tnetstring::parse_string(f.data).unwrap(); content_type = str::from_utf8(s).unwrap(); } "body" => { let s = tnetstring::parse_string(f.data).unwrap(); body = s; } "more" => { let b = tnetstring::parse_bool(f.data).unwrap(); more = b; } _ => {} } } let seq = seq.unwrap(); debug!( "received stream message from_router={} id={} seq={}", from_router, id, seq ); let out_seq = seq + 1; // as a hack to make the test server stateless, respond to every message // using the received sequence number. for messages we don't care about, // respond with keep-alive in order to keep the sequencing going if ptype.is_empty() || ptype == "ping" || ptype == "pong" || ptype == "close" { if ptype.is_empty() && content_type.is_empty() { // assume http/ws accept, or http body if !reason.is_empty() { // http/ws accept let code = code.unwrap(); assert!(code == 200 || code == 101); if code == 200 { assert_eq!(reason, "OK"); assert_eq!(body.len(), 0); assert!(more); } else { // 101 assert_eq!(reason, "Switching Protocols"); assert_eq!(body.len(), 0); assert!(!more); } let msg = Self::respond_msg( id.as_bytes(), out_seq, "keep-alive", "", b"", None, ) .unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); out_stream_events = out_stream_sock.get_events().unwrap(); } else { // http body assert_eq!(str::from_utf8(body).unwrap(), "hello\n"); assert!(!more); status.send(StatusMessage::StreamFinished).unwrap(); } } else { // assume ws message if ptype == "ping" { ptype = "pong"; } // echo let msg = Self::respond_msg( id.as_bytes(), out_seq, ptype, content_type, body, code, ) .unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); out_stream_events = out_stream_sock.get_events().unwrap(); if ptype == "close" { status.send(StatusMessage::StreamFinished).unwrap(); } } } else { let msg = Self::respond_msg(id.as_bytes(), out_seq, "keep-alive", "", b"", None) .unwrap(); out_stream_sock .send_multipart( [ zmq::Message::from(b"test".as_slice()), zmq::Message::new(), msg, ], 0, ) .unwrap(); out_stream_events = out_stream_sock.get_events().unwrap(); } } poller.poll(None).unwrap(); for event in poller.iter_events() { match event.token() { mio::Token(1) => { while let Ok(msg) = control.try_recv() { match msg { ControlMessage::Stop => break 'main, ControlMessage::Req(msg) => { req_sock .send_multipart([zmq::Message::new(), msg], 0) .unwrap(); req_events = req_sock.get_events().unwrap(); } ControlMessage::Stream(msg) => out_sock.send(msg, 0).unwrap(), } } } mio::Token(2) => req_events = req_sock.get_events().unwrap(), mio::Token(3) => out_stream_events = out_stream_sock.get_events().unwrap(), mio::Token(4) => in_events = in_sock.get_events().unwrap(), _ => unreachable!(), } } } } } impl Drop for TestClient { fn drop(&mut self) { self.control.try_send(ControlMessage::Stop).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] pub mod tests { use super::*; use crate::connmgr::connection::calculate_ws_accept; use crate::connmgr::websocket; use std::io::Read; use test_log::test; fn recv_frame( stream: &mut R, buf: &mut Vec, ) -> Result<(bool, u8, Vec), io::Error> { loop { let fi = match websocket::read_header(buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => return Err(e), }; while buf.len() < fi.payload_offset + fi.payload_size { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); } let mut content = Vec::from(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); if let Some(mask) = fi.mask { websocket::apply_mask(&mut content, mask, 0); } *buf = buf.split_off(fi.payload_offset + fi.payload_size); return Ok((fi.fin, fi.opcode, content)); } } #[test] fn test_client() { let client = TestClient::new(1); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); // req client.do_req(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ) .unwrap(); drop(stream); client.wait_req(); // stream (http) client.do_stream_http(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ) .unwrap(); drop(stream); client.wait_stream(); // stream (http) with responses via router client.do_stream_http_router_resp(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let expected = format!( concat!("GET /path HTTP/1.1\r\n", "Host: {}\r\n", "\r\n"), addr ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); stream .write( b"HTTP/1.1 200 OK\r\nContent-Type: text/plain\r\nContent-Length: 6\r\n\r\nhello\n", ) .unwrap(); drop(stream); client.wait_stream(); // stream (ws) client.do_stream_ws(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let req_buf = &buf[..req_end]; // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; 32]; let mut req = httparse::Request::new(&mut headers); match req.parse(req_buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: {}\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "\r\n" ), addr, ws_key, ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); buf = buf.split_off(req_end); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); stream.write(resp_data.as_bytes()).unwrap(); // send message let mut data = vec![0; 1024]; let body = &b"hello"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv message let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_TEXT); assert_eq!(str::from_utf8(&content).unwrap(), "hello"); } #[test] fn test_ws() { let client = TestClient::new(1); let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap(); let addr = listener.local_addr().unwrap(); client.do_stream_ws(addr); let (mut stream, _) = listener.accept().unwrap(); let mut buf = Vec::new(); let mut req_end = 0; while req_end == 0 { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { req_end = i + 4; break; } } } let req_buf = &buf[..req_end]; // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; 32]; let mut req = httparse::Request::new(&mut headers); match req.parse(req_buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: {}\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "\r\n" ), addr, ws_key, ); assert_eq!(str::from_utf8(&buf[..req_end]).unwrap(), expected); buf = buf.split_off(req_end); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); stream.write(resp_data.as_bytes()).unwrap(); // send binary let mut data = vec![0; 1024]; let body = &[1, 2, 3][..]; let size = websocket::write_header( true, false, websocket::OPCODE_BINARY, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv binary let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_BINARY); assert_eq!(content, &[1, 2, 3][..]); buf.clear(); // send ping let mut data = vec![0; 1024]; let body = &b""[..]; let size = websocket::write_header( true, false, websocket::OPCODE_PING, body.len(), None, &mut data, ) .unwrap(); stream.write(&data[..size]).unwrap(); // recv pong let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_PONG); assert_eq!(str::from_utf8(&content).unwrap(), ""); buf.clear(); // send close let mut data = vec![0; 1024]; let body = &b"\x03\xf0gone"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_CLOSE, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); stream.write(&data[..(size + body.len())]).unwrap(); // recv close let (fin, opcode, content) = recv_frame(&mut stream, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_CLOSE); assert_eq!(&content, &b"\x03\xf0gone"[..]); // expect tcp close let mut chunk = [0; 1024]; let size = stream.read(&mut chunk).unwrap(); assert_eq!(size, 0); client.wait_stream(); } #[cfg(target_arch = "x86_64")] #[cfg(debug_assertions)] #[test] fn test_task_sizes() { // sizes in debug mode at commit c0e4d161997e5c2880ba3409efe13afa3ec26fd7 const REQ_TASK_SIZE_BASE: usize = 6888; const STREAM_TASK_SIZE_BASE: usize = 12152; // cause tests to fail if sizes grow too much const GROWTH_LIMIT: usize = 1000; const REQ_TASK_SIZE_MAX: usize = REQ_TASK_SIZE_BASE + GROWTH_LIMIT; const STREAM_TASK_SIZE_MAX: usize = STREAM_TASK_SIZE_BASE + GROWTH_LIMIT; let sizes = Client::task_sizes(); assert_eq!(sizes[0].0, "client_req_connection_task"); assert!(sizes[0].1 <= REQ_TASK_SIZE_MAX); assert_eq!(sizes[1].0, "client_stream_connection_task"); assert!(sizes[1].1 <= STREAM_TASK_SIZE_MAX); } } pushpin-1.41.0/src/connmgr/connection.rs000066400000000000000000011775141504671364300202530ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ // Note: Always Be Receiving (ABR) // // Connection handlers are expected to read ZHTTP messages as fast as // possible. If they don't, the whole thread could stall. This is by design, // to limit the number of to-be-processed messages in memory. They either // need to do something immediately with the messages, or discard them. // // Every await point must ensure messages keep getting read/processed, by // doing one of: // // - Directly awaiting a message. // - Awaiting a select that is awaiting a message. // - Wrapping other activity with discard_while(). // - Calling handle_other(), which itself will read messages. // - Awaiting something known to not block. #![allow(clippy::collapsible_if)] #![allow(clippy::collapsible_else_if)] use crate::connmgr::counter::{Counter, CounterDec}; use crate::connmgr::pool::Pool; use crate::connmgr::resolver; use crate::connmgr::tls::{AsyncTlsStream, TlsConfigCache, TlsStream, TlsWaker, VerifyMode}; use crate::connmgr::track::{ self, track_future, Track, TrackFlag, TrackedAsyncLocalReceiver, ValueActiveError, }; use crate::connmgr::websocket; use crate::connmgr::zhttppacket; use crate::core::arena; use crate::core::buffer::{ Buffer, ContiguousBuffer, LimitBufsMut, TmpBuffer, VecRingBuffer, VECTORED_MAX, }; use crate::core::channel::{AsyncLocalReceiver, AsyncLocalSender}; use crate::core::defer::Defer; use crate::core::http1::Error as CoreHttpError; use crate::core::http1::{self, client, server, RecvStatus, SendStatus}; use crate::core::io::{ io_split, AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt, ReadHalf, StdWriteWrapper, WriteHalf, }; use crate::core::net::{AsyncTcpStream, SocketAddr}; use crate::core::reactor::Reactor; use crate::core::select::{select_2, select_3, select_4, select_option, Select2, Select3, Select4}; use crate::core::shuffle::random; use crate::core::task::{poll_async, CancellationToken}; use crate::core::time::Timeout; use crate::core::waker::RefWakerData; use crate::core::zmq::MultipartHeader; use arrayvec::{ArrayString, ArrayVec}; use ipnet::IpNet; use log::{debug, log, warn, Level}; use sha1::{Digest, Sha1}; use std::cell::{Ref, RefCell}; use std::cmp; use std::collections::VecDeque; use std::convert::TryFrom; use std::future::Future; use std::io::{self, Read, Write}; use std::mem; use std::net::IpAddr; use std::pin::pin; use std::pin::Pin; use std::rc::Rc; use std::str; use std::str::FromStr; use std::sync::{mpsc, Arc, Mutex}; use std::task::Context; use std::task::Poll; use std::thread; use std::time::{Duration, Instant}; const URI_SIZE_MAX: usize = 4096; const HEADERS_MAX: usize = 64; const WS_HASH_INPUT_MAX: usize = 256; const WS_KEY_MAX: usize = 24; // base64_encode([16 bytes]) = 24 bytes const WS_ACCEPT_MAX: usize = 28; // base64_encode(sha1_hash) = 28 bytes const REDIRECTS_MAX: usize = 8; const ZHTTP_SESSION_TIMEOUT: Duration = Duration::from_secs(60); const CONNECTION_POOL_TTL: Duration = Duration::from_secs(55); pub trait CidProvider { fn get_new_assigned_cid(&mut self) -> ArrayString<32>; } pub trait Identify { fn set_id(&mut self, id: &str); } #[derive(PartialEq)] enum Mode { HttpReq, HttpStream, WebSocket, } fn get_host<'a>(headers: &'a [httparse::Header]) -> &'a str { for h in headers.iter() { if h.name.eq_ignore_ascii_case("Host") { match str::from_utf8(h.value) { Ok(s) => return s, Err(_) => break, } } } "localhost" } fn gen_ws_key() -> ArrayString { let mut nonce = [0; 16]; for b in nonce.iter_mut() { *b = (random() % 256) as u8; } let mut output = [0; WS_KEY_MAX]; let size = base64::encode_config_slice(nonce, base64::STANDARD, &mut output); let output = str::from_utf8(&output[..size]).unwrap(); ArrayString::from_str(output).unwrap() } #[allow(clippy::result_unit_err)] pub fn calculate_ws_accept(key: &[u8]) -> Result, ()> { let input_len = key.len() + websocket::WS_GUID.len(); if input_len > WS_HASH_INPUT_MAX { return Err(()); } let mut input = [0; WS_HASH_INPUT_MAX]; input[..key.len()].copy_from_slice(key); input[key.len()..input_len].copy_from_slice(websocket::WS_GUID.as_bytes()); let input = &input[..input_len]; let mut hasher = Sha1::new(); hasher.update(input); let digest = hasher.finalize(); let mut output = [0; WS_ACCEPT_MAX]; let size = base64::encode_config_slice(digest, base64::STANDARD, &mut output); let output = match str::from_utf8(&output[..size]) { Ok(s) => s, Err(_) => return Err(()), }; Ok(ArrayString::from_str(output).unwrap()) } fn validate_ws_request( req: &http1::Request, ws_version: Option<&[u8]>, ws_key: Option<&[u8]>, ) -> Result, ()> { // a websocket request must not have a body. // some clients send "Content-Length: 0", which we'll allow. // chunked encoding will be rejected. if req.method == "GET" && (req.body_size == http1::BodySize::NoBody || req.body_size == http1::BodySize::Known(0)) && ws_version == Some(b"13") { if let Some(ws_key) = ws_key { return calculate_ws_accept(ws_key); } } Err(()) } fn validate_ws_response(ws_key: &[u8], ws_accept: Option<&[u8]>) -> Result<(), ()> { if let Some(ws_accept) = ws_accept { if calculate_ws_accept(ws_key)?.as_bytes() == ws_accept { return Ok(()); } } Err(()) } fn gen_mask() -> [u8; 4] { let mut out = [0; 4]; for b in out.iter_mut() { *b = (random() % 256) as u8; } out } fn write_ws_ext_header_value( config: &websocket::PerMessageDeflateConfig, dest: &mut W, ) -> Result<(), io::Error> { write!(dest, "permessage-deflate")?; config.serialize(dest) } #[allow(clippy::too_many_arguments)] fn make_zhttp_request( instance: &str, ids: &[zhttppacket::Id], method: &str, path: &str, headers: &[httparse::Header], body: &[u8], more: bool, mode: Mode, credits: u32, peer_addr: Option<&SocketAddr>, secure: bool, packet_buf: &mut [u8], ) -> Result { let mut data = zhttppacket::RequestData::new(); data.method = method; let host = get_host(headers); let mut zheaders = [zhttppacket::EMPTY_HEADER; HEADERS_MAX]; let mut zheaders_len = 0; for h in headers.iter() { zheaders[zheaders_len] = zhttppacket::Header { name: h.name, value: h.value, }; zheaders_len += 1; } data.headers = &zheaders[..zheaders_len]; let scheme = match mode { Mode::HttpReq | Mode::HttpStream => { if secure { "https" } else { "http" } } Mode::WebSocket => { if secure { "wss" } else { "ws" } } }; let mut uri = [0; URI_SIZE_MAX]; let mut c = io::Cursor::new(&mut uri[..]); write!(&mut c, "{}://{}{}", scheme, host, path)?; let size = c.position() as usize; data.uri = match str::from_utf8(&uri[..size]) { Ok(s) => s, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidData)), }; data.body = body; data.more = more; if mode == Mode::HttpStream { data.stream = true; } if mode == Mode::HttpStream || mode == Mode::WebSocket { data.router_resp = true; } data.credits = credits; let mut addr = [0; 128]; if let Some(SocketAddr::Ip(peer_addr)) = peer_addr { let mut c = io::Cursor::new(&mut addr[..]); write!(&mut c, "{}", peer_addr.ip()).unwrap(); let size = c.position() as usize; data.peer_address = str::from_utf8(&addr[..size]).unwrap(); data.peer_port = peer_addr.port(); } let mut zreq = zhttppacket::Request::new_data(instance.as_bytes(), ids, data); zreq.multi = true; let size = zreq.serialize(packet_buf)?; Ok(zmq::Message::from(&packet_buf[..size])) } // return the capacity increase fn resize_write_buffer_if_full( buf: &mut VecRingBuffer, block_size: usize, blocks_max: usize, blocks_avail: &mut CounterDec, ) -> usize { assert!(blocks_max >= 2); // all but one block can be used for writing let allowed = blocks_max - 1; if buf.remaining_capacity() == 0 && buf.capacity() < block_size * allowed && blocks_avail.dec(1).is_ok() { buf.resize(buf.capacity() + block_size); block_size } else { 0 } } #[derive(Debug)] enum Error { Io(io::Error), #[allow(dead_code)] Utf8(str::Utf8Error), #[allow(dead_code)] CoreHttp(CoreHttpError), #[allow(dead_code)] WebSocket(websocket::Error), ReqModeWebSocket, InvalidWebSocketRequest, InvalidWebSocketResponse, #[allow(dead_code)] WebSocketRejectionTooLarge(usize), Compression, BadMessage, Handler, HandlerCancel, BufferExceeded, BadFrame, BadRequest, Tls, PolicyViolation, TooManyRedirects, ValueActive, StreamTimeout, SessionTimeout, Stopped, } impl Error { // returns true if the error represents a logic error (a bug in the code) // that could warrant a panic or high severity log level fn is_logical(&self) -> bool { matches!(self, Error::ValueActive) } fn log_level(&self) -> Level { if self.is_logical() { Level::Error } else { Level::Debug } } fn to_condition(&self) -> &'static str { match self { Error::Io(e) if e.kind() == io::ErrorKind::ConnectionRefused => { "remote-connection-failed" } Error::Io(e) if e.kind() == io::ErrorKind::TimedOut => "connection-timeout", Error::BadRequest => "bad-request", Error::StreamTimeout => "connection-timeout", Error::Tls => "tls-error", Error::PolicyViolation => "policy-violation", Error::TooManyRedirects => "too-many-redirects", Error::WebSocketRejectionTooLarge(_) => "rejection-too-large", _ => "undefined-condition", } } } impl From for Error { fn from(e: io::Error) -> Self { Self::Io(e) } } impl From for Error { fn from(e: str::Utf8Error) -> Self { Self::Utf8(e) } } impl From> for Error { fn from(_e: mpsc::SendError) -> Self { Self::Io(io::Error::from(io::ErrorKind::BrokenPipe)) } } impl From> for Error { fn from(e: mpsc::TrySendError) -> Self { let kind = match e { mpsc::TrySendError::Full(_) => io::ErrorKind::WriteZero, mpsc::TrySendError::Disconnected(_) => io::ErrorKind::BrokenPipe, }; Self::Io(io::Error::from(kind)) } } impl From for Error { fn from(e: CoreHttpError) -> Self { Self::CoreHttp(e) } } impl From for Error { fn from(e: websocket::Error) -> Self { Self::WebSocket(e) } } impl From for Error { fn from(_e: ValueActiveError) -> Self { Self::ValueActive } } impl From for Error { fn from(e: track::RecvError) -> Self { match e { track::RecvError::Disconnected => { Self::Io(io::Error::from(io::ErrorKind::UnexpectedEof)) } track::RecvError::ValueActive => Self::ValueActive, } } } #[derive(Clone, Copy)] struct MessageItem { mtype: u8, avail: usize, } struct MessageTracker { items: VecDeque, last_partial: bool, } impl MessageTracker { fn new(max_messages: usize) -> Self { Self { items: VecDeque::with_capacity(max_messages), last_partial: false, } } fn in_progress(&self) -> bool { self.last_partial } fn start(&mut self, mtype: u8) -> Result<(), ()> { if self.last_partial || self.items.len() == self.items.capacity() { return Err(()); } self.items.push_back(MessageItem { mtype, avail: 0 }); self.last_partial = true; Ok(()) } fn extend(&mut self, amt: usize) { assert!(self.last_partial); self.items.back_mut().unwrap().avail += amt; } fn done(&mut self) { self.last_partial = false; } // type, avail, done fn current(&self) -> Option<(u8, usize, bool)> { #[allow(clippy::comparison_chain)] if self.items.len() > 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, true)) } else if self.items.len() == 1 { let m = self.items.front().unwrap(); Some((m.mtype, m.avail, !self.last_partial)) } else { None } } fn consumed(&mut self, amt: usize, done: bool) { assert!(amt <= self.items[0].avail); self.items[0].avail -= amt; if done { assert_eq!(self.items[0].avail, 0); self.items.pop_front().unwrap(); } } } pub struct AddrRef<'a> { s: Ref<'a, Option>>, } impl AddrRef<'_> { pub fn get(&self) -> Option<&[u8]> { match &*self.s { Some(s) => Some(s.as_slice()), None => None, } } } struct StreamSharedDataInner { to_addr: Option>, out_seq: u32, router_resp: bool, } pub struct StreamSharedData { inner: RefCell, } #[allow(clippy::new_without_default)] impl StreamSharedData { pub fn new() -> Self { Self { inner: RefCell::new(StreamSharedDataInner { to_addr: None, out_seq: 0, router_resp: false, }), } } fn reset(&self) { let s = &mut *self.inner.borrow_mut(); s.to_addr = None; s.out_seq = 0; s.router_resp = false; } fn set_to_addr(&self, addr: Option>) { let s = &mut *self.inner.borrow_mut(); s.to_addr = addr; } pub fn to_addr(&self) -> AddrRef<'_> { AddrRef { s: Ref::map(self.inner.borrow(), |s| &s.to_addr), } } pub fn out_seq(&self) -> u32 { self.inner.borrow().out_seq } pub fn inc_out_seq(&self) { let s = &mut *self.inner.borrow_mut(); s.out_seq += 1; } pub fn router_resp(&self) -> bool { self.inner.borrow().router_resp } pub fn set_router_resp(&self, b: bool) { let s = &mut *self.inner.borrow_mut(); s.router_resp = b; } } fn make_zhttp_req_response( id: Option<&[u8]>, ptype: zhttppacket::ResponsePacket, scratch: &mut [u8], ) -> Result { let mut ids_mem = [zhttppacket::Id { id: b"", seq: None }]; let ids = if let Some(id) = id { ids_mem[0].id = id; ids_mem.as_slice() } else { &[] }; let zresp = zhttppacket::Response { from: b"", ids, multi: false, ptype, ptype_str: "", }; let size = zresp.serialize(scratch)?; let payload = &scratch[..size]; Ok(zmq::Message::from(payload)) } pub fn make_zhttp_response( addr: &[u8], use_router: bool, zresp: zhttppacket::Response, scratch: &mut [u8], ) -> Result<(Option>, zmq::Message), io::Error> { let size = zresp.serialize(scratch)?; let payload = &scratch[..size]; let (addr, v) = if use_router { // for router, use message as-is and return addr separately let v = Vec::from(payload); let addr = ArrayVec::try_from(addr).expect("addr has unexpected size"); (Some(addr), v) } else { // for pub, embed addr in message let mut v = vec![0; addr.len() + 1 + payload.len()]; v[..addr.len()].copy_from_slice(addr); v[addr.len()] = b' '; let pos = addr.len() + 1; v[pos..(pos + payload.len())].copy_from_slice(payload); (None, v) }; // this takes over the vec's memory without copying let msg = zmq::Message::from(v); Ok((addr, msg)) } async fn recv_nonzero(r: &mut R, buf: &mut VecRingBuffer) -> Result<(), io::Error> { if buf.remaining_capacity() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let size = match r.read(buf.write_buf()).await { Ok(size) => size, Err(e) => return Err(e), }; buf.write_commit(size); if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } Ok(()) } struct WebSocketRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut VecRingBuffer, } struct WebSocketWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, buf: &'a mut VecRingBuffer, block_size: usize, } struct SendMessageContentFuture<'a, 'b, W: AsyncWrite, M> { w: &'a RefCell>, protocol: &'a websocket::Protocol, avail: usize, done: bool, } impl + AsMut<[u8]>> Future for SendMessageContentFuture<'_, '_, W, M> { type Output = Result<(usize, bool), Error>; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; let w = &mut *f.w.borrow_mut(); let stream = &mut w.stream; if !stream.is_writable() { return Poll::Pending; } // protocol.send_message_content may add 1 element to vector let mut buf_arr = mem::MaybeUninit::<[&mut [u8]; VECTORED_MAX - 1]>::uninit(); let mut bufs = w.buf.read_bufs_mut(&mut buf_arr).limit(f.avail); match f.protocol.send_message_content( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs.as_slice(), f.done, ) { Ok(ret) => Poll::Ready(Ok(ret)), Err(websocket::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e.into())), } } } impl Drop for SendMessageContentFuture<'_, '_, W, M> { fn drop(&mut self) { self.w.borrow_mut().stream.cancel(); } } struct WebSocketHandler<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, protocol: websocket::Protocol>, } impl<'a, R: AsyncRead, W: AsyncWrite> WebSocketHandler<'a, R, W> { fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut VecRingBuffer, buf2: &'a mut VecRingBuffer, deflate_config: Option<(bool, VecRingBuffer)>, ) -> Self { buf2.clear(); let block_size = buf2.capacity(); Self { r: RefCell::new(WebSocketRead { stream: stream.0, buf: buf1, }), w: RefCell::new(WebSocketWrite { stream: stream.1, buf: buf2, block_size, }), protocol: websocket::Protocol::new(deflate_config), } } fn state(&self) -> websocket::State { self.protocol.state() } #[allow(clippy::await_holding_refcell_ref)] async fn add_to_recv_buffer(&self) -> Result<(), Error> { let r = &mut *self.r.borrow_mut(); if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } Ok(()) } fn try_recv_message_content( &self, dest: &mut [u8], ) -> Option> { let r = &mut *self.r.borrow_mut(); loop { match self.protocol.recv_message_content(r.buf, dest) { Some(Ok(ret)) => return Some(Ok(ret)), Some(Err(e)) => return Some(Err(e.into())), None => { if !r.buf.is_readable_contiguous() { r.buf.align(); continue; } return None; } } } } fn accept_avail(&self) -> usize { self.w.borrow().buf.remaining_capacity() } fn accept_body(&self, body: &[u8]) -> Result<(), Error> { let w = &mut *self.w.borrow_mut(); w.buf.write_all(body)?; Ok(()) } fn expand_write_buffer(&self, blocks_max: usize, blocks_avail: &mut CounterDec) -> usize { let w = &mut *self.w.borrow_mut(); resize_write_buffer_if_full(w.buf, w.block_size, blocks_max, blocks_avail) } fn is_sending_message(&self) -> bool { self.protocol.is_sending_message() } fn send_message_start(&self, opcode: u8, mask: Option<[u8; 4]>) { self.protocol.send_message_start(opcode, mask); } async fn send_message_content( &self, avail: usize, done: bool, bytes_sent: &F, ) -> Result<(usize, bool), Error> where F: Fn(), { loop { let (size, done) = SendMessageContentFuture { w: &self.w, protocol: &self.protocol, avail, done, } .await?; let w = &mut *self.w.borrow_mut(); if size == 0 && !done { continue; } w.buf.read_commit(size); bytes_sent(); return Ok((size, done)); } } } struct ZhttpStreamSessionOut<'a> { instance_id: &'a str, id: &'a str, packet_buf: &'a RefCell>, sender_stream: &'a AsyncLocalSender<(ArrayVec, zmq::Message)>, shared: &'a StreamSharedData, } impl<'a> ZhttpStreamSessionOut<'a> { fn new( instance_id: &'a str, id: &'a str, packet_buf: &'a RefCell>, sender_stream: &'a AsyncLocalSender<(ArrayVec, zmq::Message)>, shared: &'a StreamSharedData, ) -> Self { Self { instance_id, id, packet_buf, sender_stream, shared, } } async fn check_send(&self) { self.sender_stream.check_send().await } fn cancel_send(&self) { self.sender_stream.cancel(); } // this method is non-blocking, in order to increment the sequence number // and send the message in one shot, without concurrent activity // interfering with the sequencing. to send asynchronously, first await // on check_send and then call this method fn try_send_msg(&self, zreq: zhttppacket::Request) -> Result<(), Error> { let msg = { let mut zreq = zreq; let ids = [zhttppacket::Id { id: self.id.as_bytes(), seq: Some(self.shared.out_seq()), }]; zreq.from = self.instance_id.as_bytes(); zreq.ids = &ids; zreq.multi = true; let packet_buf = &mut *self.packet_buf.borrow_mut(); let size = zreq.serialize(packet_buf)?; zmq::Message::from(&packet_buf[..size]) }; let mut addr = ArrayVec::new(); if addr .try_extend_from_slice(self.shared.to_addr().get().unwrap()) .is_err() { return Err(io::Error::from(io::ErrorKind::InvalidInput).into()); } self.sender_stream.try_send((addr, msg))?; self.shared.inc_out_seq(); Ok(()) } } struct ZhttpServerStreamSessionOut<'a> { instance_id: &'a str, id: &'a [u8], packet_buf: &'a RefCell>, sender: &'a AsyncLocalSender<(Option>, zmq::Message)>, shared: &'a StreamSharedData, } impl<'a> ZhttpServerStreamSessionOut<'a> { fn new( instance_id: &'a str, id: &'a [u8], packet_buf: &'a RefCell>, sender: &'a AsyncLocalSender<(Option>, zmq::Message)>, shared: &'a StreamSharedData, ) -> Self { Self { instance_id, id, packet_buf, sender, shared, } } async fn check_send(&self) { self.sender.check_send().await } fn cancel_send(&self) { self.sender.cancel(); } // this method is non-blocking, in order to increment the sequence number // and send the message in one shot, without concurrent activity // interfering with the sequencing. to send asynchronously, first await // on check_send and then call this method fn try_send_msg(&self, zresp: zhttppacket::Response) -> Result<(), Error> { let (addr, msg) = { let mut zresp = zresp; let ids = [zhttppacket::Id { id: self.id, seq: Some(self.shared.out_seq()), }]; zresp.from = self.instance_id.as_bytes(); zresp.ids = &ids; zresp.multi = true; let addr = self.shared.to_addr(); let addr = addr.get().unwrap(); let packet_buf = &mut *self.packet_buf.borrow_mut(); make_zhttp_response(addr, self.shared.router_resp(), zresp, packet_buf)? }; self.sender.try_send((addr, msg))?; self.shared.inc_out_seq(); Ok(()) } } struct ZhttpStreamSessionIn<'a, 'b, R> { id: &'a str, send_buf_size: usize, websocket: bool, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, next: Option<(Track<'b, arena::Rc>, usize)>, seq: u32, credits: u32, first_data: bool, } impl<'a, 'b: 'a, R> ZhttpStreamSessionIn<'a, 'b, R> where R: Fn(), { fn new( id: &'a str, send_buf_size: usize, websocket: bool, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, ) -> Self { Self { id, send_buf_size, websocket, receiver, shared, msg_read, next: None, seq: 0, credits: 0, first_data: true, } } fn credits(&self) -> u32 { self.credits } fn subtract_credits(&mut self, amount: u32) { self.credits -= amount; } async fn peek_msg(&mut self) -> Result<&arena::Rc, Error> { if self.next.is_none() { let (r, id_index) = loop { let (r, id_index) = Track::map_first(self.receiver.recv().await?); let zresp = r.get().get(); if zresp.ids[id_index].id != self.id.as_bytes() { // skip messages addressed to old ids continue; } break (r, id_index); }; let zresp = r.get().get(); if !zresp.ptype_str.is_empty() { debug!( "server-conn {}: handle packet: {}", self.id, zresp.ptype_str ); } else { debug!("server-conn {}: handle packet: (data)", self.id); } if zresp.ids.is_empty() { return Err(Error::BadMessage); } if let Some(seq) = zresp.ids[id_index].seq { if seq != self.seq { warn!( "server-conn {}: bad seq (expected {}, got {}), skipping", self.id, self.seq, seq ); return Err(Error::BadMessage); } self.seq += 1; } let mut addr = ArrayVec::new(); if addr.try_extend_from_slice(zresp.from).is_err() { return Err(Error::BadMessage); } self.shared.set_to_addr(Some(addr)); (self.msg_read)(); match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => { let mut credits = rdata.credits; if self.first_data { self.first_data = false; if self.websocket && credits == 0 { // workaround for pushpin-proxy, which doesn't // send credits on websocket accept credits = self.send_buf_size as u32; debug!( "server-conn {}: no credits in websocket accept, assuming {}", self.id, credits ); } } self.credits += credits; } zhttppacket::ResponsePacket::Error(edata) => { debug!( "server-conn {}: zhttp error condition={}", self.id, edata.condition ); } zhttppacket::ResponsePacket::Credit(cdata) => { self.credits += cdata.credits; } zhttppacket::ResponsePacket::Ping(pdata) => { self.credits += pdata.credits; } zhttppacket::ResponsePacket::Pong(pdata) => { self.credits += pdata.credits; } _ => {} } self.next = Some((r, id_index)); } Ok(&self.next.as_ref().unwrap().0) } async fn recv_msg( &mut self, ) -> Result>, Error> { self.peek_msg().await?; Ok(self.next.take().unwrap().0) } } struct ZhttpServerStreamSessionIn<'a, 'b, R> { log_id: &'a str, id: &'a [u8], receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, next: Option<(Track<'b, arena::Rc>, usize)>, seq: u32, credits: u32, } impl<'a, 'b: 'a, R> ZhttpServerStreamSessionIn<'a, 'b, R> where R: Fn(), { fn new( log_id: &'a str, id: &'a [u8], credits: u32, receiver: &'a TrackedAsyncLocalReceiver<'b, (arena::Rc, usize)>, shared: &'a StreamSharedData, msg_read: &'a R, ) -> Self { Self { log_id, id, receiver, shared, msg_read, next: None, seq: 1, credits, } } fn credits(&self) -> u32 { self.credits } fn subtract_credits(&mut self, amount: u32) { self.credits -= amount; } async fn peek_msg(&mut self) -> Result<&arena::Rc, Error> { if self.next.is_none() { let (r, id_index) = loop { let (r, id_index) = Track::map_first(self.receiver.recv().await?); let zreq = r.get().get(); if zreq.ids[id_index].id != self.id { // skip messages addressed to old ids continue; } break (r, id_index); }; let zreq = r.get().get(); if !zreq.ptype_str.is_empty() { debug!( "client-conn {}: handle packet: {}", self.log_id, zreq.ptype_str ); } else { debug!("client-conn {}: handle packet: (data)", self.log_id); } if zreq.ids.is_empty() { return Err(Error::BadMessage); } if let Some(seq) = zreq.ids[id_index].seq { if seq != self.seq { warn!( "client-conn {}: bad seq (expected {}, got {}), skipping", self.log_id, self.seq, seq ); return Err(Error::BadMessage); } self.seq += 1; } let mut addr = ArrayVec::new(); if addr.try_extend_from_slice(zreq.from).is_err() { return Err(Error::BadMessage); } self.shared.set_to_addr(Some(addr)); (self.msg_read)(); match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => { self.credits += rdata.credits; } zhttppacket::RequestPacket::Error(edata) => { debug!( "client-conn {}: zhttp error condition={}", self.log_id, edata.condition ); } zhttppacket::RequestPacket::Credit(cdata) => { self.credits += cdata.credits; } zhttppacket::RequestPacket::Ping(pdata) => { self.credits += pdata.credits; } zhttppacket::RequestPacket::Pong(pdata) => { self.credits += pdata.credits; } _ => {} } self.next = Some((r, id_index)); } Ok(&self.next.as_ref().unwrap().0) } async fn recv_msg(&mut self) -> Result>, Error> { self.peek_msg().await?; Ok(self.next.take().unwrap().0) } } async fn send_msg(sender: &AsyncLocalSender, msg: zmq::Message) -> Result<(), Error> { Ok(sender.send(msg).await?) } async fn discard_while( receiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, fut: F, ) -> Result where F: Future> + Unpin, Error: From, { match select_2(fut, pin!(receiver.recv())).await { Select2::R1(v) => Ok(v?), Select2::R2(ret) => { ret?; // unexpected message in current state Err(Error::BadMessage) } } } async fn server_discard_while( receiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, fut: F, ) -> F::Output where F: Future> + Unpin, { match select_2(fut, pin!(receiver.recv())).await { Select2::R1(v) => v, Select2::R2(_) => Err(Error::BadMessage), // unexpected message in current state } } async fn send_error_response( mut resp: server::Response<'_, R, W>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, e: &Error, ) -> Result<(), Error> { let headers = &[http1::Header { name: "Content-Type", value: b"text/plain", }]; let mut body: ArrayVec = ArrayVec::new(); let code = match e { Error::CoreHttp(CoreHttpError::Protocol(e)) => { writeln!(&mut body, "Failed to parse request: {}", e)?; 400 } Error::CoreHttp(CoreHttpError::RequestTooLarge(limit)) => { writeln!( &mut body, "Request header size exceeded limit of {} bytes.", limit )?; 400 } Error::CoreHttp(CoreHttpError::ResponseTooLarge(limit)) => { writeln!( &mut body, "Response header size exceeded limit of {} bytes.", limit )?; 500 } Error::ReqModeWebSocket => { writeln!(&mut body, "WebSockets not supported on req mode interface.")?; 400 } Error::InvalidWebSocketRequest => { writeln!(&mut body, "Request contained an Upgrade header with value \"websocket\" but the request was not a valid WebSocket request.")?; 400 } Error::WebSocketRejectionTooLarge(limit) => { writeln!( &mut body, "Non-101 response body size exceeded limit of {} bytes.", limit )?; 500 } _ => { writeln!(&mut body, "Failed to process request.")?; 500 } }; let reason = match code { 400 => "Bad Request", _ => "Internal Server Error", }; let mut state = server::ResponseState::default(); let (header, prepare_body) = resp.prepare_header( code, reason, headers, http1::BodySize::Known(body.len()), &mut state, )?; // ABR: discard_while let header_sent = discard_while(zreceiver, pin!(header.send())).await?; let resp_body = header_sent.start_body(prepare_body); resp_body.prepare(&body, true)?; loop { // send the buffer let send = pin!(async { match resp_body.send().await { SendStatus::Complete(finished) => Ok(Some(finished)), SendStatus::EarlyResponse(_) => unreachable!(), // for requests only SendStatus::Partial((), _) => Ok(None), SendStatus::Error((), e) => Err(e), } }); // ABR: discard_while if let Some(_finished) = discard_while(zreceiver, send).await? { break; } } Ok(()) } // read request body and prepare outgoing zmq message #[allow(clippy::too_many_arguments)] async fn server_req_read_body( id: &str, req: &http1::Request<'_, '_>, req_body: &mut server::RequestBodyKeepHeader<'_, '_, R, W>, peer_addr: Option<&SocketAddr>, secure: bool, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result { // receive request body loop { match req_body.try_recv(body_buf.write_buf())? { RecvStatus::Complete((), size) => { body_buf.write_commit(size); break; } RecvStatus::Read((), size) => { body_buf.write_commit(size); if size == 0 { return Err(Error::BufferExceeded); } } RecvStatus::NeedBytes(()) => { // ABR: discard_while discard_while(zreceiver, pin!(req_body.add_to_buffer())).await?; } } } // determine how to respond let mut websocket = false; for h in req.headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") && h.value == b"websocket" { websocket = true; break; } } if websocket { // websocket requests are not supported in req mode // toss the request body body_buf.clear(); return Err(Error::ReqModeWebSocket); } // regular http requests we can handle // prepare zmq message let ids = [zhttppacket::Id { id: id.as_bytes(), seq: None, }]; let msg = make_zhttp_request( "", &ids, req.method, req.uri, req.headers, Buffer::read_buf(body_buf), false, Mode::HttpReq, 0, peer_addr, secure, &mut packet_buf.borrow_mut(), )?; // body consumed body_buf.clear(); Ok(msg) } // read full request and prepare outgoing zmq message. // return Ok(None) if client disconnects before providing a complete request header async fn server_req_read_header_and_body( id: &str, req_header: server::RequestHeader<'_, '_, R, W>, peer_addr: Option<&SocketAddr>, secure: bool, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result, Error> { let mut scratch = http1::ParseScratch::::new(); // receive request header // WARNING: the returned req_header must not be dropped and instead must // be consumed by discard_header(). be careful with early returns from // this function and do not use the ?-operator let (req_header, mut req_body) = { // ABR: discard_while match discard_while(zreceiver, pin!(req_header.recv(&mut scratch))).await { Ok(ret) => ret, Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(e), } }; let req_ref = req_header.get(); // log request { let host = get_host(req_ref.headers); let scheme = if secure { "https" } else { "http" }; debug!( "server-conn {}: request: {} {}://{}{}", id, req_ref.method, scheme, host, req_ref.uri ); } let result = server_req_read_body( id, &req_ref, &mut req_body, peer_addr, secure, body_buf, packet_buf, zreceiver, ) .await; // whether success or fail, toss req_header so we are able to respond req_body.discard_header(req_header); // NOTE: req_header is now consumed and we don't need to worry about it from here Ok(Some(result?)) } struct ReqRespond<'buf, 'st, R: AsyncRead, W: AsyncWrite> { header: server::ResponseHeader<'buf, 'st, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, } // consumes resp if successful #[allow(clippy::too_many_arguments)] async fn server_req_respond<'buf, 'st, R: AsyncRead, W: AsyncWrite>( id: &str, req: server::Request, resp: &mut Option>, resp_state: &'st mut server::ResponseState<'buf, R, W>, peer_addr: Option<&SocketAddr>, secure: bool, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, zsender: &AsyncLocalSender, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result>, Error> { let msg = { let req_header = req.recv_header(resp.as_mut().unwrap()); match server_req_read_header_and_body( id, req_header, peer_addr, secure, body_buf, packet_buf, zreceiver, ) .await? { Some(msg) => msg, None => return Ok(None), } }; // send message // ABR: discard_while discard_while(zreceiver, pin!(send_msg(zsender, msg))).await?; // receive message let zresp = loop { // ABR: direct read let (zresp, id_index) = Track::map_first(zreceiver.recv().await?); let zresp_ref = zresp.get().get(); if zresp_ref.ids[id_index].id != id.as_bytes() { // skip messages addressed to old ids continue; } if !zresp_ref.ptype_str.is_empty() { debug!("server-conn {}: handle packet: {}", id, zresp_ref.ptype_str); } else { debug!("server-conn {}: handle packet: (data)", id); } // skip non-data messages match &zresp_ref.ptype { zhttppacket::ResponsePacket::Data(_) => break zresp, _ => debug!( "server-conn {}: unexpected packet in req mode: {}", id, zresp_ref.ptype_str ), } }; let (header, prepare_body) = { let zresp = zresp.get().get(); let rdata = match &zresp.ptype { zhttppacket::ResponsePacket::Data(rdata) => rdata, _ => unreachable!(), // we confirmed the type above }; if body_buf.write_all(rdata.body).is_err() { return Err(Error::BufferExceeded); } // send response header let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } let headers = &headers[..headers_len]; let mut resp_take = resp.take().unwrap(); let (header, prepare_body) = match resp_take.prepare_header( rdata.code, rdata.reason, headers, http1::BodySize::Known(rdata.body.len()), resp_state, ) { Ok(ret) => ret, Err(e) => { *resp = Some(resp_take); return Err(e.into()); } }; (header, prepare_body) }; Ok(Some(ReqRespond { header, prepare_body, })) } // return true if persistent #[allow(clippy::too_many_arguments)] async fn server_req_handler( id: &str, stream: &mut S, peer_addr: Option<&SocketAddr>, secure: bool, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, zsender: &AsyncLocalSender, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result { let stream = RefCell::new(stream); let mut resp_state = server::ResponseState::default(); let r = { let (req, resp) = server::Request::new(io_split(&stream), buf1, buf2); let mut resp = Some(resp); let ret = match server_req_respond( id, req, &mut resp, &mut resp_state, peer_addr, secure, body_buf, packet_buf, zsender, zreceiver, ) .await { Ok(Some(ret)) => ret, Ok(None) => return Ok(false), // no request Err(e) => { // on error, resp is not consumed, so we can use it send_error_response(resp.take().unwrap(), zreceiver, &e).await?; return Err(e); } }; assert!(resp.is_none()); ret }; // ABR: discard_while let header_sent = discard_while(zreceiver, pin!(r.header.send())).await?; let resp_body = header_sent.start_body(r.prepare_body); // send response body let finished = loop { // fill the buffer as much as possible let size = resp_body.prepare(Buffer::read_buf(body_buf), true)?; body_buf.read_commit(size); // send the buffer let send = pin!(async { match resp_body.send().await { SendStatus::Complete(finished) => Ok(Some(finished)), SendStatus::EarlyResponse(_) => unreachable!(), // for requests only SendStatus::Partial((), _) => Ok(None), SendStatus::Error((), e) => Err(e), } }); // ABR: discard_while if let Some(finished) = discard_while(zreceiver, send).await? { break finished; } }; assert_eq!(body_buf.len(), 0); Ok(finished.is_persistent()) } #[allow(clippy::too_many_arguments)] async fn server_req_connection_inner( token: CancellationToken, cid: &mut ArrayString<32>, cid_provider: &mut P, mut stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, zsender: AsyncLocalSender, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let mut buf1 = VecRingBuffer::new(buffer_size, rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, rb_tmp); let mut body_buf = ContiguousBuffer::new(body_buffer_size); loop { stream.set_id(cid); // this was originally logged when starting the non-async state // machine, so we'll keep doing that debug!("server-conn {}: assigning id", cid); let reuse = { let handler = server_req_handler( cid.as_ref(), &mut stream, peer_addr, secure, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, &zsender, zreceiver, ); let timeout = Timeout::new(reactor.now() + timeout); match select_3(pin!(handler), timeout.elapsed(), token.cancelled()).await { Select3::R1(ret) => ret?, Select3::R2(_) => return Err(Error::StreamTimeout), Select3::R3(_) => return Err(Error::Stopped), } }; if !reuse { break; } // note: buf1 is not cleared as there may be data to read buf2.clear(); body_buf.clear(); *cid = cid_provider.get_new_assigned_cid(); } // ABR: discard_while discard_while(zreceiver, pin!(stream.close())).await?; Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn server_req_connection( token: CancellationToken, mut cid: ArrayString<32>, cid_provider: &mut P, stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, zsender: AsyncLocalSender, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, ) { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( server_req_connection_inner( token, &mut cid, cid_provider, stream, peer_addr, secure, buffer_size, body_buffer_size, rb_tmp, packet_buf, timeout, zsender, &zreceiver, ), &value_active, ) .await { Ok(()) => debug!("server-conn {}: finished", cid), Err(e) => log!(e.log_level(), "server-conn {}: process error: {:?}", cid, e), } } async fn accept_handoff( zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { // discarding here is fine. the sender should cease sending // messages until we've replied with proceed discard_while( zsess_in.receiver, pin!(async { zsess_out.check_send().await; Ok::<(), Error>(()) }), ) .await?; let zreq = zhttppacket::Request::new_handoff_proceed(b"", &[]); // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; // unset to_addr so we don't send keep-alives zsess_in.shared.set_to_addr(None); // pause until we get a msg zsess_in.peek_msg().await?; Ok(()) } async fn server_accept_handoff( zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { // discarding here is fine. the sender should cease sending // messages until we've replied with proceed server_discard_while( zsess_in.receiver, pin!(async { zsess_out.check_send().await; Ok(()) }), ) .await?; let zresp = zhttppacket::Response::new_handoff_proceed(b"", &[]); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; // unset to_addr so we don't send keep-alives zsess_in.shared.set_to_addr(None); // pause until we get a msg zsess_in.peek_msg().await?; Ok(()) } // this function will either return immediately or await messages async fn handle_other( zresp: Track<'_, arena::Rc>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { match &zresp.get().get().ptype { zhttppacket::ResponsePacket::KeepAlive => Ok(()), zhttppacket::ResponsePacket::Credit(_) => Ok(()), zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); accept_handoff(zsess_in, zsess_out).await?; Ok(()) } zhttppacket::ResponsePacket::Error(_) => Err(Error::Handler), zhttppacket::ResponsePacket::Cancel => Err(Error::HandlerCancel), _ => Err(Error::BadMessage), // unexpected type } } // this function will either return immediately or await messages async fn server_handle_other( zreq: Track<'_, arena::Rc>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where R: Fn(), { match &zreq.get().get().ptype { zhttppacket::RequestPacket::KeepAlive => Ok(()), zhttppacket::RequestPacket::Credit(_) => Ok(()), zhttppacket::RequestPacket::HandoffStart => { drop(zreq); server_accept_handoff(zsess_in, zsess_out).await?; Ok(()) } zhttppacket::RequestPacket::Error(_) => Err(Error::Handler), zhttppacket::RequestPacket::Cancel => Err(Error::HandlerCancel), _ => Err(Error::BadMessage), // unexpected type } } async fn stream_recv_body( tmp_buf: &RefCell>, bytes_read: &R1, req_body: server::RequestBody<'_, '_, R, W>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { let mut check_send = pin!(None); let mut add_to_buffer = pin!(None); loop { if zsess_in.credits() > 0 && add_to_buffer.is_none() && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read let ret = select_3( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_buffer.as_mut().as_pin_mut()), pin!(zsess_in.peek_msg()), ) .await; match ret { Select3::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); assert!(zsess_in.credits() > 0); assert!(add_to_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (size, done) = match req_body.try_recv(&mut tmp_buf[..max_read])? { RecvStatus::Complete((), size) => (size, true), RecvStatus::Read((), size) => (size, false), RecvStatus::NeedBytes(()) => { add_to_buffer.set(Some(req_body.add_to_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; zsess_in.subtract_credits(size as u32); let mut rdata = zhttppacket::RequestData::new(); rdata.body = body; rdata.more = !done; let zresp = zhttppacket::Request::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; if done { break; } } Select3::R2(ret) => { ret?; add_to_buffer.set(None); } Select3::R3(ret) => { let r = ret?; let zresp_ref = r.get().get(); match &zresp_ref.ptype { zhttppacket::ResponsePacket::Data(_) => break, _ => { // ABR: direct read let zresp = zsess_in.recv_msg().await?; // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } } } Ok(()) } async fn server_stream_recv_body( tmp_buf: &RefCell>, bytes_read: &R1, resp_body: client::ResponseBody<'_, R>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result where R1: Fn(), R2: Fn(), R: AsyncRead, { let mut check_send = pin!(None); let mut add_to_buffer = pin!(None); loop { if zsess_in.credits() > 0 && add_to_buffer.is_none() && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } // ABR: select contains read let ret = select_3( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_buffer.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select3::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); assert!(zsess_in.credits() > 0); assert!(add_to_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (size, mut finished) = match resp_body.try_recv(&mut tmp_buf[..max_read])? { RecvStatus::Complete(finished, size) => (size, Some(finished)), RecvStatus::Read((), size) => (size, None), RecvStatus::NeedBytes(()) => { add_to_buffer.set(Some(resp_body.add_to_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; zsess_in.subtract_credits(size as u32); let mut rdata = zhttppacket::ResponseData::new(); rdata.body = body; rdata.more = finished.is_none(); let zresp = zhttppacket::Response::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; if let Some(finished) = finished.take() { return Ok(finished); } } Select3::R2(ret) => { ret?; add_to_buffer.set(None); } Select3::R3(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } async fn stream_send_body( bytes_read: &R1, resp_body: server::ResponseBody<'_, R, W>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, ) -> Result where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { let mut out_credits = 0; let mut send = pin!(None); let mut check_send = pin!(None); let mut prepare_done = false; let finished = 'main: loop { let ret = { if send.is_none() && resp_body.can_send() { send.set(Some(resp_body.send())); } if !prepare_done && out_credits > 0 && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } let fill_recv_buffer = if send.is_none() { Some(resp_body.fill_recv_buffer()) } else { None }; // ABR: select contains read select_4( select_option(send.as_mut().as_pin_mut()), select_option(check_send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), select_option(pin!(fill_recv_buffer).as_pin_mut()), ) .await }; match ret { Select4::R1(ret) => { send.set(None); match ret { SendStatus::Complete(finished) => break finished, SendStatus::EarlyResponse(_) => unreachable!(), // for requests only SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error(_, e) => return Err(e.into()), } } Select4::R2(()) => { check_send.set(None); let zreq = zhttppacket::Request::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; } Select4::R3(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => { let size = resp_body.prepare(rdata.body, !rdata.more)?; if size < rdata.body.len() { return Err(Error::BufferExceeded); } if rdata.more { out_credits += resp_body .expand_write_buffer(blocks_max, || blocks_avail.dec(1).is_ok())? as u32; } else { prepare_done = true; } } zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait if send.is_none() && resp_body.can_send() { send.set(Some(resp_body.send())); } while let Some(fut) = send.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send.set(None); match ret { SendStatus::Complete(resp) => break 'main resp, SendStatus::EarlyResponse(_) => unreachable!(), // for requests only SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error((), e) => return Err(e.into()), } if resp_body.can_send() { send.set(Some(resp_body.send())); } } // ABR: function contains read accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } Select4::R4(e) => return Err(e.into()), } }; Ok(finished) } struct Overflow { buf: ContiguousBuffer, end: bool, } #[allow(clippy::too_many_arguments)] async fn server_stream_send_body<'a, R1, R2, R, W>( bytes_read: &R1, req_body: client::RequestBody<'a, R, W>, mut overflow: Option, recv_buf_size: usize, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, ) -> Result, Error> where R1: Fn(), R2: Fn(), R: AsyncRead, W: AsyncWrite, { // send initial body, including overflow, before offering credits let mut send = pin!(None); while send.is_some() || req_body.can_send() { if send.is_none() { send.set(Some(req_body.send())); } // ABR: select contains read let result = select_2( select_option(send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match result { Select2::R1(ret) => { send.set(None); match ret { SendStatus::Complete(resp) => return Ok(resp), SendStatus::EarlyResponse(resp) => return Ok(resp), SendStatus::Partial((), _) => { if !req_body.can_send() { if let Some(overflow) = &mut overflow { let size = req_body.prepare(overflow.buf.read_buf(), overflow.end)?; overflow.buf.read_commit(size); } } } SendStatus::Error((), e) => return Err(e.into()), } } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } assert!(!req_body.can_send()); let mut out_credits = recv_buf_size as u32; let mut send = pin!(None); let mut check_send = pin!(None); let mut prepare_done = false; let resp = 'main: loop { let ret = { if send.is_none() && req_body.can_send() { send.set(Some(req_body.send())); } if !prepare_done && out_credits > 0 && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } let fill_recv_buffer = if send.is_none() { Some(req_body.fill_recv_buffer()) } else { None }; // ABR: select contains read select_4( select_option(send.as_mut().as_pin_mut()), select_option(check_send.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), select_option(pin!(fill_recv_buffer).as_pin_mut()), ) .await }; match ret { Select4::R1(ret) => { send.set(None); match ret { SendStatus::Complete(resp) => break resp, SendStatus::EarlyResponse(resp) => break resp, SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error(_, e) => return Err(e.into()), } } Select4::R2(()) => { check_send.set(None); let zresp = zhttppacket::Response::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; } Select4::R3(ret) => { let zreq = ret?; match &zreq.get().get().ptype { zhttppacket::RequestPacket::Data(rdata) => { let size = req_body.prepare(rdata.body, !rdata.more)?; if size < rdata.body.len() { return Err(Error::BufferExceeded); } if rdata.more { out_credits += req_body .expand_write_buffer(blocks_max, || blocks_avail.dec(1).is_ok())? as u32; } else { prepare_done = true; } } zhttppacket::RequestPacket::HandoffStart => { drop(zreq); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait if send.is_none() && req_body.can_send() { send.set(Some(req_body.send())); } while let Some(fut) = send.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send.set(None); match ret { SendStatus::Complete(resp) => break 'main resp, SendStatus::EarlyResponse(resp) => break 'main resp, SendStatus::Partial((), size) => { out_credits += size as u32; if size > 0 { bytes_read(); } } SendStatus::Error((), e) => return Err(e.into()), } if req_body.can_send() { send.set(Some(req_body.send())); } } // ABR: function contains read server_accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } Select4::R4(e) => return Err(e.into()), } }; Ok(resp) } #[allow(clippy::too_many_arguments)] async fn stream_websocket( log_id: &str, stream: RefCell<&mut S>, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, messages_max: usize, tmp_buf: &RefCell>, bytes_read: &R1, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, zsess_in: &mut ZhttpStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpStreamSessionOut<'_>, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let deflate_config = match deflate_config { Some((config, enc_buf_size)) => { let ebuf = VecRingBuffer::new(enc_buf_size, buf2.get_tmp()); Some((!config.server_no_context_takeover, ebuf)) } None => None, }; let handler = WebSocketHandler::new(io_split(&stream), buf1, buf2, deflate_config); let mut ws_in_tracker = MessageTracker::new(messages_max); let mut out_credits = 0; let mut check_send = pin!(None); let mut add_to_recv_buffer = pin!(None); let mut send_content = pin!(None); loop { let (do_send, do_recv) = match handler.state() { websocket::State::Connected => (true, true), websocket::State::PeerClosed => (true, false), websocket::State::Closing => (false, true), websocket::State::Finished => break, }; if out_credits > 0 || (do_recv && zsess_in.credits() > 0 && add_to_recv_buffer.is_none()) && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } if do_send && send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, None); } if avail > 0 || done { send_content.set(Some(handler.send_message_content(avail, done, bytes_read))); } } } // ABR: select contains read let ret = select_4( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_recv_buffer.as_mut().as_pin_mut()), select_option(send_content.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select4::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); if out_credits > 0 { let zreq = zhttppacket::Request::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; continue; } assert!(zsess_in.credits() > 0); assert!(add_to_recv_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (opcode, size, end) = match handler.try_recv_message_content(&mut tmp_buf[..max_read]) { Some(ret) => ret?, None => { add_to_recv_buffer.set(Some(handler.add_to_recv_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; let zreq = match opcode { websocket::OPCODE_TEXT | websocket::OPCODE_BINARY => { if body.is_empty() && !end { // don't bother sending empty message continue; } let mut data = zhttppacket::RequestData::new(); data.body = body; data.content_type = if opcode == websocket::OPCODE_TEXT { Some(zhttppacket::ContentType::Text) } else { Some(zhttppacket::ContentType::Binary) }; data.more = !end; zhttppacket::Request::new_data(b"", &[], data) } websocket::OPCODE_CLOSE => { let status = if body.len() >= 2 { let mut arr = [0; 2]; arr[..].copy_from_slice(&body[..2]); let code = u16::from_be_bytes(arr); let reason = match str::from_utf8(&body[2..]) { Ok(reason) => reason, Err(e) => return Err(e.into()), }; Some((code, reason)) } else { None }; zhttppacket::Request::new_close(b"", &[], status) } websocket::OPCODE_PING => zhttppacket::Request::new_ping(b"", &[], body), websocket::OPCODE_PONG => zhttppacket::Request::new_pong(b"", &[], body), opcode => { debug!( "server-conn {}: unsupported websocket opcode: {}", log_id, opcode ); return Err(Error::BadFrame); } }; zsess_in.subtract_credits(size as u32); // check_send just finished, so this should succeed zsess_out.try_send_msg(zreq)?; } Select4::R2(ret) => { ret?; add_to_recv_buffer.set(None); } Select4::R3(ret) => { send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } Select4::R4(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(rdata.body) { warn!( "received too much data from handler (size={}, credits={})", rdata.body.len(), avail, ); return Err(e); } out_credits += handler.expand_write_buffer(blocks_max, blocks_avail) as u32; let opcode = match &rdata.content_type { Some(zhttppacket::ContentType::Binary) => websocket::OPCODE_BINARY, _ => websocket::OPCODE_TEXT, }; if !ws_in_tracker.in_progress() { if ws_in_tracker.start(opcode).is_err() { return Err(Error::BufferExceeded); } } ws_in_tracker.extend(rdata.body.len()); if !rdata.more { ws_in_tracker.done(); } } _ => {} }, zhttppacket::ResponsePacket::Close(cdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let (code, reason) = cdata.status.unwrap_or((1000, "")); let arr: [u8; 2] = code.to_be_bytes(); // close content isn't limited by credits. if we // don't have space for it, just error out handler.accept_body(&arr)?; handler.accept_body(reason.as_bytes())?; if ws_in_tracker.start(websocket::OPCODE_CLOSE).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(arr.len() + reason.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::Ping(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PING).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::Pong(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PONG).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::ResponsePacket::HandoffStart => { drop(zresp); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait loop { if send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, None); } if avail > 0 || done { send_content.set(Some( handler.send_message_content(avail, done, bytes_read), )); } } } if let Some(fut) = send_content.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } else { break; } } // ABR: function contains read accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other handle_other(zresp, zsess_in, zsess_out).await?; } } } } } Ok(()) } #[allow(clippy::too_many_arguments)] async fn server_stream_websocket( log_id: &str, stream: RefCell<&mut S>, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, messages_max: usize, tmp_buf: &RefCell>, bytes_read: &R1, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, ) -> Result<(), Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let deflate_config = match deflate_config { Some((config, enc_buf_size)) => { let ebuf = VecRingBuffer::new(enc_buf_size, buf2.get_tmp()); Some((!config.client_no_context_takeover, ebuf)) } None => None, }; let handler = WebSocketHandler::new(io_split(&stream), buf1, buf2, deflate_config); let mut ws_in_tracker = MessageTracker::new(messages_max); let mut out_credits = 0; let mut check_send = pin!(None); let mut add_to_recv_buffer = pin!(None); let mut send_content = pin!(None); loop { let (do_send, do_recv) = match handler.state() { websocket::State::Connected => (true, true), websocket::State::PeerClosed => (true, false), websocket::State::Closing => (false, true), websocket::State::Finished => break, }; if out_credits > 0 || (do_recv && zsess_in.credits() > 0 && add_to_recv_buffer.is_none()) && check_send.is_none() { check_send.set(Some(zsess_out.check_send())); } if do_send && send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, Some(gen_mask())); } if avail > 0 || done { send_content.set(Some(handler.send_message_content(avail, done, bytes_read))); } } } // ABR: select contains read let ret = select_4( select_option(check_send.as_mut().as_pin_mut()), select_option(add_to_recv_buffer.as_mut().as_pin_mut()), select_option(send_content.as_mut().as_pin_mut()), pin!(zsess_in.recv_msg()), ) .await; match ret { Select4::R1(()) => { check_send.set(None); let _defer = Defer::new(|| zsess_out.cancel_send()); if out_credits > 0 { let zresp = zhttppacket::Response::new_credit(b"", &[], out_credits); out_credits = 0; // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; continue; } assert!(zsess_in.credits() > 0); assert!(add_to_recv_buffer.is_none()); let tmp_buf = &mut *tmp_buf.borrow_mut(); let max_read = cmp::min(tmp_buf.len(), zsess_in.credits() as usize); let (opcode, size, end) = match handler.try_recv_message_content(&mut tmp_buf[..max_read]) { Some(ret) => ret?, None => { add_to_recv_buffer.set(Some(handler.add_to_recv_buffer())); continue; } }; bytes_read(); let body = &tmp_buf[..size]; let zresp = match opcode { websocket::OPCODE_TEXT | websocket::OPCODE_BINARY => { if body.is_empty() && !end { // don't bother sending empty message continue; } let mut data = zhttppacket::ResponseData::new(); data.body = body; data.content_type = if opcode == websocket::OPCODE_TEXT { Some(zhttppacket::ContentType::Text) } else { Some(zhttppacket::ContentType::Binary) }; data.more = !end; zhttppacket::Response::new_data(b"", &[], data) } websocket::OPCODE_CLOSE => { let status = if body.len() >= 2 { let mut arr = [0; 2]; arr[..].copy_from_slice(&body[..2]); let code = u16::from_be_bytes(arr); let reason = match str::from_utf8(&body[2..]) { Ok(reason) => reason, Err(e) => return Err(e.into()), }; Some((code, reason)) } else { None }; zhttppacket::Response::new_close(b"", &[], status) } websocket::OPCODE_PING => zhttppacket::Response::new_ping(b"", &[], body), websocket::OPCODE_PONG => zhttppacket::Response::new_pong(b"", &[], body), opcode => { debug!( "client-conn {}: unsupported websocket opcode: {}", log_id, opcode ); return Err(Error::BadFrame); } }; zsess_in.subtract_credits(size as u32); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; } Select4::R2(ret) => { ret?; add_to_recv_buffer.set(None); } Select4::R3(ret) => { send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } Select4::R4(ret) => { let zreq = ret?; match &zreq.get().get().ptype { zhttppacket::RequestPacket::Data(rdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(rdata.body) { warn!( "received too much data from handler (size={}, credits={})", rdata.body.len(), avail, ); return Err(e); } out_credits += handler.expand_write_buffer(blocks_max, blocks_avail) as u32; let opcode = match &rdata.content_type { Some(zhttppacket::ContentType::Binary) => websocket::OPCODE_BINARY, _ => websocket::OPCODE_TEXT, }; if !ws_in_tracker.in_progress() { if ws_in_tracker.start(opcode).is_err() { return Err(Error::BufferExceeded); } } ws_in_tracker.extend(rdata.body.len()); if !rdata.more { ws_in_tracker.done(); } } _ => {} }, zhttppacket::RequestPacket::Close(cdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let (code, reason) = cdata.status.unwrap_or((1000, "")); let arr: [u8; 2] = code.to_be_bytes(); // close content isn't limited by credits. if we // don't have space for it, just error out handler.accept_body(&arr)?; handler.accept_body(reason.as_bytes())?; if ws_in_tracker.start(websocket::OPCODE_CLOSE).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(arr.len() + reason.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::Ping(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PING).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::Pong(pdata) => match handler.state() { websocket::State::Connected | websocket::State::PeerClosed => { let avail = handler.accept_avail(); if let Err(e) = handler.accept_body(pdata.body) { warn!( "received too much data from handler (size={}, credits={})", pdata.body.len(), avail, ); return Err(e); } if ws_in_tracker.start(websocket::OPCODE_PONG).is_err() { return Err(Error::BadFrame); } ws_in_tracker.extend(pdata.body.len()); ws_in_tracker.done(); } _ => {} }, zhttppacket::RequestPacket::HandoffStart => { drop(zreq); // if handoff requested, flush what we can before accepting // so that the data is not delayed while we wait loop { if send_content.is_none() { if let Some((mtype, avail, done)) = ws_in_tracker.current() { if !handler.is_sending_message() { handler.send_message_start(mtype, Some(gen_mask())); } if avail > 0 || done { send_content.set(Some( handler.send_message_content(avail, done, bytes_read), )); } } } if let Some(fut) = send_content.as_mut().as_pin_mut() { // ABR: poll_async doesn't block let ret = match poll_async(fut).await { Poll::Ready(ret) => ret, Poll::Pending => break, }; send_content.set(None); let (size, done) = ret?; ws_in_tracker.consumed(size, done); if handler.state() == websocket::State::Connected || handler.state() == websocket::State::PeerClosed { out_credits += size as u32; } } else { break; } } // ABR: function contains read server_accept_handoff(zsess_in, zsess_out).await?; } _ => { // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } } Ok(()) } struct WsReqData { accept: ArrayString, deflate_config: Option<(websocket::PerMessageDeflateConfig, usize)>, } #[allow(clippy::too_many_arguments)] fn server_stream_process_req_header( id: &str, req: &http1::Request<'_, '_>, peer_addr: Option<&SocketAddr>, secure: bool, allow_compression: bool, packet_buf: &RefCell>, instance_id: &str, shared: &StreamSharedData, recv_buf_size: usize, ) -> Result<(zmq::Message, Option), Error> { let mut websocket = false; let mut ws_version = None; let mut ws_key = None; let mut ws_deflate_config = None; for h in req.headers.iter() { if h.name.eq_ignore_ascii_case("Upgrade") && h.value == b"websocket" { websocket = true; } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Version") { ws_version = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { for value in http1::parse_header_value(h.value) { let (name, params) = match value { Ok(v) => v, Err(_) => return Err(Error::InvalidWebSocketRequest), }; match name { "permessage-deflate" => { // the client can present multiple offers. take // the first that works. if none work, it's not // an error. we'll just not use compression if allow_compression && ws_deflate_config.is_none() { if let Ok(config) = websocket::PerMessageDeflateConfig::from_params(params) { if let Ok(resp_config) = config.create_response() { // set the encoded buffer to be 25% the size of the // recv buffer let enc_buf_size = recv_buf_size / 4; ws_deflate_config = Some((resp_config, enc_buf_size)); } } } } name => { debug!("ignoring unsupported websocket extension: {}", name); continue; } } } } } // log request let host = get_host(req.headers); let scheme = if websocket { if secure { "wss" } else { "ws" } } else { if secure { "https" } else { "http" } }; debug!( "server-conn {}: request: {} {}://{}{}", id, req.method, scheme, host, req.uri ); let ws_req_data: Option = if websocket { let accept = match validate_ws_request(req, ws_version, ws_key) { Ok(s) => s, Err(_) => return Err(Error::InvalidWebSocketRequest), }; Some(WsReqData { accept, deflate_config: ws_deflate_config, }) } else { None }; let ids = [zhttppacket::Id { id: id.as_bytes(), seq: Some(shared.out_seq()), }]; let (mode, more) = if websocket { (Mode::WebSocket, false) } else { let more = match req.body_size { http1::BodySize::NoBody => false, http1::BodySize::Known(x) => x > 0, http1::BodySize::Unknown => true, }; (Mode::HttpStream, more) }; let msg = make_zhttp_request( instance_id, &ids, req.method, req.uri, req.headers, b"", more, mode, recv_buf_size as u32, peer_addr, secure, &mut packet_buf.borrow_mut(), )?; shared.inc_out_seq(); Ok((msg, ws_req_data)) } // read request header and prepare outgoing zmq message. // return Ok(None) if client disconnects before providing a complete request header #[allow(clippy::too_many_arguments)] async fn server_stream_read_header<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( id: &str, req_header: server::RequestHeader<'a, 'b, R, W>, peer_addr: Option<&SocketAddr>, secure: bool, allow_compression: bool, packet_buf: &RefCell>, instance_id: &str, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, shared: &StreamSharedData, recv_buf_size: usize, ) -> Result< Option<( zmq::Message, http1::BodySize, Option, server::RequestBody<'a, 'b, R, W>, )>, Error, > { let mut scratch = http1::ParseScratch::::new(); // receive request header // WARNING: the returned req_header must not be dropped and instead must // be consumed by discard_header(). be careful with early returns from // this function and do not use the ?-operator let (req_header, req_body) = { // ABR: discard_while match discard_while(zreceiver, pin!(req_header.recv(&mut scratch))).await { Ok(ret) => ret, Err(Error::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof => return Ok(None), Err(e) => return Err(e), } }; let req_ref = req_header.get(); let result = server_stream_process_req_header( id, &req_ref, peer_addr, secure, allow_compression, packet_buf, instance_id, shared, recv_buf_size, ); let body_size = req_ref.body_size; // whether success or fail, toss req_header so we are able to respond let req_body = req_body.discard_header(req_header); // NOTE: req_header is now consumed and we don't need to worry about it from here let (msg, ws_req_data) = result?; Ok(Some((msg, body_size, ws_req_data, req_body))) } struct StreamRespondProceed<'buf, 'st, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { header: server::ResponseHeader<'buf, 'st, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, zsess_in: ZhttpStreamSessionIn<'zs, 'tr, R2>, ws_config: Option>, } struct StreamRespondWebSocketRejected<'buf, 'st, R: AsyncRead, W: AsyncWrite> { header: server::ResponseHeader<'buf, 'st, R, W>, prepare_body: server::ResponsePrepareBody<'buf, 'st, R, W>, } enum StreamRespond<'buf, 'st, 'zs, 'tr, R: AsyncRead, W: AsyncWrite, R2> { Proceed(StreamRespondProceed<'buf, 'st, 'zs, 'tr, R, W, R2>), WebSocketRejected(StreamRespondWebSocketRejected<'buf, 'st, R, W>), } // consumes resp if successful #[allow(clippy::too_many_arguments)] async fn server_stream_respond<'buf, 'st, 'zs, 'tr, R, W, R1, R2>( id: &'zs str, req: server::Request, resp: &mut Option>, resp_state: &'st mut server::ResponseState<'buf, R, W>, peer_addr: Option<&SocketAddr>, secure: bool, send_buf_size: usize, recv_buf_size: usize, allow_compression: bool, packet_buf: &RefCell>, tmp_buf: &RefCell>, instance_id: &str, zsender: &AsyncLocalSender, zsess_out: &ZhttpStreamSessionOut<'_>, zreceiver: &'zs TrackedAsyncLocalReceiver<'tr, (arena::Rc, usize)>, shared: &'zs StreamSharedData, refresh_stream_timeout: &R1, refresh_session_timeout: &'zs R2, ) -> Result>, Error> where R: AsyncRead, W: AsyncWrite, R1: Fn(), R2: Fn(), { let req_header = req.recv_header(resp.as_mut().unwrap()); // receive request header let result = server_stream_read_header( id, req_header, peer_addr, secure, allow_compression, packet_buf, instance_id, zreceiver, shared, recv_buf_size, ) .await?; let (msg, body_size, ws_req_data, req_body) = match result { Some(ret) => ret, None => return Ok(None), }; refresh_stream_timeout(); // send request message // ABR: discard_while discard_while(zreceiver, pin!(send_msg(zsender, msg))).await?; let mut zsess_in = ZhttpStreamSessionIn::new( id, send_buf_size, ws_req_data.is_some(), zreceiver, shared, refresh_session_timeout, ); // receive any message, in order to get a handler address // ABR: direct read zsess_in.peek_msg().await?; if body_size != http1::BodySize::NoBody { // receive request body and send to handler // ABR: function contains read stream_recv_body( tmp_buf, refresh_stream_timeout, req_body, &mut zsess_in, zsess_out, ) .await?; } // receive response message let zresp = loop { let mut resp_take = resp.take().unwrap(); // ABR: select contains read let ret = select_2( pin!(zsess_in.recv_msg()), pin!(resp_take.fill_recv_buffer()), ) .await; *resp = Some(resp_take); match ret { Select2::R1(ret) => { let zresp = ret?; match zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(_) | zhttppacket::ResponsePacket::Error(_) => break zresp, _ => { // ABR: handle_other handle_other(zresp, &mut zsess_in, zsess_out).await?; } } } Select2::R2(e) => return Err(e.into()), } }; // determine how to respond let rdata = match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => rdata, zhttppacket::ResponsePacket::Error(edata) => { if ws_req_data.is_some() && edata.condition == "rejected" { // send websocket rejection let rdata = edata.rejected_info.as_ref().unwrap(); if rdata.body.len() > recv_buf_size { return Err(Error::WebSocketRejectionTooLarge(recv_buf_size)); } let (header, mut prepare_body) = { let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; for h in rdata.headers.iter() { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { continue; } if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } let headers = &headers[..headers_len]; let mut resp_take = resp.take().unwrap(); match resp_take.prepare_header( rdata.code, rdata.reason, headers, http1::BodySize::Known(rdata.body.len()), resp_state, ) { Ok(ret) => ret, Err(e) => { *resp = Some(resp_take); return Err(e.into()); } } }; // first call can't fail let (size, overflowed) = prepare_body .prepare(rdata.body, true) .expect("infallible prepare call failed"); if overflowed > 0 { debug!("server-conn {}: overflowing {} bytes", id, overflowed); } // we confirmed above that the data will fit in the buffer assert!(size == rdata.body.len()); return Ok(Some(StreamRespond::WebSocketRejected( StreamRespondWebSocketRejected { header, prepare_body, }, ))); } else { // ABR: handle_other return Err(handle_other(zresp, &mut zsess_in, zsess_out) .await .unwrap_err()); } } _ => unreachable!(), // we confirmed the type above }; if rdata.body.len() > recv_buf_size { return Err(Error::BufferExceeded); } // send response header let (header, mut prepare_body) = { let mut headers = [http1::EMPTY_HEADER; HEADERS_MAX]; let mut headers_len = 0; let mut body_size = http1::BodySize::Unknown; for h in rdata.headers.iter() { if ws_req_data.is_some() { // don't send these headers if h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { continue; } } else { if h.name.eq_ignore_ascii_case("Content-Length") { let s = str::from_utf8(h.value)?; let clen: usize = match s.parse() { Ok(clen) => clen, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidInput).into()), }; body_size = http1::BodySize::Known(clen); } } if headers_len >= headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: h.name, value: h.value, }; headers_len += 1; } if body_size == http1::BodySize::Unknown && !rdata.more { body_size = http1::BodySize::Known(rdata.body.len()); } let mut ws_ext = ArrayVec::::new(); if let Some(ws_req_data) = &ws_req_data { let accept_data = &ws_req_data.accept; if headers_len + 4 > headers.len() { return Err(Error::BadMessage); } headers[headers_len] = http1::Header { name: "Upgrade", value: b"websocket", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Connection", value: b"Upgrade", }; headers_len += 1; headers[headers_len] = http1::Header { name: "Sec-WebSocket-Accept", value: accept_data.as_bytes(), }; headers_len += 1; if let Some((config, _)) = &ws_req_data.deflate_config { if write_ws_ext_header_value(config, &mut ws_ext).is_err() { return Err(Error::Compression); } headers[headers_len] = http1::Header { name: "Sec-WebSocket-Extensions", value: ws_ext.as_ref(), }; headers_len += 1; } } let headers = &headers[..headers_len]; let mut resp_take = resp.take().unwrap(); match resp_take.prepare_header(rdata.code, rdata.reason, headers, body_size, resp_state) { Ok(ret) => ret, Err(e) => { *resp = Some(resp_take); return Err(e.into()); } } }; // first call can't fail let (size, overflowed) = prepare_body .prepare(rdata.body, !rdata.more) .expect("infallible prepare call failed"); if overflowed > 0 { debug!("server-conn {}: overflowing {} bytes", id, overflowed); } // we confirmed above that the data will fit in the buffer assert!(size == rdata.body.len()); let ws_config = if let Some(ws_req_data) = ws_req_data { Some(ws_req_data.deflate_config) } else { None }; Ok(Some(StreamRespond::Proceed(StreamRespondProceed { header, prepare_body, ws_config, zsess_in, }))) } // return true if persistent #[allow(clippy::too_many_arguments)] async fn server_stream_handler( id: &str, stream: &mut S, peer_addr: Option<&SocketAddr>, secure: bool, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, messages_max: usize, allow_compression: bool, packet_buf: &RefCell>, tmp_buf: &RefCell>, instance_id: &str, zsender: &AsyncLocalSender, zsender_stream: &AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, shared: &StreamSharedData, refresh_stream_timeout: &R1, refresh_session_timeout: &R2, ) -> Result where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let stream = RefCell::new(stream); let send_buf_size = buf1.capacity(); // for sending to handler let recv_buf_size = buf2.capacity(); // for receiving from handler let zsess_out = ZhttpStreamSessionOut::new(instance_id, id, packet_buf, zsender_stream, shared); let mut resp_state = server::ResponseState::default(); let respond = { let (req, resp) = server::Request::new(io_split(&stream), buf1, buf2); let mut resp = Some(resp); let ret = match server_stream_respond( id, req, &mut resp, &mut resp_state, peer_addr, secure, send_buf_size, recv_buf_size, allow_compression, packet_buf, tmp_buf, instance_id, zsender, &zsess_out, zreceiver, shared, refresh_stream_timeout, refresh_session_timeout, ) .await { Ok(Some(ret)) => ret, Ok(None) => return Ok(false), // no request Err(e) => { // on error, resp is not consumed, so we can use it send_error_response(resp.take().unwrap(), zreceiver, &e).await?; return Err(e); } }; assert!(resp.is_none()); ret }; let (header, mut prepare_body, ws_config, mut zsess_in) = match respond { StreamRespond::Proceed(p) => (p.header, p.prepare_body, p.ws_config, p.zsess_in), StreamRespond::WebSocketRejected(r) => { // ABR: discard_while let header_sent = discard_while(zreceiver, pin!(r.header.send())).await?; let resp_body = header_sent.start_body(r.prepare_body); loop { // send the buffer let send = async { match resp_body.send().await { SendStatus::Complete(finished) => Ok(Some(finished)), SendStatus::EarlyResponse(_) => unreachable!(), // for requests only SendStatus::Partial((), _) => Ok(None), SendStatus::Error((), e) => Err(e), } }; // ABR: discard_while if let Some(_finished) = discard_while(zreceiver, pin!(send)).await? { break; } } return Ok(false); } }; let header_sent = { let mut send = pin!(header.send()); loop { // ABR: select contains read let ret = select_2(send.as_mut(), pin!(zsess_in.recv_msg())).await; match ret { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zresp = ret?; match &zresp.get().get().ptype { zhttppacket::ResponsePacket::Data(rdata) => { let (size, overflowed) = prepare_body.prepare(rdata.body, !rdata.more)?; if overflowed > 0 { debug!("server-conn {}: overflowing {} bytes", id, overflowed); } if size < rdata.body.len() { return Err(Error::BufferExceeded); } } _ => { // ABR: handle_other handle_other(zresp, &mut zsess_in, &zsess_out).await?; } } } } } }; let resp_body = header_sent.start_body(prepare_body); refresh_stream_timeout(); if let Some(deflate_config) = ws_config { // reduce size of future #[allow(clippy::drop_non_drop)] drop(resp_body); // handle as websocket connection // ABR: function contains read stream_websocket( id, stream, buf1, buf2, blocks_max, blocks_avail, messages_max, tmp_buf, refresh_stream_timeout, deflate_config, &mut zsess_in, &zsess_out, ) .await?; Ok(false) } else { // send response body // ABR: function contains read let finished = stream_send_body( refresh_stream_timeout, resp_body, &mut zsess_in, &zsess_out, blocks_max, blocks_avail, ) .await?; Ok(finished.is_persistent()) } } #[allow(clippy::too_many_arguments)] async fn server_stream_connection_inner( token: CancellationToken, cid: &mut ArrayString<32>, cid_provider: &mut P, mut stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, blocks_max: usize, blocks_avail: &Counter, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, stream_timeout_duration: Duration, allow_compression: bool, instance_id: &str, zsender: AsyncLocalSender, zsender_stream: AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, shared: arena::Rc, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let mut buf1 = VecRingBuffer::new(buffer_size, rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, rb_tmp); loop { stream.set_id(cid); // this was originally logged when starting the non-async state // machine, so we'll keep doing that debug!("server-conn {}: assigning id", cid); let reuse = { let stream_timeout = Timeout::new(reactor.now() + stream_timeout_duration); let session_timeout = Timeout::new(reactor.now() + ZHTTP_SESSION_TIMEOUT); let refresh_stream_timeout = || { stream_timeout.set_deadline(reactor.now() + stream_timeout_duration); }; let refresh_session_timeout = || { session_timeout.set_deadline(reactor.now() + ZHTTP_SESSION_TIMEOUT); }; let mut blocks_avail = CounterDec::new(blocks_avail); let handler = pin!(server_stream_handler( cid.as_ref(), &mut stream, peer_addr, secure, &mut buf1, &mut buf2, blocks_max, &mut blocks_avail, messages_max, allow_compression, &packet_buf, &tmp_buf, instance_id, &zsender, &zsender_stream, zreceiver, shared.get(), &refresh_stream_timeout, &refresh_session_timeout, )); let ret = match select_4( handler, stream_timeout.elapsed(), session_timeout.elapsed(), token.cancelled(), ) .await { Select4::R1(ret) => ret, Select4::R2(_) => Err(Error::StreamTimeout), Select4::R3(_) => return Err(Error::SessionTimeout), Select4::R4(_) => return Err(Error::Stopped), }; match ret { Ok(reuse) => reuse, Err(e) => { let handler_caused = matches!( &e, Error::BadMessage | Error::Handler | Error::HandlerCancel ); if !handler_caused { let shared = shared.get(); let msg = if let Some(addr) = shared.to_addr().get() { let id = cid.as_ref(); let mut zreq = zhttppacket::Request::new_cancel(b"", &[]); let ids = [zhttppacket::Id { id: id.as_bytes(), seq: Some(shared.out_seq()), }]; zreq.from = instance_id.as_bytes(); zreq.ids = &ids; zreq.multi = true; let packet_buf = &mut *packet_buf.borrow_mut(); let size = zreq.serialize(packet_buf)?; let msg = zmq::Message::from(&packet_buf[..size]); let addr = match ArrayVec::try_from(addr) { Ok(v) => v, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidInput).into()) } }; Some((addr, msg)) } else { None }; if let Some((addr, msg)) = msg { // best effort let _ = zsender_stream.try_send((addr, msg)); shared.inc_out_seq(); } } return Err(e); } } }; if !reuse { break; } // note: buf1 is not cleared as there may be data to read buf2.clear(); buf2.resize(buffer_size); shared.get().reset(); *cid = cid_provider.get_new_assigned_cid(); } // ABR: discard_while discard_while(zreceiver, pin!(stream.close())).await?; Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn server_stream_connection( token: CancellationToken, mut cid: ArrayString<32>, cid_provider: &mut P, stream: S, peer_addr: Option<&SocketAddr>, secure: bool, buffer_size: usize, blocks_max: usize, blocks_avail: &Counter, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, timeout: Duration, allow_compression: bool, instance_id: &str, zsender: AsyncLocalSender, zsender_stream: AsyncLocalSender<(ArrayVec, zmq::Message)>, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, shared: arena::Rc, ) { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( server_stream_connection_inner( token, &mut cid, cid_provider, stream, peer_addr, secure, buffer_size, blocks_max, blocks_avail, messages_max, rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, instance_id, zsender, zsender_stream, &zreceiver, shared, ), &value_active, ) .await { Ok(()) => debug!("server-conn {}: finished", cid), Err(e) => log!(e.log_level(), "server-conn {}: process error: {:?}", cid, e), } } enum Stream { Plain(std::net::TcpStream), Tls(TlsStream), } impl Read for Stream { fn read(&mut self, buf: &mut [u8]) -> Result { match self { Self::Plain(stream) => stream.read(buf), Self::Tls(stream) => stream.read(buf), } } } enum AsyncStream<'a> { Plain(AsyncTcpStream), Tls(AsyncTlsStream<'a>), } impl AsyncStream<'_> { fn into_inner(self) -> Stream { match self { Self::Plain(stream) => Stream::Plain(stream.into_std()), Self::Tls(stream) => Stream::Tls(stream.into_std()), } } } #[derive(Clone, Eq, Hash, PartialEq)] struct ConnectionPoolKey { addr: std::net::SocketAddr, tls: bool, host: String, } impl ConnectionPoolKey { fn new(addr: std::net::SocketAddr, tls: bool, host: String) -> Self { Self { addr, tls, host } } } pub struct ConnectionPool { inner: Arc>>, thread: Option>, done: Option>, } impl ConnectionPool { pub fn new(capacity: usize) -> Self { let inner = Arc::new(Mutex::new(Pool::::new(capacity))); let (s, r) = mpsc::sync_channel(1); let thread = { let inner = Arc::clone(&inner); thread::Builder::new() .name("connection-pool".into()) .spawn(move || { while let Err(mpsc::RecvTimeoutError::Timeout) = r.recv_timeout(Duration::from_secs(1)) { let now = Instant::now(); while let Some((key, _)) = inner.lock().unwrap().expire(now) { debug!("closing idle connection to {:?} for {}", key.addr, key.host); } } }) .unwrap() }; Self { inner, thread: Some(thread), done: Some(s), } } #[allow(clippy::result_large_err)] fn push( &self, addr: std::net::SocketAddr, tls: bool, host: String, stream: Stream, ttl: Duration, ) -> Result<(), Stream> { self.inner.lock().unwrap().add( ConnectionPoolKey::new(addr, tls, host), stream, Instant::now() + ttl, ) } fn take(&self, addr: std::net::SocketAddr, tls: bool, host: &str) -> Option { let key = ConnectionPoolKey::new(addr, tls, host.to_string()); // take the first connection that returns WouldBlock when attempting a read. // anything else is considered an error and the connection is discarded while let Some(mut stream) = self.inner.lock().unwrap().take(&key) { match stream.read(&mut [0]) { Err(e) if e.kind() == io::ErrorKind::WouldBlock => return Some(stream), _ => {} } debug!( "discarding broken connection to {:?} for {}", key.addr, key.host ); } None } } impl Drop for ConnectionPool { fn drop(&mut self) { self.done = None; let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } fn is_allowed(addr: &IpAddr, deny: &[IpNet]) -> bool { for net in deny { if net.contains(addr) { return false; } } true } #[allow(clippy::too_many_arguments)] async fn client_connect<'a>( log_id: &str, rdata: &zhttppacket::RequestData<'_, '_>, uri: &url::Url, resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, deny: &[IpNet], pool: &ConnectionPool, tls_waker_data: &'a RefWakerData, ) -> Result<(std::net::SocketAddr, bool, AsyncStream<'a>), Error> { let use_tls = ["https", "wss"].contains(&uri.scheme()); let uri_host = match uri.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let default_port = if use_tls { 443 } else { 80 }; let (connect_host, connect_port) = if !rdata.connect_host.is_empty() { (rdata.connect_host, rdata.connect_port) } else { (uri_host, uri.port().unwrap_or(default_port)) }; let resolver = resolver::AsyncResolver::new(resolver); debug!("client-conn {}: resolving: [{}]", log_id, connect_host); let resolver_results = resolver.resolve(connect_host).await?; let mut addrs = ArrayVec::::new(); let mut denied = false; let mut reuse_stream = None; for addr in resolver_results { if !is_allowed(&addr, deny) { denied = true; continue; } let addr = std::net::SocketAddr::new(addr, connect_port); if let Some(stream) = pool.take(addr, use_tls, uri_host) { reuse_stream = Some((addr, stream)); break; } addrs.push(addr); } let (peer_addr, mut stream, is_new) = if let Some((peer_addr, stream)) = reuse_stream { debug!( "client-conn {}: reusing connection to {:?}", log_id, peer_addr, ); let stream = match stream { Stream::Plain(stream) => AsyncStream::Plain(AsyncTcpStream::from_std(stream)), Stream::Tls(stream) => { AsyncStream::Tls(AsyncTlsStream::from_std(stream, tls_waker_data)) } }; (peer_addr, stream, false) } else { if addrs.is_empty() && denied { return Err(Error::PolicyViolation); } debug!("client-conn {}: connecting to one of {:?}", log_id, addrs); let stream = AsyncTcpStream::connect(&addrs).await?; let peer_addr = stream.peer_addr()?; debug!("client-conn {}: connected to {}", log_id, peer_addr); let stream = if use_tls { let host = if rdata.trust_connect_host { connect_host } else { uri_host }; let verify_mode = if rdata.ignore_tls_errors { VerifyMode::None } else { VerifyMode::Full }; let stream = match AsyncTlsStream::connect( host, stream, verify_mode, tls_waker_data, tls_config_cache, ) { Ok(stream) => stream, Err(e) => { debug!("client-conn {}: tls connect error: {}", log_id, e); return Err(Error::Tls); } }; AsyncStream::Tls(stream) } else { AsyncStream::Plain(stream) }; (peer_addr, stream, true) }; if let AsyncStream::Tls(stream) = &mut stream { if stream.inner().set_id(log_id).is_err() { warn!("client-conn {}: log id too long for TlsStream", log_id); return Err(Error::BadRequest); } if is_new { if let Err(e) = stream.ensure_handshake().await { debug!("client-conn {}: tls handshake error: {:?}", log_id, e); return Err(Error::Tls); } } } Ok((peer_addr, use_tls, stream)) } // return Some if fully valid redirect response, else return None. fn check_redirect( method: &str, base_url: &url::Url, resp: &http1::Response, schemes: &[&str], ) -> Option<(url::Url, bool)> { if resp.code >= 300 && resp.code <= 399 { let mut location = None; for h in resp.headers.iter() { if h.name.eq_ignore_ascii_case("Location") { location = Some(h.value); break; } } // must have location header if let Some(s) = location { // must be UTF-8 if let Ok(s) = str::from_utf8(s) { // must be valid URL if let Ok(url) = base_url.join(s) { // must have an acceptable scheme if schemes.contains(&url.scheme()) { let use_get = resp.code >= 301 && resp.code <= 303 && method == "POST"; // all is well! return Some((url, use_get)); } } } } } None } enum ClientHandlerDone { Complete(T, bool), Redirect(bool, url::Url, bool), // rare alloc } impl ClientHandlerDone { fn is_persistent(&self) -> bool { match self { ClientHandlerDone::Complete(_, persistent) => *persistent, ClientHandlerDone::Redirect(persistent, _, _) => *persistent, } } } // return (_, true) if persistent #[allow(clippy::too_many_arguments)] async fn client_req_handler( log_id: &str, id: Option<&[u8]>, stream: &mut S, zreq: &zhttppacket::Request<'_, '_, '_>, method: &str, url: &url::Url, include_body: bool, follow_redirects: bool, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, ) -> Result, Error> where S: AsyncRead + AsyncWrite, { let stream = RefCell::new(stream); let req = client::Request::new(io_split(&stream), buf1, buf2); let req_header = { let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let host_port = &url[url::Position::BeforeHost..url::Position::AfterPort]; let mut headers = ArrayVec::::new(); headers.push(http1::Header { name: "Host", value: host_port.as_bytes(), }); for h in rdata.headers.iter() { if headers.remaining_capacity() == 0 { return Err(Error::BadRequest); } // host comes from the uri if h.name.eq_ignore_ascii_case("Host") { continue; } headers.push(http1::Header { name: h.name, value: h.value, }); } let path = &url[url::Position::BeforePath..]; let body_size = if include_body { body_buf.write_all(rdata.body)?; http1::BodySize::Known(rdata.body.len()) } else { http1::BodySize::NoBody }; req.prepare_header(method, path, &headers, body_size, false, &[], false)? }; let resp = { // send request header let req_body = req_header.send().await?; // send request body loop { // fill the buffer as much as possible let size = req_body.prepare(Buffer::read_buf(body_buf), true)?; body_buf.read_commit(size); // send the buffer match req_body.send().await { SendStatus::Complete(resp) => break resp, SendStatus::EarlyResponse(resp) => { body_buf.clear(); break resp; } SendStatus::Partial((), _) => {} SendStatus::Error((), e) => return Err(e.into()), } } }; assert_eq!(body_buf.len(), 0); // receive response header let mut scratch = http1::ParseScratch::::new(); let (resp, resp_body) = resp.recv_header(&mut scratch).await?; let (zresp, finished) = { let resp_ref = resp.get(); debug!( "client-conn {}: response: {} {}", log_id, resp_ref.code, resp_ref.reason ); // receive response body let finished = { loop { match resp_body.try_recv(body_buf.write_buf())? { RecvStatus::Complete(finished, size) => { body_buf.write_commit(size); break finished; } RecvStatus::Read((), size) => { body_buf.write_commit(size); if size == 0 { return Err(Error::BufferExceeded); } } RecvStatus::NeedBytes(()) => resp_body.add_to_buffer().await?, } } }; if follow_redirects { if let Some((url, use_get)) = check_redirect(method, url, &resp_ref, &["http", "https"]) { let finished = finished.discard_header(resp); debug!("client-conn {}: redirecting to {}", log_id, url); return Ok(ClientHandlerDone::Redirect( finished.is_persistent(), url, use_get, )); } } let mut zheaders = ArrayVec::::new(); for h in resp_ref.headers { zheaders.push(zhttppacket::Header { name: h.name, value: h.value, }); } let rdata = zhttppacket::ResponseData { credits: 0, more: false, code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, content_type: None, body: Buffer::read_buf(body_buf), }; let zresp = make_zhttp_req_response( id, zhttppacket::ResponsePacket::Data(rdata), &mut packet_buf.borrow_mut(), )?; (zresp, finished) }; let finished = finished.discard_header(resp); Ok(ClientHandlerDone::Complete(zresp, finished.is_persistent())) } #[allow(clippy::too_many_arguments)] async fn client_req_connect( log_id: &str, id: Option<&[u8]>, zreq: arena::Rc, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, body_buf: &mut ContiguousBuffer, packet_buf: &RefCell>, deny: &[IpNet], resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, ) -> Result { let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let initial_url = match url::Url::parse(rdata.uri) { Ok(url) => url, Err(_) => return Err(Error::BadRequest), }; // must be an http url if !["http", "https"].contains(&initial_url.scheme()) { return Err(Error::BadRequest); } // must have a method if rdata.method.is_empty() { return Err(Error::BadRequest); } debug!( "client-conn {}: request: {} {}", log_id, rdata.method, rdata.uri, ); let deny = if rdata.ignore_policies { &[] } else { deny }; let mut last_redirect: Option<(url::Url, bool)> = None; let mut redirect_count = 0; let zresp = loop { let (method, url, include_body) = match &last_redirect { Some((url, use_get)) => { let (method, include_body) = if *use_get { ("GET", false) } else { (rdata.method, true) }; (method, url, include_body) } None => (rdata.method, &initial_url, true), }; let url_host = match url.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let tls_waker_data = RefWakerData::new(TlsWaker::new()); let (peer_addr, using_tls, mut stream) = client_connect( log_id, rdata, url, resolver, tls_config_cache, deny, pool, &tls_waker_data, ) .await?; let done = match &mut stream { AsyncStream::Plain(stream) => { client_req_handler( log_id, id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, body_buf, packet_buf, ) .await? } AsyncStream::Tls(stream) => { client_req_handler( log_id, id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, body_buf, packet_buf, ) .await? } }; if done.is_persistent() { if pool .push( peer_addr, using_tls, url_host.to_string(), stream.into_inner(), CONNECTION_POOL_TTL, ) .is_ok() { debug!("client-conn {}: leaving connection intact", log_id); } } match done { ClientHandlerDone::Complete(zresp, _) => break zresp, ClientHandlerDone::Redirect(_, url, mut use_get) => { if redirect_count >= REDIRECTS_MAX { return Err(Error::TooManyRedirects); } redirect_count += 1; if let Some((_, b)) = &last_redirect { use_get = use_get || *b; } last_redirect = Some((url, use_get)); } } }; Ok(zresp) } #[allow(clippy::too_many_arguments)] async fn client_req_connection_inner( token: CancellationToken, log_id: &str, id: Option<&[u8]>, zreq: (MultipartHeader, arena::Rc), buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, deny: &[IpNet], resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, zsender: AsyncLocalSender<(MultipartHeader, zmq::Message)>, ) -> Result<(), Error> { let reactor = Reactor::current().unwrap(); let (zheader, zreq) = zreq; let mut buf1 = VecRingBuffer::new(buffer_size, rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, rb_tmp); let mut body_buf = ContiguousBuffer::new(body_buffer_size); let handler = client_req_connect( log_id, id, zreq, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, deny, resolver, tls_config_cache, pool, ); let timeout = Timeout::new(reactor.now() + timeout); let ret = match select_3(pin!(handler), timeout.elapsed(), token.cancelled()).await { Select3::R1(ret) => ret, Select3::R2(_) => Err(Error::StreamTimeout), Select3::R3(_) => return Err(Error::Stopped), }; match ret { Ok(zresp) => zsender.send((zheader, zresp)).await?, Err(e) => { let zresp = make_zhttp_req_response( id, zhttppacket::ResponsePacket::Error(zhttppacket::ResponseErrorData { condition: e.to_condition(), rejected_info: None, }), &mut packet_buf.borrow_mut(), )?; zsender.send((zheader, zresp)).await?; return Err(e); } } Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn client_req_connection( token: CancellationToken, log_id: &str, id: Option<&[u8]>, zreq: (MultipartHeader, arena::Rc), buffer_size: usize, body_buffer_size: usize, rb_tmp: &Rc, packet_buf: Rc>>, timeout: Duration, deny: &[IpNet], resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, zsender: AsyncLocalSender<(MultipartHeader, zmq::Message)>, ) { match client_req_connection_inner( token, log_id, id, zreq, buffer_size, body_buffer_size, rb_tmp, packet_buf, timeout, deny, resolver, tls_config_cache, pool, zsender, ) .await { Ok(()) => debug!("client-conn {}: finished", log_id), Err(e) => log!( e.log_level(), "client-conn {}: process error: {:?}", log_id, e ), } } // return true if persistent #[allow(clippy::too_many_arguments)] async fn client_stream_handler( log_id: &str, stream: &mut S, zreq: &zhttppacket::Request<'_, '_, '_>, method: &str, url: &url::Url, include_body: bool, mut follow_redirects: bool, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, blocks_max: usize, blocks_avail: &mut CounterDec<'_>, messages_max: usize, allow_compression: bool, tmp_buf: &RefCell>, zsess_in: &mut ZhttpServerStreamSessionIn<'_, '_, R2>, zsess_out: &ZhttpServerStreamSessionOut<'_>, response_received: &mut bool, refresh_stream_timeout: &R1, ) -> Result, Error> where S: AsyncRead + AsyncWrite, R1: Fn(), R2: Fn(), { let stream = RefCell::new(stream); let send_buf_size = buf1.capacity(); // for sending to handler let recv_buf_size = buf2.capacity(); // for receiving from handler let req = client::Request::new(io_split(&stream), buf1, buf2); let (req_header, ws_key, overflow) = { let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let websocket = ["wss", "ws"].contains(&url.scheme()); let host_port = &url[url::Position::BeforeHost..url::Position::AfterPort]; let ws_key = if websocket { Some(gen_ws_key()) } else { None }; if !websocket && rdata.more { follow_redirects = false; } let mut ws_ext = ArrayVec::::new(); let mut headers = ArrayVec::::new(); headers.push(http1::Header { name: "Host", value: host_port.as_bytes(), }); if let Some(ws_key) = &ws_key { headers.push(http1::Header { name: "Upgrade", value: b"websocket", }); headers.push(http1::Header { name: "Connection", value: b"Upgrade", }); headers.push(http1::Header { name: "Sec-WebSocket-Version", value: b"13", }); headers.push(http1::Header { name: "Sec-WebSocket-Key", value: ws_key.as_bytes(), }); if allow_compression { if write_ws_ext_header_value( &websocket::PerMessageDeflateConfig::default(), &mut ws_ext, ) .is_err() { return Err(Error::Compression); } headers.push(http1::Header { name: "Sec-WebSocket-Extensions", value: ws_ext.as_slice(), }); } } let mut body_size = if websocket || !include_body { http1::BodySize::NoBody } else { http1::BodySize::Unknown }; for h in rdata.headers.iter() { // host comes from the uri if h.name.eq_ignore_ascii_case("Host") { continue; } if websocket { // don't send these headers if h.name.eq_ignore_ascii_case("Connection") || h.name.eq_ignore_ascii_case("Upgrade") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Version") || h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { continue; } } else { if h.name.eq_ignore_ascii_case("Content-Length") { let s = str::from_utf8(h.value)?; let clen: usize = match s.parse() { Ok(clen) => clen, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidInput).into()), }; body_size = http1::BodySize::Known(clen); } } if headers.remaining_capacity() == 0 { return Err(Error::BadRequest); } headers.push(http1::Header { name: h.name, value: h.value, }); } let method = if websocket { "GET" } else { method }; let path = &url[url::Position::BeforePath..]; if body_size == http1::BodySize::Unknown && !rdata.more { body_size = http1::BodySize::Known(rdata.body.len()); } let mut overflow = None; let req_header = if websocket { req.prepare_header(method, path, &headers, body_size, true, &[], true)? } else { let (initial_body, end) = if include_body { if rdata.body.len() > recv_buf_size { let body = &rdata.body[..recv_buf_size]; let mut remainder = ContiguousBuffer::new(rdata.body.len() - body.len()); remainder.write_all(&rdata.body[body.len()..])?; debug!( "initial={} overflow={} end={}", body.len(), remainder.len(), !rdata.more ); overflow = Some(Overflow { buf: remainder, end: !rdata.more, }); (body, false) } else { (rdata.body, !rdata.more) } } else { (&[][..], true) }; req.prepare_header(method, path, &headers, body_size, false, initial_body, end)? }; (req_header, ws_key, overflow) }; // send request header let req_body = { let mut send_header = pin!(req_header.send()); loop { // ABR: select contains read let result = select_2(send_header.as_mut(), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } }; refresh_stream_timeout(); // send request body // ABR: function contains read let resp = server_stream_send_body( refresh_stream_timeout, req_body, overflow, recv_buf_size, zsess_in, zsess_out, blocks_max, blocks_avail, ) .await?; // receive response header let (resp_body, ws_config) = { let mut scratch = http1::ParseScratch::::new(); let mut recv_header = pin!(resp.recv_header(&mut scratch)); let (resp, resp_body) = loop { // ABR: select contains read let result = select_2(recv_header.as_mut(), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } }; let ws_config = { let resp_ref = resp.get(); debug!( "client-conn {}: response: {} {}", log_id, resp_ref.code, resp_ref.reason ); loop { // ABR: select contains read let result = select_2(pin!(zsess_out.check_send()), pin!(zsess_in.recv_msg())).await; match result { Select2::R1(()) => break, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } if follow_redirects { let schemes = if ws_key.is_some() { ["ws", "wss"] } else { ["http", "https"] }; if let Some((url, use_get)) = check_redirect(method, url, &resp_ref, &schemes) { // eat response body let finished = loop { let ret = { let mut buf = [0; 4_096]; resp_body.try_recv(&mut buf)? }; match ret { RecvStatus::Complete(finished, _) => break finished, RecvStatus::Read((), size) => { // buf is non-empty so this can never be zero assert!(size > 0); } RecvStatus::NeedBytes(()) => { let mut add_to_buffer = pin!(resp_body.add_to_buffer()); loop { // ABR: select contains read let result = select_2(add_to_buffer.as_mut(), pin!(zsess_in.recv_msg())) .await; match result { Select2::R1(ret) => { ret?; break; } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } } }; let finished = finished.discard_header(resp); debug!("client-conn {}: redirecting to {}", log_id, url); return Ok(ClientHandlerDone::Redirect( finished.is_persistent(), url, use_get, )); } } let mut zheaders = ArrayVec::::new(); let mut ws_accept = None; let mut ws_deflate_config = None; for h in resp_ref.headers { if ws_key.is_some() { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Accept") { ws_accept = Some(h.value); } if h.name.eq_ignore_ascii_case("Sec-WebSocket-Extensions") { for value in http1::parse_header_value(h.value) { let (name, params) = match value { Ok(v) => v, Err(_) => return Err(Error::InvalidWebSocketResponse), }; match name { "permessage-deflate" => { // we must have offered, and server must // provide one response at most if !allow_compression || ws_deflate_config.is_some() { return Err(Error::InvalidWebSocketResponse); } if let Ok(config) = websocket::PerMessageDeflateConfig::from_params(params) { if config.check_response().is_ok() { // set the encoded buffer to be 25% the size of the // recv buffer let enc_buf_size = recv_buf_size / 4; ws_deflate_config = Some((config, enc_buf_size)); } } } name => { debug!("ignoring unsupported websocket extension: {}", name); continue; } } } } } zheaders.push(zhttppacket::Header { name: h.name, value: h.value, }); } if let Some(ws_key) = &ws_key { if resp_ref.code == 101 { if validate_ws_response(ws_key.as_bytes(), ws_accept).is_err() { return Err(Error::InvalidWebSocketResponse); } } else { // websocket request rejected // we need to allocate to collect the response body, // since buf1 holds bytes read from the socket, and // resp is using buf2's inner buffer let mut body_buf = ContiguousBuffer::new(send_buf_size); // receive response body let finished = loop { match resp_body.try_recv(body_buf.write_buf())? { RecvStatus::Complete(finished, size) => { body_buf.write_commit(size); break finished; } RecvStatus::Read((), size) => { body_buf.write_commit(size); if size == 0 { return Err(Error::WebSocketRejectionTooLarge(send_buf_size)); } } RecvStatus::NeedBytes(()) => { let mut add_to_buffer = pin!(resp_body.add_to_buffer()); loop { // ABR: select contains read let result = select_2(add_to_buffer.as_mut(), pin!(zsess_in.recv_msg())) .await; match result { Select2::R1(ret) => { ret?; break; } Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, zsess_in, zsess_out).await?; } } } } } }; let edata = zhttppacket::ResponseErrorData { condition: "rejected", rejected_info: Some(zhttppacket::RejectedInfo { code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, body: body_buf.read_buf(), }), }; let zresp = zhttppacket::Response::new_error(b"", &[], edata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; drop(zheaders); let finished = finished.discard_header(resp); return Ok(ClientHandlerDone::Complete((), finished.is_persistent())); } } let credits = if ws_key.is_some() { // for websockets, provide credits when sending response to handler recv_buf_size as u32 } else { // for http, it is not necessary to provide credits when responding 0 }; let rdata = zhttppacket::ResponseData { credits, more: ws_key.is_none(), code: resp_ref.code, reason: resp_ref.reason, headers: &zheaders, content_type: None, body: b"", }; let zresp = zhttppacket::Response::new_data(b"", &[], rdata); // check_send just finished, so this should succeed zsess_out.try_send_msg(zresp)?; if ws_key.is_some() { Some(ws_deflate_config) } else { None } }; let resp_body = resp_body.discard_header(resp)?; (resp_body, ws_config) }; *response_received = true; if let Some(deflate_config) = ws_config { // handle as websocket connection // ABR: function contains read server_stream_websocket( log_id, stream, buf1, buf2, blocks_max, blocks_avail, messages_max, tmp_buf, refresh_stream_timeout, deflate_config, zsess_in, zsess_out, ) .await?; Ok(ClientHandlerDone::Complete((), false)) } else { // receive response body // ABR: function contains read let finished = server_stream_recv_body( tmp_buf, refresh_stream_timeout, resp_body, zsess_in, zsess_out, ) .await?; Ok(ClientHandlerDone::Complete((), finished.is_persistent())) } } #[allow(clippy::too_many_arguments)] async fn client_stream_connect( log_id: &str, id: &[u8], zreq: arena::Rc, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, buffer_size: usize, blocks_max: usize, blocks_avail: &Counter, messages_max: usize, allow_compression: bool, packet_buf: &RefCell>, tmp_buf: &RefCell>, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, zsender: &AsyncLocalSender<(Option>, zmq::Message)>, shared: &StreamSharedData, enable_routing: &E, response_received: &mut bool, refresh_stream_timeout: &R1, refresh_session_timeout: &R2, ) -> Result<(), Error> where E: Fn(), R1: Fn(), R2: Fn(), { let zreq = zreq.get().get(); // assign address so we can send replies let addr: ArrayVec = match ArrayVec::try_from(zreq.from) { Ok(v) => v, Err(_) => return Err(Error::BadRequest), }; shared.set_to_addr(Some(addr)); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(data) => data, _ => return Err(Error::BadRequest), }; let initial_url = match url::Url::parse(rdata.uri) { Ok(url) => url, Err(_) => return Err(Error::BadRequest), }; // must be an http or websocket url if !["http", "https", "ws", "wss"].contains(&initial_url.scheme()) { return Err(Error::BadRequest); } // http requests must have a method if ["http", "https"].contains(&initial_url.scheme()) && rdata.method.is_empty() { return Err(Error::BadRequest); } let method = if !rdata.method.is_empty() { rdata.method } else { "_" }; shared.set_router_resp(rdata.router_resp); debug!("client-conn {}: request: {} {}", log_id, method, rdata.uri); let zsess_out = ZhttpServerStreamSessionOut::new(instance_id, id, packet_buf, zsender, shared); // ack request // ABR: discard_while server_discard_while( zreceiver, pin!(async { zsess_out.check_send().await; Ok(()) }), ) .await?; zsess_out.try_send_msg(zhttppacket::Response::new_keep_alive(b"", &[]))?; let mut zsess_in = ZhttpServerStreamSessionIn::new( log_id, id, rdata.credits, zreceiver, shared, refresh_session_timeout, ); // allow receiving subsequent messages enable_routing(); let deny = if rdata.ignore_policies { &[] } else { deny }; let mut last_redirect: Option<(url::Url, bool)> = None; let mut redirect_count = 0; loop { let (method, url, include_body) = match &last_redirect { Some((url, use_get)) => { let (method, include_body) = if *use_get { ("GET", false) } else { (rdata.method, true) }; (method, url, include_body) } None => (rdata.method, &initial_url, true), }; let url_host = match url.host_str() { Some(s) => s, None => return Err(Error::BadRequest), }; let tls_waker_data = RefWakerData::new(TlsWaker::new()); let (peer_addr, using_tls, mut stream) = { let mut client_connect = pin!(client_connect( log_id, rdata, url, resolver, tls_config_cache, deny, pool, &tls_waker_data, )); loop { // ABR: select contains read let ret = select_2(client_connect.as_mut(), pin!(zsess_in.recv_msg())).await; match ret { Select2::R1(ret) => break ret?, Select2::R2(ret) => { let zreq = ret?; // ABR: handle_other server_handle_other(zreq, &mut zsess_in, &zsess_out).await?; } } } }; let mut blocks_avail = CounterDec::new(blocks_avail); let done = match &mut stream { AsyncStream::Plain(stream) => { client_stream_handler( log_id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, blocks_max, &mut blocks_avail, messages_max, allow_compression, tmp_buf, &mut zsess_in, &zsess_out, response_received, refresh_stream_timeout, ) .await? } AsyncStream::Tls(stream) => { client_stream_handler( log_id, stream, zreq, method, url, include_body, rdata.follow_redirects, buf1, buf2, blocks_max, &mut blocks_avail, messages_max, allow_compression, tmp_buf, &mut zsess_in, &zsess_out, response_received, refresh_stream_timeout, ) .await? } }; if done.is_persistent() { buf2.resize(buffer_size); if pool .push( peer_addr, using_tls, url_host.to_string(), stream.into_inner(), CONNECTION_POOL_TTL, ) .is_ok() { debug!("client-conn {}: leaving connection intact", log_id); } } match done { ClientHandlerDone::Complete((), _) => break, ClientHandlerDone::Redirect(_, url, mut use_get) => { if redirect_count >= REDIRECTS_MAX { return Err(Error::TooManyRedirects); } redirect_count += 1; if let Some((_, b)) = &last_redirect { use_get = use_get || *b; } last_redirect = Some((url, use_get)); } } } Ok(()) } #[allow(clippy::too_many_arguments)] async fn client_stream_connection_inner( token: CancellationToken, log_id: &str, id: &[u8], zreq: arena::Rc, buffer_size: usize, blocks_max: usize, blocks_avail: &Counter, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, stream_timeout_duration: Duration, allow_compression: bool, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, zreceiver: &TrackedAsyncLocalReceiver<'_, (arena::Rc, usize)>, zsender: AsyncLocalSender<(Option>, zmq::Message)>, shared: arena::Rc, enable_routing: &E, ) -> Result<(), Error> where E: Fn(), { let reactor = Reactor::current().unwrap(); let mut buf1 = VecRingBuffer::new(buffer_size, rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, rb_tmp); let stream_timeout = Timeout::new(reactor.now() + stream_timeout_duration); let session_timeout = Timeout::new(reactor.now() + ZHTTP_SESSION_TIMEOUT); let refresh_stream_timeout = || { stream_timeout.set_deadline(reactor.now() + stream_timeout_duration); }; let refresh_session_timeout = || { session_timeout.set_deadline(reactor.now() + ZHTTP_SESSION_TIMEOUT); }; let mut response_received = false; let ret = { let handler = pin!(client_stream_connect( log_id, id, zreq, &mut buf1, &mut buf2, buffer_size, blocks_max, blocks_avail, messages_max, allow_compression, &packet_buf, &tmp_buf, deny, instance_id, resolver, tls_config_cache, pool, zreceiver, &zsender, shared.get(), enable_routing, &mut response_received, &refresh_stream_timeout, &refresh_session_timeout, )); match select_4( handler, stream_timeout.elapsed(), session_timeout.elapsed(), token.cancelled(), ) .await { Select4::R1(ret) => ret, Select4::R2(_) => Err(Error::StreamTimeout), Select4::R3(_) => return Err(Error::SessionTimeout), Select4::R4(_) => return Err(Error::Stopped), } }; match ret { Ok(()) => {} Err(e) => { let handler_caused = matches!( &e, Error::BadMessage | Error::Handler | Error::HandlerCancel ); if !handler_caused { let shared = shared.get(); let resp = if let Some(addr) = shared.to_addr().get() { let mut zresp = if response_received { zhttppacket::Response::new_cancel(b"", &[]) } else { zhttppacket::Response::new_error( b"", &[], zhttppacket::ResponseErrorData { condition: e.to_condition(), rejected_info: None, }, ) }; let ids = [zhttppacket::Id { id, seq: Some(shared.out_seq()), }]; zresp.from = instance_id.as_bytes(); zresp.ids = &ids; zresp.multi = true; let packet_buf = &mut *packet_buf.borrow_mut(); Some(make_zhttp_response( addr, shared.router_resp(), zresp, packet_buf, )?) } else { None }; if let Some((addr, msg)) = resp { // best effort let _ = zsender.try_send((addr, msg)); shared.inc_out_seq(); } } return Err(e); } } Ok(()) } #[allow(clippy::too_many_arguments)] pub async fn client_stream_connection( token: CancellationToken, log_id: &str, id: &[u8], zreq: arena::Rc, buffer_size: usize, blocks_max: usize, blocks_avail: &Counter, messages_max: usize, rb_tmp: &Rc, packet_buf: Rc>>, tmp_buf: Rc>>, timeout: Duration, allow_compression: bool, deny: &[IpNet], instance_id: &str, resolver: &resolver::Resolver, tls_config_cache: &TlsConfigCache, pool: &ConnectionPool, zreceiver: AsyncLocalReceiver<(arena::Rc, usize)>, zsender: AsyncLocalSender<(Option>, zmq::Message)>, shared: arena::Rc, enable_routing: &E, ) where E: Fn(), { let value_active = TrackFlag::default(); let zreceiver = TrackedAsyncLocalReceiver::new(zreceiver, &value_active); match track_future( client_stream_connection_inner( token, log_id, id, zreq, buffer_size, blocks_max, blocks_avail, messages_max, rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, deny, instance_id, resolver, tls_config_cache, pool, &zreceiver, zsender, shared, enable_routing, ), &value_active, ) .await { Ok(()) => debug!("client-conn {}: finished", log_id), Err(e) => log!( e.log_level(), "client-conn {}: process error: {:?}", log_id, e ), } } pub mod testutil { use super::*; use crate::core::buffer::TmpBuffer; use crate::core::channel; use crate::core::waker; use std::fmt; use std::future::Future; use std::io::Read; use std::rc::Rc; use std::sync::Arc; use std::task::{Context, Poll, Waker}; use std::time::Instant; pub struct NoopWaker {} #[allow(clippy::new_without_default)] impl NoopWaker { pub fn new() -> Self { Self {} } pub fn into_std(self: Rc) -> Waker { waker::into_std(self) } } impl waker::RcWake for NoopWaker { fn wake(self: Rc) {} } pub struct StepExecutor<'a, F> { reactor: &'a Reactor, fut: Pin>, } impl<'a, F> StepExecutor<'a, F> where F: Future, { pub fn new(reactor: &'a Reactor, fut: F) -> Self { Self { reactor, fut: Box::pin(fut), } } pub fn step(&mut self) -> Poll { self.reactor.poll_nonblocking(self.reactor.now()).unwrap(); let waker = Rc::new(NoopWaker::new()).into_std(); let mut cx = Context::from_waker(&waker); self.fut.as_mut().poll(&mut cx) } pub fn advance_time(&mut self, now: Instant) { self.reactor.poll_nonblocking(now).unwrap(); } } #[track_caller] pub fn check_poll(p: Poll>) -> Option where E: fmt::Debug, { match p { Poll::Ready(v) => match v { Ok(t) => Some(t), Err(e) => panic!("check_poll error: {:?}", e), }, Poll::Pending => None, } } #[track_caller] pub fn check_poll_err(p: Poll>) -> Option where T: fmt::Debug, { match p { Poll::Ready(v) => match v { Ok(t) => panic!("check_poll_err ok: {:?}", t), Err(e) => Some(e), }, Poll::Pending => None, } } pub struct FakeSock { inbuf: Vec, outbuf: Vec, out_allow: usize, closed: bool, } #[allow(clippy::new_without_default)] impl FakeSock { pub fn new() -> Self { Self { inbuf: Vec::with_capacity(16384), outbuf: Vec::with_capacity(16384), out_allow: 0, closed: false, } } pub fn add_readable(&mut self, buf: &[u8]) { self.inbuf.extend_from_slice(buf); } pub fn take_writable(&mut self) -> Vec { mem::take(&mut self.outbuf) } pub fn allow_write(&mut self, size: usize) { self.out_allow += size; } pub fn clear_write_allowed(&mut self) { self.out_allow = 0; } pub fn close(&mut self) { self.closed = true; } } impl Read for FakeSock { fn read(&mut self, buf: &mut [u8]) -> Result { if self.closed { return Ok(0); } if self.inbuf.is_empty() { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.inbuf.len()); buf[..size].copy_from_slice(&self.inbuf[..size]); let mut rest = self.inbuf.split_off(size); mem::swap(&mut self.inbuf, &mut rest); Ok(size) } } impl Write for FakeSock { fn write(&mut self, buf: &[u8]) -> Result { if self.closed { return Err(io::Error::from(io::ErrorKind::BrokenPipe)); } if !buf.is_empty() && self.out_allow == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf); self.out_allow -= size; Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { if self.closed { return Err(io::Error::from(io::ErrorKind::BrokenPipe)); } let mut total = 0; for buf in bufs { if !buf.is_empty() && self.out_allow == 0 { if total == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } break; } let size = cmp::min(buf.len(), self.out_allow); let buf = &buf[..size]; self.outbuf.extend_from_slice(buf.as_ref()); self.out_allow -= size; total += buf.len(); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct AsyncFakeSock { pub inner: Rc>, } impl AsyncFakeSock { pub fn new(sock: Rc>) -> Self { Self { inner: sock } } } impl AsyncRead for AsyncFakeSock { fn poll_read( self: Pin<&mut Self>, _cx: &mut Context, buf: &mut [u8], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.read(buf) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn cancel(&mut self) {} } impl AsyncWrite for AsyncFakeSock { fn poll_write( self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.write(buf) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn poll_write_vectored( self: Pin<&mut Self>, _cx: &mut Context, bufs: &[io::IoSlice], ) -> Poll> { let inner = &mut *self.inner.borrow_mut(); match inner.write_vectored(bufs) { Ok(usize) => Poll::Ready(Ok(usize)), Err(e) if e.kind() == io::ErrorKind::WouldBlock => Poll::Pending, Err(e) => Poll::Ready(Err(e)), } } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { true } fn cancel(&mut self) {} } impl Identify for AsyncFakeSock { fn set_id(&mut self, _id: &str) { // do nothing } } pub struct SimpleCidProvider { pub cid: ArrayString<32>, } impl CidProvider for SimpleCidProvider { fn get_new_assigned_cid(&mut self) -> ArrayString<32> { self.cid } } #[allow(clippy::too_many_arguments)] async fn server_req_handler_fut( sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, packet_buf: Rc>>, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, body_buf: &mut ContiguousBuffer, ) -> Result { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); server_req_handler( "1", &mut sock, None, secure, buf1, buf2, body_buf, &packet_buf, &s_from_conn, &r_to_conn, ) .await } pub struct BenchServerReqHandlerArgs { sock: Rc>, buf1: VecRingBuffer, buf2: VecRingBuffer, body_buf: ContiguousBuffer, } pub struct BenchServerReqHandler { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerReqHandler { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), } } pub fn init(&self) -> BenchServerReqHandlerArgs { let buffer_size = 1024; BenchServerReqHandlerArgs { sock: Rc::new(RefCell::new(FakeSock::new())), buf1: VecRingBuffer::new(buffer_size, &self.rb_tmp), buf2: VecRingBuffer::new(buffer_size, &self.rb_tmp), body_buf: ContiguousBuffer::new(buffer_size), } } pub fn run(&self, args: &mut BenchServerReqHandlerArgs) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let packet_buf = &self.packet_buf; let sock = &args.sock; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = args.sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_handler_fut( sock, false, s_from_conn, r_to_conn, packet_buf.clone(), &mut args.buf1, &mut args.buf2, &mut args.body_buf, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(false)); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } async fn server_req_connection_inner_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, rb_tmp: Rc, packet_buf: Rc>>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let timeout = Duration::from_millis(5_000); server_req_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, buffer_size, &rb_tmp, packet_buf, timeout, s_from_conn, &r_to_conn, ) .await } pub struct BenchServerReqConnection { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerReqConnection { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), } } pub fn init(&self) -> Rc> { Rc::new(RefCell::new(FakeSock::new())) } pub fn run(&self, sock: &Rc>) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let rb_tmp = &self.rb_tmp; let packet_buf = &self.packet_buf; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_connection_inner_fut( token, sock, false, s_from_conn, r_to_conn, rb_tmp.clone(), packet_buf.clone(), ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } #[allow(clippy::too_many_arguments)] async fn server_stream_handler_fut( sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, packet_buf: Rc>>, tmp_buf: Rc>>, buf1: &mut VecRingBuffer, buf2: &mut VecRingBuffer, shared: arena::Rc, ) -> Result { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); server_stream_handler( "1", &mut sock, None, secure, buf1, buf2, 2, &mut CounterDec::new(&Counter::new(0)), 10, false, &packet_buf, &tmp_buf, "test", &s_from_conn, &s_stream_from_conn, &r_to_conn, shared.get(), &|| {}, &|| {}, ) .await } pub struct BenchServerStreamHandlerArgs { sock: Rc>, buf1: VecRingBuffer, buf2: VecRingBuffer, } pub struct BenchServerStreamHandler { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, shared_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerStreamHandler { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), shared_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), tmp_buf: Rc::new(RefCell::new(vec![0; 1024])), } } pub fn init(&self) -> BenchServerStreamHandlerArgs { let buffer_size = 1024; BenchServerStreamHandlerArgs { sock: Rc::new(RefCell::new(FakeSock::new())), buf1: VecRingBuffer::new(buffer_size, &self.rb_tmp), buf2: VecRingBuffer::new(buffer_size, &self.rb_tmp), } } pub fn run(&self, args: &mut BenchServerStreamHandlerArgs) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let shared_mem = &self.shared_mem; let packet_buf = &self.packet_buf; let tmp_buf = &self.tmp_buf; let sock = &args.sock; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = args.sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), shared_mem).unwrap(); server_stream_handler_fut( sock, false, s_from_conn, s_stream_from_conn, r_to_conn, packet_buf.clone(), tmp_buf.clone(), &mut args.buf1, &mut args.buf2, shared, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), Some(true)); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } #[allow(clippy::too_many_arguments)] async fn server_stream_connection_inner_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, shared: arena::Rc, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); let buffer_size = 1024; let timeout = Duration::from_millis(5_000); server_stream_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, 2, &Counter::new(0), 10, &rb_tmp, packet_buf, tmp_buf, timeout, false, "test", s_from_conn, s_stream_from_conn, &r_to_conn, shared, ) .await } pub struct BenchServerStreamConnection { reactor: Reactor, msg_mem: Arc>, scratch_mem: Rc>>>, resp_mem: Rc>, shared_mem: Rc>, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } #[allow(clippy::new_without_default)] impl BenchServerStreamConnection { pub fn new() -> Self { Self { reactor: Reactor::new(100), msg_mem: Arc::new(arena::ArcMemory::new(1)), scratch_mem: Rc::new(arena::RcMemory::new(1)), resp_mem: Rc::new(arena::RcMemory::new(1)), shared_mem: Rc::new(arena::RcMemory::new(1)), rb_tmp: Rc::new(TmpBuffer::new(1024)), packet_buf: Rc::new(RefCell::new(vec![0; 2048])), tmp_buf: Rc::new(RefCell::new(vec![0; 1024])), } } pub fn init(&self) -> Rc> { Rc::new(RefCell::new(FakeSock::new())) } pub fn run(&self, sock: &Rc>) { let reactor = &self.reactor; let msg_mem = &self.msg_mem; let scratch_mem = &self.scratch_mem; let resp_mem = &self.resp_mem; let shared_mem = &self.shared_mem; let rb_tmp = &self.rb_tmp; let packet_buf = &self.packet_buf; let tmp_buf = &self.tmp_buf; let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), shared_mem).unwrap(); server_stream_connection_inner_fut( token, sock, false, s_from_conn, s_stream_from_conn, r_to_conn, rb_tmp.clone(), packet_buf.clone(), tmp_buf.clone(), shared, ) }; let mut executor = StepExecutor::new(reactor, fut); assert_eq!(check_poll(executor.step()), None); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); // read message let _ = r_from_conn.try_recv().unwrap(); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), scratch_mem) .unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } } } #[cfg(test)] mod tests { use super::testutil::*; use super::*; use crate::connmgr::websocket::Decoder; use crate::core::buffer::TmpBuffer; use crate::core::channel; use std::rc::Rc; use std::sync::Arc; use std::task::Poll; use std::time::Instant; use test_log::test; #[test] fn ws_ext_header() { let config = websocket::PerMessageDeflateConfig::default(); let mut dest = ArrayVec::::new(); write_ws_ext_header_value(&config, &mut dest).unwrap(); let expected = "permessage-deflate"; assert_eq!(str::from_utf8(&dest).unwrap(), expected); let mut config = websocket::PerMessageDeflateConfig::default(); config.client_no_context_takeover = true; let mut dest = ArrayVec::::new(); write_ws_ext_header_value(&config, &mut dest).unwrap(); let expected = "permessage-deflate; client_no_context_takeover"; assert_eq!(str::from_utf8(&dest).unwrap(), expected); } #[test] fn message_tracker() { let mut t = MessageTracker::new(2); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), None); t.start(websocket::OPCODE_TEXT).unwrap(); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, false))); t.extend(5); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 5, false))); t.consumed(2, false); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, false))); t.done(); assert_eq!(t.in_progress(), false); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, true))); t.start(websocket::OPCODE_TEXT).unwrap(); assert_eq!(t.in_progress(), true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 3, true))); t.consumed(3, false); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, true))); t.consumed(0, true); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, false))); t.done(); assert_eq!(t.current(), Some((websocket::OPCODE_TEXT, 0, true))); t.consumed(0, true); assert_eq!(t.current(), None); for _ in 0..t.items.capacity() { t.start(websocket::OPCODE_TEXT).unwrap(); t.done(); } let r = t.start(websocket::OPCODE_TEXT); assert!(r.is_err()); } async fn server_req_fut( token: CancellationToken, sock: Rc>, secure: bool, s_from_conn: channel::LocalSender, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let packet_buf = Rc::new(RefCell::new(vec![0; 2048])); let timeout = Duration::from_millis(5_000); server_req_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, buffer_size, &rb_tmp, packet_buf, timeout, s_from_conn, &r_to_conn, ) .await } #[test] fn server_req_without_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T148:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,23:http://example.com/path,7:headers,52:22:4:Host,11:exa", "mple.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_with_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "Connection: close\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T191:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,4:POST,3:u", "ri,23:http://example.com/path,7:headers,78:22:4:Host,11:ex", "ample.com,]22:14:Content-Length,1:6,]22:10:Connection,5:cl", "ose,]]4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_timeout() { let now = Instant::now(); let reactor = Reactor::new_with_time(100, now); let sock = Rc::new(RefCell::new(FakeSock::new())); let (_s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, _r_from_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); executor.advance_time(now + Duration::from_millis(5_000)); match executor.step() { Poll::Ready(Err(Error::StreamTimeout)) => {} _ => panic!("unexpected state"), } } #[test] fn server_req_pipeline() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, false, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path1 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", "GET /path2 HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n", ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path1,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T123:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:http://example.com/path2,7:headers,26:22:4:Host,11:ex", "ample.com,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_req_secure() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_req_fut(token, sock, true, s_from_conn, r_to_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T149:2:id,1:1,3:ext,15:5:multi,4:true!}6:method,3:GET,3:ur", "i,24:https://example.com/path,7:headers,52:22:4:Host,11:ex", "ample.com,]22:10:Connection,5:close,]]}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T100:2:id,1:1,4:code,3:200#6:reason,2:OK,7:h", "eaders,34:30:12:Content-Type,10:text/plain,]]4:body,6:hell", "o\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } async fn server_stream_fut( token: CancellationToken, sock: Rc>, secure: bool, allow_compression: bool, s_from_conn: channel::LocalSender, s_stream_from_conn: channel::LocalSender<(ArrayVec, zmq::Message)>, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, ) -> Result<(), Error> { let mut cid = ArrayString::from_str("1").unwrap(); let mut cid_provider = SimpleCidProvider { cid }; let sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let s_stream_from_conn = AsyncLocalSender::new(s_stream_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(1024)); let packet_buf = Rc::new(RefCell::new(vec![0; 2048])); let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let timeout = Duration::from_millis(5_000); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); server_stream_connection_inner( token, &mut cid, &mut cid_provider, sock, None, secure, buffer_size, 3, &Counter::new(1), 10, &rb_tmp, packet_buf, tmp_buf, timeout, allow_compression, "test", s_from_conn, s_stream_from_conn, &r_to_conn, shared, ) .await } #[test] fn server_stream_without_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T201:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_with_body() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T242:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!11:rout", "er-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!("T69:7:credits,4:1024#3:seq,1:0#2:id,1:1,4:from,7:handler,4:type,6:credit,}",); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T74:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T127:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:1#4:from,7:handler,4:code,3:200#", "4:body,6:hello\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_chunked() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T201:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T125:4:more,4:true!2:id,1:1,6:reason,2:OK,7:headers,34:30:", "12:Content-Type,10:text/plain,]]3:seq,1:0#4:from,7:handler", ",4:code,3:200#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let msg = concat!("T52:3:seq,1:1#2:id,1:1,4:from,7:handler,4:body,6:hello\n,}"); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); // connection reusable assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\n", "hello\n", "\r\n", "0\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_early_response() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", ) .as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T242:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,4:POST,3:uri,23:http://example.com/path,7:he", "aders,52:22:4:Host,11:example.com,]22:14:Content-Length,1:", "6,]]7:credits,4:1024#4:more,4:true!6:stream,4:true!11:rout", "er-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T150:2:id,1:1,6:reason,11:Bad Request,7:headers,34:30:12:C", "ontent-Type,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:c", "ode,3:400#4:body,18:stopping this now\n,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), Some(())); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 400 Bad Request\r\n", "Content-Type: text/plain\r\n", "Connection: close\r\n", "Content-Length: 18\r\n", "\r\n", "stopping this now\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); } #[test] fn server_stream_expand_write_buffer() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); let msg = r_from_conn.try_recv().unwrap(); // no other messages assert!(r_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "T201:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); sock.borrow_mut().allow_write(1024); let msg = concat!( "T125:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:more,4:true!}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // no more messages yet assert!(r_stream_from_conn.try_recv().is_err()); sock.borrow_mut().clear_write_allowed(); let body = vec![0; 1024]; let mut rdata = zhttppacket::ResponseData::new(); rdata.body = body.as_slice(); rdata.more = true; let resp = zhttppacket::Response::new_data( b"handler", &[zhttppacket::Id { id: b"1", seq: Some(1), }], rdata, ); let mut buf = [0; 2048]; let size = resp.serialize(&mut buf).unwrap(); let msg = zmq::Message::from(&buf[..size]); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), None); // read message let (_, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert!(r_stream_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn server_stream_disconnect() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let resp_mem = Rc::new(arena::RcMemory::new(1)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, _r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the connection's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); let req_data = concat!("GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); sock.borrow_mut().add_readable(req_data); // connection won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now connection will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T201:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,23:http://example.com/path,7:hea", "ders,26:22:4:Host,11:example.com,]]7:credits,4:1024#6:stre", "am,4:true!11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T125:2:id,1:1,6:reason,2:OK,7:headers,34:30:12:Content-Typ", "e,10:text/plain,]]3:seq,1:0#4:from,7:handler,4:code,3:200#", "4:more,4:true!}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); // data received so far let expected = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); sock.borrow_mut().close(); // closed, task should error out let e = check_poll_err(executor.step()).unwrap(); assert!(matches!(e, Error::CoreHttp(CoreHttpError::Io(e)) if e.kind() == io::ErrorKind::UnexpectedEof )); let data = sock.borrow_mut().take_writable(); assert!(data.is_empty()); } #[test] fn server_websocket() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T277:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,21:ws://example.com/path,7:heade", "rs,119:22:4:Host,11:example.com,]22:7:Upgrade,9:websocket,", "]30:21:Sec-WebSocket-Version,2:13,]29:17:Sec-WebSocket-Key", ",5:abcde,]]7:credits,4:1024#11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T98:2:id,1:1,6:reason,19:Switching Protocols,3:seq,1:0#4:f", "rom,7:handler,4:code,3:101#7:credits,4:1024#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = b"hello"; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T96:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = &data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; assert_eq!(str::from_utf8(content).unwrap(), "world"); } #[test] fn server_websocket_with_deflate() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_stream_fut( token, sock, false, true, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T331:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,21:ws://example.com/path,7:heade", "rs,173:22:4:Host,11:example.com,]22:7:Upgrade,9:websocket,", "]30:21:Sec-WebSocket-Version,2:13,]29:17:Sec-WebSocket-Key", ",5:abcde,]50:24:Sec-WebSocket-Extensions,18:permessage-def", "late,]]7:credits,4:1024#11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T98:2:id,1:1,6:reason,19:Switching Protocols,3:seq,1:0#4:f", "rom,7:handler,4:code,3:101#7:credits,4:1024#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = { let src = b"hello"; let mut enc = websocket::DeflateEncoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = enc.encode(src, true, &mut dest).unwrap(); assert_eq!(read, 5); assert_eq!(output_end, true); dest.truncate(written); dest }; let size = websocket::write_header( true, true, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(&body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); assert_eq!(addr.as_ref(), "handler".as_bytes()); let buf = &msg[..]; let expected = concat!( "T96:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = { let src = &data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; let mut dec = websocket::DeflateDecoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = dec.decode(src, true, &mut dest).unwrap(); assert_eq!(read, src.len()); assert_eq!(output_end, true); dest.truncate(written); dest }; assert_eq!(str::from_utf8(&content).unwrap(), "world"); } #[test] fn server_websocket_expand_write_buffer() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let resp_mem = Rc::new(arena::RcMemory::new(2)); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (s_stream_from_conn, r_stream_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let (_cancel, token) = CancellationToken::new(&reactor.local_registration_memory()); let fut = { let sock = sock.clone(); server_stream_fut( token, sock, false, false, s_from_conn, s_stream_from_conn, r_to_conn, ) }; let mut executor = StepExecutor::new(&reactor, fut); let req_data = concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n" ) .as_bytes(); sock.borrow_mut().add_readable(req_data); assert_eq!(check_poll(executor.step()), None); // read message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T277:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:t", "rue!}6:method,3:GET,3:uri,21:ws://example.com/path,7:heade", "rs,119:22:4:Host,11:example.com,]22:7:Upgrade,9:websocket,", "]30:21:Sec-WebSocket-Version,2:13,]29:17:Sec-WebSocket-Key", ",5:abcde,]]7:credits,4:1024#11:router-resp,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!( "T98:2:id,1:1,6:reason,19:Switching Protocols,3:seq,1:0#4:f", "rom,7:handler,4:code,3:101#7:credits,4:1024#}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert_eq!(s_to_conn.try_send((resp, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); assert_eq!(data.is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let data = sock.borrow_mut().take_writable(); let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&data).unwrap(), expected); sock.borrow_mut().clear_write_allowed(); let body = vec![0; 1024]; let mut rdata = zhttppacket::ResponseData::new(); rdata.body = body.as_slice(); rdata.content_type = Some(zhttppacket::ContentType::Text); let resp = zhttppacket::Response::new_data( b"handler", &[zhttppacket::Id { id: b"1", seq: Some(1), }], rdata, ); let mut buf = [0; 2048]; let size = resp.serialize(&mut buf).unwrap(); let msg = zmq::Message::from(&buf[..size]); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let resp = zhttppacket::OwnedResponse::parse(msg, 0, scratch).unwrap(); let resp = arena::Rc::new(resp, &resp_mem).unwrap(); assert!(s_to_conn.try_send((resp, 0)).is_ok()); assert_eq!(check_poll(executor.step()), None); // read message let (_, msg) = r_stream_from_conn.try_recv().unwrap(); // no other messages assert!(r_stream_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } async fn client_req_fut( id: Option>, zreq: arena::Rc, sock: Rc>, s_from_conn: channel::LocalSender, ) -> Result<(), Error> { let mut sock = AsyncFakeSock::new(sock); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut buf1 = VecRingBuffer::new(buffer_size, &rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, &rb_tmp); let mut body_buf = ContiguousBuffer::new(buffer_size); let packet_buf = RefCell::new(vec![0; 2048]); let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => rdata, _ => panic!("unexpected init packet"), }; let url = url::Url::parse(rdata.uri).unwrap(); let msg = match client_req_handler( "test", id.as_deref(), &mut sock, zreq, rdata.method, &url, true, false, &mut buf1, &mut buf2, &mut body_buf, &packet_buf, ) .await? { ClientHandlerDone::Complete(r, _) => r, ClientHandlerDone::Redirect(_, _, _) => panic!("unexpected redirect"), }; s_from_conn.send(msg).await?; Ok(()) } #[test] fn client_req_without_id() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_mem = Rc::new(arena::RcMemory::new(1)); let data = concat!( "T74:7:headers,16:12:3:Foo,3:Bar,]]3:uri,19:https://example.co", "m,6:method,3:GET,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); client_req_fut(None, zreq, sock, s_from_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "GET / HTTP/1.1\r\n", "Host: example.com\r\n", "Foo: Bar\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); // now handler will be able to send a message and finish assert_eq!(check_poll(executor.step()), Some(())); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T117:4:code,3:200#6:reason,2:OK,7:headers,60:30:12:Content", "-Type,10:text/plain,]22:14:Content-Length,1:6,]]4:body,6:h", "ello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_req_with_id() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(1)); let scratch_mem = Rc::new(arena::RcMemory::new(1)); let req_mem = Rc::new(arena::RcMemory::new(1)); let data = concat!( "T83:7:headers,16:12:3:Foo,3:Bar,]]3:uri,19:https://example.co", "m,6:method,3:GET,2:id,1:1,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); client_req_fut(Some(b"1".to_vec()), zreq, sock, s_from_conn) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // no messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); // fill the handler's outbound message queue assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_ok(), true); assert_eq!(s_from_conn.try_send(zmq::Message::new()).is_err(), true); drop(s_from_conn); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "GET / HTTP/1.1\r\n", "Host: example.com\r\n", "Foo: Bar\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let msg = r_from_conn.try_recv().unwrap(); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); // now handler will be able to send a message and finish assert_eq!(check_poll(executor.step()), Some(())); // read real message let msg = r_from_conn.try_recv().unwrap(); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T126:2:id,1:1,4:code,3:200#6:reason,2:OK,7:headers,60:30:1", "2:Content-Type,10:text/plain,]22:14:Content-Length,1:6,]]4", ":body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } async fn client_stream_fut( id: Vec, zreq: arena::Rc, sock: Rc>, allow_compression: bool, r_to_conn: channel::LocalReceiver<(arena::Rc, usize)>, s_from_conn: channel::LocalSender<(Option>, zmq::Message)>, shared: arena::Rc, ) -> Result<(), Error> { let mut sock = AsyncFakeSock::new(sock); let f = TrackFlag::default(); let r_to_conn = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r_to_conn), &f); let s_from_conn = AsyncLocalSender::new(s_from_conn); let buffer_size = 1024; let rb_tmp = Rc::new(TmpBuffer::new(buffer_size)); let mut buf1 = VecRingBuffer::new(buffer_size, &rb_tmp); let mut buf2 = VecRingBuffer::new(buffer_size, &rb_tmp); let packet_buf = RefCell::new(vec![0; 2048]); let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let mut response_received = false; let refresh_stream_timeout = || {}; let refresh_session_timeout = || {}; let zreq = zreq.get().get(); let rdata = match &zreq.ptype { zhttppacket::RequestPacket::Data(rdata) => rdata, _ => panic!("unexpected init packet"), }; let url = url::Url::parse(rdata.uri).unwrap(); let log_id = "test"; let instance_id = "test"; let zsess_out = ZhttpServerStreamSessionOut::new( instance_id, &id, &packet_buf, &s_from_conn, shared.get(), ); zsess_out.check_send().await; zsess_out.try_send_msg(zhttppacket::Response::new_keep_alive(b"", &[]))?; let mut zsess_in = ZhttpServerStreamSessionIn::new( log_id, &id, rdata.credits, &r_to_conn, shared.get(), &refresh_session_timeout, ); let _persistent = client_stream_handler( "test", &mut sock, zreq, rdata.method, &url, true, false, &mut buf1, &mut buf2, 3, &mut CounterDec::new(&Counter::new(1)), 10, allow_compression, &tmp_buf, &mut zsess_in, &zsess_out, &mut response_received, &refresh_stream_timeout, ) .await?; Ok(()) } #[test] fn client_stream() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T165:7:credits,4:1024#4:more,4:true!7:headers,34:30:12:Conten", "t-Type,10:text/plain,]]3:uri,24:https://example.com/path,6:me", "thod,4:POST,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_ok(), true ); assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_err(), true ); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:mu", "lti,4:true!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!("T52:3:seq,1:1#2:id,1:1,4:from,7:handler,4:body,6:hello\n,}"); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let expected = concat!("6\r\nhello\n\r\n0\r\n\r\n",); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), None); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T173:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:m", "ulti,4:true!}4:code,3:200#6:reason,2:OK,7:headers,60:30:12", ":Content-Type,10:text/plain,]22:14:Content-Length,1:6,]]4:", "more,4:true!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), Some(())); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T74:4:from,4:test,2:id,1:1,3:seq,1:3#3:ext,15:5:mu", "lti,4:true!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_stream_router_resp() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T187:7:credits,4:1024#4:more,4:true!7:headers,34:30:12:Conten", "t-Type,10:text/plain,]]3:uri,24:https://example.com/path,6:me", "thod,4:POST,3:seq,1:0#2:id,1:1,4:from,7:handler,11:router-res", "p,4:true!}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); shared.get().set_router_resp(true); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_ok(), true ); assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_err(), true ); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); let expected_addr = b"handler".as_slice(); // read real message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert_eq!(addr.as_deref(), Some(expected_addr)); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:multi,4:tr", "ue!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert_eq!(addr.as_deref(), Some(expected_addr)); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4:tr", "ue!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let msg = concat!("T52:3:seq,1:1#2:id,1:1,4:from,7:handler,4:body,6:hello\n,}"); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let expected = concat!("6\r\nhello\n\r\n0\r\n\r\n",); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), None); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let resp_data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ) .as_bytes(); sock.borrow_mut().add_readable(resp_data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert_eq!(addr.as_deref(), Some(expected_addr)); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T173:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:multi,4:t", "rue!}4:code,3:200#6:reason,2:OK,7:headers,60:30:12:Content", "-Type,10:text/plain,]22:14:Content-Length,1:6,]]4:more,4:t", "rue!}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); assert_eq!(check_poll(executor.step()), Some(())); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert_eq!(addr.as_deref(), Some(expected_addr)); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "T74:4:from,4:test,2:id,1:1,3:seq,1:3#3:ext,15:5:multi,4:tr", "ue!}4:body,6:hello\n,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_stream_expand_write_buffer() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T165:7:credits,4:1024#4:more,4:true!7:headers,34:30:12:Conten", "t-Type,10:text/plain,]]3:uri,24:https://example.com/path,6:me", "thod,4:POST,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert!(r_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert!(sock.borrow_mut().take_writable().is_empty()); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let expected = concat!( "POST /path HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", ); let buf = sock.borrow_mut().take_writable(); assert_eq!(str::from_utf8(&buf).unwrap(), expected); sock.borrow_mut().clear_write_allowed(); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert!(r_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "handler T91:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:mu", "lti,4:true!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); let body = vec![0; 1024]; let mut rdata = zhttppacket::RequestData::new(); rdata.body = body.as_slice(); rdata.more = true; let req = zhttppacket::Request::new_data( b"handler", &[zhttppacket::Id { id: b"1", seq: Some(1), }], rdata, ); let mut buf = [0; 2048]; let size = req.serialize(&mut buf).unwrap(); let msg = zmq::Message::from(&buf[..size]); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert!(s_to_conn.try_send((req, 0)).is_ok()); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert!(r_from_conn.try_recv().is_err()); let buf = &msg[..]; let expected = concat!( "handler T91:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:mu", "lti,4:true!}4:type,6:credit,7:credits,4:1024#}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); } #[test] fn client_websocket() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T115:7:credits,4:1024#7:headers,16:12:3:Foo,3:Bar,]]3:uri,22:", "wss://example.com/path,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, false, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_ok(), true ); assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_err(), true ); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let buf = sock.borrow_mut().take_writable(); // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let mut req = httparse::Request::new(&mut headers); match req.parse(&buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "Foo: Bar\r\n", "\r\n", ), ws_key ); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "\r\n", ), ws_accept ); sock.borrow_mut().add_readable(resp_data.as_bytes()); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = format!( concat!( "handler T249:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:m", "ulti,4:true!}}4:code,3:101#6:reason,19:Switching Protocols", ",7:headers,114:22:7:Upgrade,9:websocket,]24:10:Connection,", "7:Upgrade,]56:20:Sec-WebSocket-Accept,28:{},]]7:credits,4:", "1024#}}", ), ws_accept ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = b"hello"; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); let buf = &msg[..]; let expected = concat!( "handler T96:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:mu", "lti,4:true!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let mut data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = &mut data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; websocket::apply_mask(content, fi.mask.unwrap(), 0); assert_eq!(str::from_utf8(content).unwrap(), "world"); } #[test] fn client_websocket_with_deflate() { let reactor = Reactor::new(100); let msg_mem = Arc::new(arena::ArcMemory::new(2)); let scratch_mem = Rc::new(arena::RcMemory::new(2)); let req_mem = Rc::new(arena::RcMemory::new(2)); let data = concat!( "T115:7:credits,4:1024#7:headers,16:12:3:Foo,3:Bar,]]3:uri,22:", "wss://example.com/path,3:seq,1:0#2:id,1:1,4:from,7:handler,}", ) .as_bytes(); let msg = zmq::Message::from(data); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let zreq = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let zreq = arena::Rc::new(zreq, &req_mem).unwrap(); let sock = Rc::new(RefCell::new(FakeSock::new())); let (s_to_conn, r_to_conn) = channel::local_channel(1, 1, &reactor.local_registration_memory()); let (s_from_conn, r_from_conn) = channel::local_channel(1, 2, &reactor.local_registration_memory()); let fut = { let sock = sock.clone(); let s_from_conn = s_from_conn .try_clone(&reactor.local_registration_memory()) .unwrap(); let shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &shared_mem).unwrap(); let addr = ArrayVec::try_from(b"handler".as_slice()).unwrap(); shared.get().set_to_addr(Some(addr)); client_stream_fut( b"1".to_vec(), zreq, sock, true, r_to_conn, s_from_conn, shared, ) }; let mut executor = StepExecutor::new(&reactor, fut); // fill the handler's outbound message queue assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_ok(), true ); assert_eq!( s_from_conn.try_send((None, zmq::Message::new())).is_err(), true ); drop(s_from_conn); // handler won't be able to send a message yet assert_eq!(check_poll(executor.step()), None); // read bogus message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); assert_eq!(msg.is_empty(), true); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); // now handler will be able to send a message assert_eq!(check_poll(executor.step()), None); // read real message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = concat!( "handler T79:4:from,4:test,2:id,1:1,3:seq,1:0#3:ext,15:5:mu", "lti,4:true!}4:type,10:keep-alive,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // no data yet assert_eq!(sock.borrow_mut().take_writable().is_empty(), true); sock.borrow_mut().allow_write(1024); assert_eq!(check_poll(executor.step()), None); let buf = sock.borrow_mut().take_writable(); // use httparse to fish out Sec-WebSocket-Key let ws_key = { let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let mut req = httparse::Request::new(&mut headers); match req.parse(&buf) { Ok(httparse::Status::Complete(_)) => {} _ => panic!("unexpected parse status"), } let mut ws_key = String::new(); for h in req.headers { if h.name.eq_ignore_ascii_case("Sec-WebSocket-Key") { ws_key = String::from_utf8(h.value.to_vec()).unwrap(); } } ws_key }; let expected = format!( concat!( "GET /path HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: {}\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "Foo: Bar\r\n", "\r\n", ), ws_key ); assert_eq!(str::from_utf8(&buf).unwrap(), expected); // no more messages yet assert_eq!(r_from_conn.try_recv().is_err(), true); let ws_accept = calculate_ws_accept(ws_key.as_bytes()).unwrap(); let resp_data = format!( concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: {}\r\n", "Sec-WebSocket-Extensions: permessage-deflate\r\n", "\r\n", ), ws_accept ); sock.borrow_mut().add_readable(resp_data.as_bytes()); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); // no other messages assert_eq!(r_from_conn.try_recv().is_err(), true); let buf = &msg[..]; let expected = format!( concat!( "handler T303:4:from,4:test,2:id,1:1,3:seq,1:1#3:ext,15:5:m", "ulti,4:true!}}4:code,3:101#6:reason,19:Switching Protocols", ",7:headers,168:22:7:Upgrade,9:websocket,]24:10:Connection,", "7:Upgrade,]56:20:Sec-WebSocket-Accept,28:{},]50:24:Sec-Web", "Socket-Extensions,18:permessage-deflate,]]7:credits,4:1024", "#}}", ), ws_accept ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // send message let mut data = vec![0; 1024]; let body = { let src = b"hello"; let mut enc = websocket::DeflateEncoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = enc.encode(src, true, &mut dest).unwrap(); assert_eq!(read, 5); assert_eq!(output_end, true); dest.truncate(written); dest }; let size = websocket::write_header( true, true, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(&body); let data = &data[..(size + body.len())]; sock.borrow_mut().add_readable(data); assert_eq!(check_poll(executor.step()), None); // read message let (addr, msg) = r_from_conn.try_recv().unwrap(); assert!(addr.is_none()); let buf = &msg[..]; let expected = concat!( "handler T96:4:from,4:test,2:id,1:1,3:seq,1:2#3:ext,15:5:mu", "lti,4:true!}12:content-type,4:text,4:body,5:hello,}", ); assert_eq!(str::from_utf8(buf).unwrap(), expected); // recv message let msg = concat!( "T99:4:from,7:handler,2:id,1:1,3:seq,1:1#3:ext,15:5:multi,4", ":true!}12:content-type,4:text,4:body,5:world,}", ); let msg = zmq::Message::from(msg.as_bytes()); let msg = arena::Arc::new(msg, &msg_mem).unwrap(); let scratch = arena::Rc::new(RefCell::new(zhttppacket::ParseScratch::new()), &scratch_mem).unwrap(); let req = zhttppacket::OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = arena::Rc::new(req, &req_mem).unwrap(); assert_eq!(s_to_conn.try_send((req, 0)).is_ok(), true); assert_eq!(check_poll(executor.step()), None); let mut data = sock.borrow_mut().take_writable(); let fi = websocket::read_header(&data).unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, websocket::OPCODE_TEXT); assert!(data.len() >= fi.payload_offset + fi.payload_size); let content = { let src = &mut data[fi.payload_offset..(fi.payload_offset + fi.payload_size)]; websocket::apply_mask(src, fi.mask.unwrap(), 0); let mut dec = websocket::DeflateDecoder::new(); let mut dest = vec![0; 1024]; let (read, written, output_end) = dec.decode(src, true, &mut dest).unwrap(); assert_eq!(read, src.len()); assert_eq!(output_end, true); dest.truncate(written); dest }; assert_eq!(str::from_utf8(&content).unwrap(), "world"); } #[test] fn bench_server_req_handler() { let t = BenchServerReqHandler::new(); t.run(&mut t.init()); } #[test] fn bench_server_req_connection() { let t = BenchServerReqConnection::new(); t.run(&mut t.init()); } #[test] fn bench_server_stream_handler() { let t = BenchServerStreamHandler::new(); t.run(&mut t.init()); } #[test] fn bench_server_stream_connection() { let t = BenchServerStreamConnection::new(); t.run(&mut t.init()); } } pushpin-1.41.0/src/connmgr/counter.rs000066400000000000000000000065041504671364300175600ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::sync::atomic::{AtomicUsize, Ordering}; #[derive(Debug)] pub struct CounterError; /// An unsigned integer that can be shared between threads. Counter is backed by an AtomicUsize /// and performs operations with Relaxed memory ordering, so its value cannot be reliably assumed /// to be in sync with other atomic values, including other Counter values. pub struct Counter(AtomicUsize); impl Counter { pub fn new(value: usize) -> Self { Self(AtomicUsize::new(value)) } pub fn inc(&self, amount: usize) -> Result<(), CounterError> { if amount == 0 { return Ok(()); } loop { let value = self.0.load(Ordering::Relaxed); if amount > usize::MAX - value { return Err(CounterError); } if self .0 .compare_exchange(value, value + amount, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { break; } } Ok(()) } pub fn dec(&self, amount: usize) -> Result<(), CounterError> { if amount == 0 { return Ok(()); } loop { let value = self.0.load(Ordering::Relaxed); if amount > value { return Err(CounterError); } if self .0 .compare_exchange(value, value - amount, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { break; } } Ok(()) } } pub struct CounterDec<'a> { counter: &'a Counter, amount: usize, } impl<'a> CounterDec<'a> { pub fn new(counter: &'a Counter) -> Self { Self { counter, amount: 0 } } pub fn dec(&mut self, amount: usize) -> Result<(), CounterError> { self.counter.dec(amount)?; self.amount += amount; Ok(()) } } impl Drop for CounterDec<'_> { fn drop(&mut self) { assert!(self.counter.inc(self.amount).is_ok()); } } #[cfg(test)] mod tests { use super::*; #[test] fn counter() { let c = Counter::new(2); assert!(c.dec(1).is_ok()); assert!(c.dec(1).is_ok()); assert!(c.dec(1).is_err()); assert!(c.inc(1).is_ok()); assert!(c.dec(2).is_err()); assert!(c.dec(1).is_ok()); assert!(c.inc(usize::MAX).is_ok()); assert!(c.inc(1).is_err()); } #[test] fn counter_dec() { let c = Counter::new(2); { let mut c = CounterDec::new(&c); assert!(c.dec(1).is_ok()); assert!(c.dec(1).is_ok()); assert!(c.dec(1).is_err()); } assert!(c.dec(2).is_ok()); assert!(c.dec(1).is_err()); } } pushpin-1.41.0/src/connmgr/listener.rs000066400000000000000000000176651504671364300177400ustar00rootroot00000000000000/* * Copyright (C) 2020-2022 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::arena::recycle_vec; use crate::core::channel; use crate::core::executor::Executor; use crate::core::net::{AsyncNetListener, NetAcceptFuture, NetListener, NetStream, SocketAddr}; use crate::core::reactor::Reactor; use crate::core::select::{select_2, select_slice, Select2}; use log::{debug, error}; use std::cmp; use std::sync::mpsc; use std::thread; const REACTOR_REGISTRATIONS_MAX: usize = 128; const EXECUTOR_TASKS_MAX: usize = 1; pub struct Listener { thread: Option>, stop: channel::Sender<()>, } impl Listener { pub fn new( name: &str, listeners: Vec, senders: Vec>, ) -> Listener { let (s, r) = channel::channel(1); let thread = thread::Builder::new() .name(name.to_string()) .spawn(move || { let reactor = Reactor::new(REACTOR_REGISTRATIONS_MAX); let executor = Executor::new(EXECUTOR_TASKS_MAX); executor.spawn(Self::run(r, listeners, senders)).unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); }) .unwrap(); Self { thread: Some(thread), stop: s, } } async fn run( stop: channel::Receiver<()>, listeners: Vec, senders: Vec>, ) { let stop = channel::AsyncReceiver::new(stop); let mut listeners: Vec = listeners.into_iter().map(AsyncNetListener::new).collect(); let mut senders: Vec> = senders.into_iter().map(channel::AsyncSender::new).collect(); let mut listeners_pos = 0; let mut senders_pos = 0; let mut sender_tasks_mem: Vec> = Vec::with_capacity(senders.len()); let mut listener_tasks_mem: Vec = Vec::with_capacity(listeners.len()); let mut slice_scratch = Vec::with_capacity(cmp::max(senders.len(), listeners.len())); let mut stop_recv = stop.recv(); 'accept: loop { // wait for a sender to become writable let mut sender_tasks = recycle_vec(sender_tasks_mem); for s in senders.iter_mut() { sender_tasks.push(s.wait_writable()); } let result = select_2( &mut stop_recv, select_slice(&mut sender_tasks, &mut slice_scratch), ) .await; sender_tasks_mem = recycle_vec(sender_tasks); match result { Select2::R1(_) => break, Select2::R2(_) => {} } // accept a connection let mut listener_tasks = recycle_vec(listener_tasks_mem); let (b, a) = listeners.split_at_mut(listeners_pos); for l in a.iter_mut().chain(b.iter_mut()) { listener_tasks.push(l.accept()); } let (pos, stream, peer_addr) = loop { match select_2( &mut stop_recv, select_slice(&mut listener_tasks, &mut slice_scratch), ) .await { Select2::R1(_) => break 'accept, Select2::R2((pos, result)) => match result { Ok((stream, peer_addr)) => break (pos, stream, peer_addr), Err(e) => error!("accept error: {:?}", e), }, } }; listener_tasks_mem = recycle_vec(listener_tasks); let pos = (listeners_pos + pos) % listeners.len(); debug!("accepted connection from {}", peer_addr); listeners_pos = (pos + 1) % listeners.len(); // write connection to sender let mut pending_sock = Some((pos, stream, peer_addr)); for _ in 0..senders.len() { let sender = &mut senders[senders_pos]; if !sender.is_writable() { senders_pos = (senders_pos + 1) % senders.len(); continue; } let s = pending_sock.take().unwrap(); match sender.try_send(s) { Ok(()) => {} Err(mpsc::TrySendError::Full(s)) => pending_sock = Some(s), Err(mpsc::TrySendError::Disconnected(_)) => { // this could happen during shutdown debug!("receiver disconnected"); } } senders_pos = (senders_pos + 1) % senders.len(); if pending_sock.is_none() { break; } } } } } impl Drop for Listener { fn drop(&mut self) { // this should never fail. receiver won't disconnect unless // we tell it to self.stop.send(()).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] mod tests { use super::*; use crate::core::event; use mio::net::TcpListener; use std::io::{Read, Write}; use std::mem; use std::sync::mpsc; #[test] fn test_accept() { let mut addrs = Vec::new(); let mut listeners = Vec::new(); let mut senders = Vec::new(); let mut receivers = Vec::new(); for _ in 0..2 { let addr = "127.0.0.1:0".parse().unwrap(); let l = TcpListener::bind(addr).unwrap(); addrs.push(l.local_addr().unwrap()); listeners.push(NetListener::Tcp(l)); let (sender, receiver) = channel::channel(0); senders.push(sender); receivers.push(receiver); } let _l = Listener::new("listener-test", listeners, senders); let mut poller = event::Poller::new(1024).unwrap(); let mut client = std::net::TcpStream::connect(&addrs[0]).unwrap(); poller .register_custom( receivers[0].get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); let result = receivers[0].try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); loop { poller.poll(None).unwrap(); let mut done = false; for event in poller.iter_events() { match event.token() { mio::Token(1) => { assert_eq!(event.is_readable(), true); done = true; break; } _ => unreachable!(), } } if done { break; } } let (lnum, peer_client, _) = receivers[0].try_recv().unwrap(); assert_eq!(lnum, 0); let mut peer_client = match peer_client { NetStream::Tcp(s) => s, _ => unreachable!(), }; peer_client.write(b"hello").unwrap(); mem::drop(peer_client); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!(&buf, b"hello"); } } pushpin-1.41.0/src/connmgr/mod.rs000066400000000000000000000324621504671364300166620ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ mod batch; mod counter; mod listener; mod pool; mod track; mod zhttppacket; mod zhttpsocket; pub mod client; pub mod connection; pub mod resolver; pub mod server; pub mod tls; pub mod websocket; use self::client::Client; use self::server::{Server, MSG_RETAINED_PER_CONNECTION_MAX, MSG_RETAINED_PER_WORKER_MAX}; use crate::core::zmq::SpecInfo; use ipnet::IpNet; use log::{debug, info}; use signal_hook; use signal_hook::consts::TERM_SIGNALS; use signal_hook::iterator::Signals; use std::cmp; use std::error::Error; use std::path::PathBuf; use std::sync::atomic::AtomicBool; use std::sync::Arc; use std::time::Duration; const INIT_HWM: usize = 128; fn make_specs(base: &str, is_server: bool) -> Result<(String, String, String), String> { if base.starts_with("ipc:") { if is_server { Ok(( format!("{}-{}", base, "in"), format!("{}-{}", base, "in-stream"), format!("{}-{}", base, "out"), )) } else { Ok(( format!("{}-{}", base, "out"), format!("{}-{}", base, "out-stream"), format!("{}-{}", base, "in"), )) } } else if base.starts_with("tcp:") { match base.rfind(':') { Some(pos) => match base[(pos + 1)..base.len()].parse::() { Ok(port) => Ok(( format!("{}:{}", &base[..pos], port), format!("{}:{}", &base[..pos], port + 1), format!("{}:{}", &base[..pos], port + 2), )), Err(e) => Err(format!("error parsing tcp port in base spec: {}", e)), }, None => Err("tcp base spec must specify port".into()), } } else { Err("base spec must be ipc or tcp".into()) } } pub enum ListenSpec { Tcp { addr: std::net::SocketAddr, tls: bool, default_cert: Option, }, Local { path: PathBuf, mode: Option, user: Option, group: Option, }, } pub struct ListenConfig { pub spec: ListenSpec, pub stream: bool, } pub struct Config { pub instance_id: String, pub workers: usize, pub req_maxconn: usize, pub stream_maxconn: usize, pub buffer_size: usize, pub body_buffer_size: usize, pub blocks_max: usize, pub connection_blocks_max: usize, pub messages_max: usize, pub req_timeout: Duration, pub stream_timeout: Duration, pub listen: Vec, pub zclient_req: Vec, pub zclient_stream: Vec, pub zclient_connect: bool, pub zserver_req: Vec, pub zserver_stream: Vec, pub zserver_connect: bool, pub ipc_file_mode: u32, pub certs_dir: PathBuf, pub allow_compression: bool, pub deny: Vec, } pub struct App { _server: Option, _client: Option, } impl App { pub fn new(config: &Config) -> Result { if config.req_maxconn < config.workers { return Err("req maxconn must be >= workers".into()); } if config.stream_maxconn < config.workers { return Err("stream maxconn must be >= workers".into()); } let zmq_context = Arc::new(zmq::Context::new()); // set hwm to 5% of maxconn let other_hwm = cmp::max((config.req_maxconn + config.stream_maxconn) / 20, 1); let handle_bound = cmp::max(other_hwm / config.workers, 1); let maxconn = config.req_maxconn + config.stream_maxconn; let server = if !config.listen.is_empty() { let mut any_req = false; let mut any_stream = false; for lc in config.listen.iter() { if lc.stream { any_stream = true; } else { any_req = true; } } let mut zsockman = zhttpsocket::ClientSocketManager::new( Arc::clone(&zmq_context), &config.instance_id, (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * config.workers), INIT_HWM, other_hwm, handle_bound, ); if any_req { let mut specs = Vec::new(); for spec in config.zclient_req.iter() { if config.zclient_connect { debug!("zhttp client connect {}", spec); } else { debug!("zhttp client bind {}", spec); } specs.push(SpecInfo { spec: spec.clone(), bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_req_specs(&specs) { return Err(format!("failed to set zhttp client req specs: {}", e)); } } if any_stream { let mut out_specs = Vec::new(); let mut out_stream_specs = Vec::new(); let mut in_specs = Vec::new(); for spec in config.zclient_stream.iter() { let (out_spec, out_stream_spec, in_spec) = make_specs(spec, false)?; if config.zclient_connect { debug!( "zhttp client connect {} {} {}", out_spec, out_stream_spec, in_spec ); } else { debug!( "zhttp client bind {} {} {}", out_spec, out_stream_spec, in_spec ); } out_specs.push(SpecInfo { spec: out_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); out_stream_specs.push(SpecInfo { spec: out_stream_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); in_specs.push(SpecInfo { spec: in_spec, bind: !config.zclient_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_client_stream_specs(&out_specs, &out_stream_specs, &in_specs) { return Err(format!("failed to set zhttp client stream specs: {}", e)); } } Some(Server::new( &config.instance_id, config.workers, config.req_maxconn, config.stream_maxconn, config.buffer_size, config.body_buffer_size, config.blocks_max, config.connection_blocks_max, config.messages_max, config.req_timeout, config.stream_timeout, &config.listen, config.certs_dir.as_path(), config.allow_compression, zsockman, handle_bound, )?) } else { None }; let client = if !config.zserver_req.is_empty() || !config.zserver_stream.is_empty() { let mut zsockman = zhttpsocket::ServerSocketManager::new( Arc::clone(&zmq_context), &config.instance_id, (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * config.workers), INIT_HWM, other_hwm, handle_bound, config.stream_maxconn, ); if !config.zserver_req.is_empty() { let mut specs = Vec::new(); for spec in config.zserver_req.iter() { if config.zserver_connect { debug!("zhttp server connect {}", spec); } else { debug!("zhttp server bind {}", spec); } specs.push(SpecInfo { spec: spec.clone(), bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_server_req_specs(&specs) { return Err(format!("failed to set zhttp server req specs: {}", e)); } } let zsockman = Arc::new(zsockman); let client = Client::new( &config.instance_id, config.workers, config.req_maxconn, config.stream_maxconn, config.buffer_size, config.body_buffer_size, config.blocks_max, config.connection_blocks_max, config.messages_max, config.req_timeout, config.stream_timeout, config.allow_compression, &config.deny, zsockman.clone(), handle_bound, )?; // stream specs must only be applied after client is initialized if !config.zserver_stream.is_empty() { let mut in_specs = Vec::new(); let mut in_stream_specs = Vec::new(); let mut out_specs = Vec::new(); for spec in config.zserver_stream.iter() { let (in_spec, in_stream_spec, out_spec) = make_specs(spec, true)?; if config.zserver_connect { debug!( "zhttp server connect {} {} {}", in_spec, in_stream_spec, out_spec ); } else { debug!( "zhttp server bind {} {} {}", in_spec, in_stream_spec, out_spec ); } in_specs.push(SpecInfo { spec: in_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); in_stream_specs.push(SpecInfo { spec: in_stream_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); out_specs.push(SpecInfo { spec: out_spec, bind: !config.zserver_connect, ipc_file_mode: config.ipc_file_mode, }); } if let Err(e) = zsockman.set_server_stream_specs(&in_specs, &in_stream_specs, &out_specs) { return Err(format!("failed to set zhttp server stream specs: {}", e)); } } Some(client) } else { None }; Ok(Self { _server: server, _client: client, }) } pub fn wait_for_term(&self) { let mut signals = Signals::new(TERM_SIGNALS).unwrap(); let term_now = Arc::new(AtomicBool::new(false)); // ensure two term signals in a row causes the app to immediately exit for signal_type in TERM_SIGNALS { signal_hook::flag::register_conditional_shutdown( *signal_type, 1, // exit code Arc::clone(&term_now), ) .unwrap(); signal_hook::flag::register(*signal_type, Arc::clone(&term_now)).unwrap(); } // wait for termination let signal_type = signals.into_iter().next().unwrap(); assert!(TERM_SIGNALS.contains(&signal_type)); } pub fn sizes() -> Vec<(String, usize)> { let mut out = Vec::new(); out.extend(Server::task_sizes()); out.extend(Client::task_sizes()); out.push(( "deflate_codec_state".to_string(), websocket::deflate_codec_state_size(), )); out } } pub fn run(config: &Config) -> Result<(), Box> { debug!("starting..."); { let a = match App::new(config) { Ok(a) => a, Err(e) => { return Err(e.into()); } }; info!("started"); a.wait_for_term(); info!("stopping..."); } debug!("stopped"); Ok(()) } pushpin-1.41.0/src/connmgr/pool.rs000066400000000000000000000113651504671364300170530ustar00rootroot00000000000000/* * Copyright (C) 2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::list; use crate::core::timer::TimerWheel; use slab::Slab; use std::borrow::Borrow; use std::collections::HashMap; use std::hash::Hash; use std::time::{Duration, Instant}; const TICK_DURATION_MS: u64 = 10; fn duration_to_ticks_round_down(d: Duration) -> u64 { (d.as_millis() / (TICK_DURATION_MS as u128)) as u64 } struct PoolItem { key: K, value: V, timer_id: usize, } pub struct Pool { nodes: Slab>>, by_key: HashMap, wheel: TimerWheel, start: Instant, current_ticks: u64, } impl Pool where K: Clone + Eq + Hash + PartialEq, { pub fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), by_key: HashMap::with_capacity(capacity), wheel: TimerWheel::new(capacity), start: Instant::now(), current_ticks: 0, } } pub fn add(&mut self, key: K, value: V, expires: Instant) -> Result<(), V> { if self.nodes.len() == self.nodes.capacity() { return Err(value); } let expires = self.get_ticks(expires); let nkey = { let entry = self.nodes.vacant_entry(); let nkey = entry.key(); let timer_id = self.wheel.add(expires, nkey).unwrap(); entry.insert(list::Node::new(PoolItem { key: key.clone(), value, timer_id, })); nkey }; let l = self.by_key.entry(key).or_default(); l.push_back(&mut self.nodes, nkey); Ok(()) } pub fn take(&mut self, key: &Q) -> Option where K: Borrow, Q: Hash + Eq + ?Sized, { let l = self.by_key.get_mut(key)?; let nkey = l.pop_front(&mut self.nodes)?; if l.is_empty() { self.by_key.remove(key); } let pi = self.nodes.remove(nkey).value; self.wheel.remove(pi.timer_id); Some(pi.value) } pub fn expire(&mut self, now: Instant) -> Option<(K, V)> { let ticks = self.get_ticks(now); if ticks > self.current_ticks { self.wheel.update(ticks); self.current_ticks = ticks; } let nkey = match self.wheel.take_expired() { Some((_, nkey)) => nkey, None => return None, }; let pi = &self.nodes[nkey].value; let l = self.by_key.get_mut(&pi.key).unwrap(); l.remove(&mut self.nodes, nkey); if l.is_empty() { let pi = &self.nodes[nkey].value; self.by_key.remove(&pi.key); } let pi = self.nodes.remove(nkey).value; Some((pi.key, pi.value)) } fn get_ticks(&self, t: Instant) -> u64 { let d = if t > self.start { t - self.start } else { Duration::from_millis(0) }; duration_to_ticks_round_down(d) } } #[cfg(test)] mod tests { use super::*; #[test] fn pool_add_take() { let mut pool = Pool::new(3); let now = Instant::now(); pool.add(1, "a", now).unwrap(); pool.add(1, "b", now).unwrap(); pool.add(2, "c", now).unwrap(); assert_eq!(pool.add(2, "d", now).is_ok(), false); assert_eq!(pool.take(&1), Some("a")); assert_eq!(pool.take(&2), Some("c")); assert_eq!(pool.take(&1), Some("b")); assert_eq!(pool.take(&2), None); } #[test] fn pool_expire() { let mut pool = Pool::new(3); let now = Instant::now(); pool.add(1, "a", now + Duration::from_secs(1)).unwrap(); pool.add(1, "b", now + Duration::from_secs(2)).unwrap(); pool.add(2, "c", now + Duration::from_secs(3)).unwrap(); assert_eq!(pool.expire(now), None); assert_eq!(pool.expire(now + Duration::from_secs(1)), Some((1, "a"))); assert_eq!(pool.expire(now + Duration::from_secs(1)), None); assert_eq!(pool.expire(now + Duration::from_secs(5)), Some((1, "b"))); assert_eq!(pool.expire(now + Duration::from_secs(5)), Some((2, "c"))); assert_eq!(pool.expire(now + Duration::from_secs(5)), None); } } pushpin-1.41.0/src/connmgr/resolver.rs000066400000000000000000000336171504671364300177470ustar00rootroot00000000000000/* * Copyright (C) 2022 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::event; use crate::core::list; use crate::core::reactor::CustomEvented; use crate::core::task::get_reactor; use arrayvec::{ArrayString, ArrayVec}; use mio::Interest; use slab::Slab; use std::collections::VecDeque; use std::future::Future; use std::io; use std::net::{IpAddr, ToSocketAddrs}; use std::pin::Pin; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::{Arc, Condvar, Mutex}; use std::task::{Context, Poll}; use std::thread; pub const REGISTRATIONS_PER_QUERY: usize = 1; pub const ADDRS_MAX: usize = 16; pub type Hostname = ArrayString<255>; pub type Addrs = ArrayVec; fn std_resolve(host: &str) -> Result { match (host, 0).to_socket_addrs() { Ok(addrs) => Ok(addrs.take(ADDRS_MAX).map(|addr| addr.ip()).collect()), Err(e) => Err(e), } } struct QueryItem { host: Hostname, result: Option>, set_readiness: event::SetReadiness, invalidated: Option>, } struct QueriesInner { stop: bool, nodes: Slab>, next: list::List, registrations: VecDeque<(event::Registration, event::SetReadiness)>, invalidated_count: u32, } #[derive(Clone)] struct Queries { inner: Arc<(Mutex, Condvar)>, } impl Queries { fn new(queries_max: usize) -> Self { let mut registrations = VecDeque::with_capacity(queries_max); for _ in 0..registrations.capacity() { registrations.push_back(event::Registration::new()); } let inner = QueriesInner { stop: false, nodes: Slab::with_capacity(queries_max), next: list::List::default(), registrations, invalidated_count: 0, }; Self { inner: Arc::new((Mutex::new(inner), Condvar::new())), } } fn set_stop_flag(&self) { let (lock, cvar) = &*self.inner; let mut queries = lock.lock().unwrap(); queries.stop = true; cvar.notify_all(); } fn add(&self, host: &str) -> Result<(usize, event::Registration), ()> { let (lock, cvar) = &*self.inner; let queries = &mut *lock.lock().unwrap(); if queries.nodes.len() == queries.nodes.capacity() { return Err(()); } let (reg, sr) = queries.registrations.pop_back().unwrap(); let nkey = match Hostname::from(host) { Ok(host) => { let nkey = queries.nodes.insert(list::Node::new(QueryItem { host, result: None, set_readiness: sr, invalidated: None, })); queries.next.push_back(&mut queries.nodes, nkey); cvar.notify_one(); nkey } Err(_) => { sr.set_readiness(Interest::READABLE).unwrap(); queries.nodes.insert(list::Node::new(QueryItem { host: Hostname::new(), result: Some(Err(io::Error::from(io::ErrorKind::InvalidInput))), set_readiness: sr, invalidated: None, })) } }; Ok((nkey, reg)) } // block until a query is available, or stopped fn get_next(&self, invalidated: &Arc) -> Option<(usize, Hostname)> { let (lock, cvar) = &*self.inner; let mut queries_guard = lock.lock().unwrap(); loop { let queries = &mut *queries_guard; if queries.stop { return None; } if let Some(nkey) = queries.next.pop_front(&mut queries.nodes) { let qi = &mut queries.nodes[nkey].value; invalidated.store(false, Ordering::Relaxed); qi.invalidated = Some(invalidated.clone()); return Some((nkey, qi.host)); } queries_guard = cvar.wait(queries_guard).unwrap(); } } fn set_result( &self, item_key: usize, result: Result, invalidated: &AtomicBool, ) { let mut queries = self.inner.0.lock().unwrap(); if !invalidated.load(Ordering::Relaxed) { let qi = &mut queries.nodes[item_key].value; qi.result = Some(result); qi.invalidated = None; qi.set_readiness.set_readiness(Interest::READABLE).unwrap(); } else { queries.invalidated_count += 1; } } fn take_result(&self, item_key: usize) -> Option> { let queries = &mut *self.inner.0.lock().unwrap(); queries.nodes[item_key].value.result.take() } fn remove(&self, item_key: usize, registration: event::Registration) { let queries = &mut *self.inner.0.lock().unwrap(); // remove from next list if present queries.next.remove(&mut queries.nodes, item_key); let qi = queries.nodes.remove(item_key).value; if let Some(invalidated) = &qi.invalidated { invalidated.store(true, Ordering::Relaxed); } queries .registrations .push_back((registration, qi.set_readiness)); } #[cfg(test)] fn invalidated_count(&self) -> u32 { let queries = &mut *self.inner.0.lock().unwrap(); queries.invalidated_count } } struct ResolverInner { workers: Vec>, queries: Queries, } impl ResolverInner { fn new(num_threads: usize, queries_max: usize, resolve_fn: Arc) -> Self where F: Fn(&str) -> Result + Send + Sync + 'static, { let mut workers = Vec::with_capacity(num_threads); let queries = Queries::new(queries_max); for _ in 0..workers.capacity() { let queries = queries.clone(); let resolve_fn = resolve_fn.clone(); let thread = thread::Builder::new() .name("resolver".to_string()) .spawn(move || { let invalidated = Arc::new(AtomicBool::new(false)); loop { assert_eq!(Arc::strong_count(&invalidated), 1); let (item_key, host) = match queries.get_next(&invalidated) { Some(ret) => ret, None => break, }; let ret = resolve_fn(host.as_str()); queries.set_result(item_key, ret, &invalidated); } }) .unwrap(); workers.push(thread); } Self { workers, queries } } #[allow(clippy::result_unit_err)] fn resolve(&self, host: &str) -> Result { let (item_key, reg) = self.queries.add(host)?; Ok(Query { queries: self.queries.clone(), item_key, registration: Some(reg), }) } fn stop(&mut self) { self.queries.set_stop_flag(); for worker in self.workers.drain(..) { worker.join().unwrap(); } } } impl Drop for ResolverInner { fn drop(&mut self) { self.stop(); } } pub struct Resolver { inner: ResolverInner, } impl Resolver { pub fn new(num_threads: usize, queries_max: usize) -> Self { let inner = ResolverInner::new(num_threads, queries_max, Arc::new(std_resolve)); Self { inner } } #[allow(clippy::result_unit_err)] pub fn resolve(&self, host: &str) -> Result { self.inner.resolve(host) } } pub struct Query { queries: Queries, item_key: usize, registration: Option, } impl Query { pub fn get_read_registration(&self) -> &event::Registration { self.registration.as_ref().unwrap() } pub fn process(&self) -> Option> { self.queries.take_result(self.item_key) } } impl Drop for Query { fn drop(&mut self) { let reg = self.registration.take().unwrap(); self.queries.remove(self.item_key, reg); } } pub struct AsyncResolver<'a> { resolver: &'a Resolver, } impl<'a> AsyncResolver<'a> { pub fn new(resolver: &'a Resolver) -> Self { Self { resolver } } pub fn resolve(&self, host: &str) -> QueryFuture { let query = self.resolver.resolve(host).ok(); QueryFuture { evented: None, query, } } } pub struct QueryFuture { evented: Option, query: Option, } impl Future for QueryFuture { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let query = match &f.query { Some(query) => query, None => return Poll::Ready(Err(io::Error::from(io::ErrorKind::OutOfMemory))), }; let evented = match &f.evented { Some(evented) => evented, None => { let evented = CustomEvented::new( query.get_read_registration(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); f.evented = Some(evented); f.evented.as_ref().unwrap() } }; evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !evented.registration().is_ready() { return Poll::Pending; } match query.process() { Some(ret) => Poll::Ready(ret), None => { evented.registration().set_ready(false); Poll::Pending } } } } impl Drop for QueryFuture { fn drop(&mut self) { if let Some(evented) = &self.evented { let query = self.query.as_ref().unwrap(); // normally, a registration will deregister itself when dropped. // however, the query's registration is not dropped when the // query is dropped, so we need to explicitly deregister evented .registration() .deregister_custom(query.get_read_registration()) .unwrap(); } } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use crate::core::reactor::Reactor; #[test] fn resolve() { let mut poller = event::Poller::new(1).unwrap(); let resolver = Resolver::new(1, 1); let query = resolver.resolve("127.0.0.1").unwrap(); // queries_max is 1, so this should error assert_eq!(resolver.resolve("127.0.0.1").is_err(), true); // register query interest with poller poller .register_custom( query.get_read_registration(), mio::Token(1), Interest::READABLE, ) .unwrap(); // wait for completion let result = loop { if let Some(result) = query.process() { break result; } poller.poll(None).unwrap(); for _ in poller.iter_events() {} }; // deregister query interest poller .deregister_custom(query.get_read_registration()) .unwrap(); assert_eq!(result.unwrap().as_slice(), &[IpAddr::from([127, 0, 0, 1])]); } #[test] fn invalidate_query() { let mut inner = { let cond = Arc::new((Mutex::new(false), Condvar::new())); let resolve_fn = { let cond = cond.clone(); Arc::new(move |_: &str| { let (lock, cvar) = &*cond; let guard = lock.lock().unwrap(); // let main thread know we've started cvar.notify_one(); // wait for query to be removed let _guard = cvar.wait(guard).unwrap(); Ok(Addrs::new()) }) }; let (lock, cvar) = &*cond; let guard = lock.lock().unwrap(); let inner = ResolverInner::new(1, 1, resolve_fn); let query = inner.resolve("127.0.0.1").unwrap(); // wait for resolve_fn to start let _guard = cvar.wait(guard).unwrap(); drop(query); // let worker know the query has been removed cvar.notify_one(); inner }; inner.stop(); assert_eq!(inner.queries.invalidated_count(), 1); } #[test] fn async_resolve() { let reactor = Reactor::new(1); let executor = Executor::new(1); executor .spawn(async { let resolver = Resolver::new(1, 1); let resolver = AsyncResolver::new(&resolver); let f1 = resolver.resolve("127.0.0.1"); let f2 = resolver.resolve("127.0.0.1"); // will error, since queries_max=1 let addrs = f1.await.unwrap(); let e = f2.await.unwrap_err(); assert_eq!(addrs.as_slice(), &[IpAddr::from([127, 0, 0, 1])]); assert_eq!(e.kind(), io::ErrorKind::OutOfMemory); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } } pushpin-1.41.0/src/connmgr/server.rs000066400000000000000000003073231504671364300174120ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::connmgr::batch::{Batch, BatchKey}; use crate::connmgr::connection::{ server_req_connection, server_stream_connection, CidProvider, Identify, StreamSharedData, }; use crate::connmgr::counter::Counter; use crate::connmgr::listener::Listener; use crate::connmgr::tls::{AsyncTlsStream, IdentityCache, TlsAcceptor, TlsStream, TlsWaker}; use crate::connmgr::zhttppacket; use crate::connmgr::zhttpsocket; use crate::connmgr::{ListenConfig, ListenSpec}; use crate::core::arena; use crate::core::buffer::TmpBuffer; use crate::core::channel::{self, AsyncLocalReceiver, AsyncLocalSender, AsyncReceiver}; use crate::core::event; use crate::core::executor::{Executor, Spawner}; use crate::core::fs::{set_group, set_user}; use crate::core::list; use crate::core::net::{ set_socket_opts, AsyncTcpStream, AsyncUnixStream, NetListener, NetStream, SocketAddr, }; use crate::core::reactor::Reactor; use crate::core::select::{ select_2, select_3, select_6, select_8, select_option, Select2, Select3, Select6, Select8, }; use crate::core::task::{self, yield_to_local_events, CancellationSender, CancellationToken}; use crate::core::time::Timeout; use crate::core::tnetstring; use crate::core::waker::RefWakerData; use crate::core::zmq::SpecInfo; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, info, warn}; use mio::net::{TcpListener, TcpStream, UnixListener}; use mio::unix::SourceFd; use slab::Slab; use socket2::{Domain, Socket, Type}; use std::cell::RefCell; use std::collections::VecDeque; use std::fs; use std::io; use std::io::Write; use std::mem; use std::net::{IpAddr, Ipv4Addr}; use std::os::unix::fs::PermissionsExt; use std::os::unix::io::{FromRawFd, IntoRawFd}; use std::path::Path; use std::pin::pin; use std::rc::Rc; use std::str::{self, FromStr}; use std::sync::{mpsc, Arc}; use std::thread; use std::time::Duration; const RESP_SENDER_BOUND: usize = 1; const HANDLE_ACCEPT_BOUND: usize = 100; // we read and process each response message one at a time, wrapping it in an // rc, and sending it to connections via channels. on the other side of each // channel, the message is received and processed immediately. this means the // max number of messages retained per connection is the channel bound per // connection pub const MSG_RETAINED_PER_CONNECTION_MAX: usize = RESP_SENDER_BOUND; // the max number of messages retained outside of connections is one per // handle we read from (req and stream), in preparation for sending to any // connections pub const MSG_RETAINED_PER_WORKER_MAX: usize = 2; // run x1 // accept_task x2 // req_handle_task x1 // stream_handle_task x1 // keep_alives_task x1 const WORKER_NON_CONNECTION_TASKS_MAX: usize = 10; // note: individual tasks are not (and must not be) capped to this number. // this is because accept_task makes a registration for every connection // task, which means each instance of accept_task could end up making // thousands of registrations. however, such registrations are associated // with the spawning of connection_task, so we can still estimate // registrations relative to the number of tasks const REGISTRATIONS_PER_TASK_MAX: usize = 32; const REACTOR_BUDGET: u32 = 100; const KEEP_ALIVE_TIMEOUT_MS: usize = 45_000; const KEEP_ALIVE_BATCH_MS: usize = 100; const KEEP_ALIVE_INTERVAL: Duration = Duration::from_millis(KEEP_ALIVE_BATCH_MS as u64); const KEEP_ALIVE_BATCHES: usize = KEEP_ALIVE_TIMEOUT_MS / KEEP_ALIVE_BATCH_MS; const BULK_PACKET_SIZE_MAX: usize = 65_000; const SHUTDOWN_TIMEOUT: Duration = Duration::from_millis(10_000); fn get_addr_and_offset(msg: &[u8]) -> Result<(&str, usize), ()> { let mut pos = None; for (i, b) in msg.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = match pos { Some(pos) => pos, None => return Err(()), }; let addr = match str::from_utf8(&msg[..pos]) { Ok(addr) => addr, Err(_) => return Err(()), }; Ok((addr, pos + 1)) } fn get_key(id: &[u8]) -> Result { let mut start = None; let mut end = None; for (i, b) in id.iter().enumerate() { if *b == b'-' { if start.is_none() { start = Some(i + 1); } else { end = Some(i); break; } } } let start = match start { Some(start) => start, None => return Err(()), }; let end = match end { Some(end) => end, None => return Err(()), }; let key = match str::from_utf8(&id[start..end]) { Ok(key) => key, Err(_) => return Err(()), }; let key = match key.parse() { Ok(key) => key, Err(_) => return Err(()), }; Ok(key) } fn local_channel( bound: usize, max_senders: usize, ) -> (channel::LocalSender, channel::LocalReceiver) { let (s, r) = channel::local_channel( bound, max_senders, &Reactor::current().unwrap().local_registration_memory(), ); (s, r) } fn async_local_channel( bound: usize, max_senders: usize, ) -> (AsyncLocalSender, AsyncLocalReceiver) { let (s, r) = local_channel(bound, max_senders); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); (s, r) } fn gen_id(id: usize, ckey: usize, next_cid: &mut u32) -> ArrayString<32> { let mut buf = [0; 32]; let mut c = io::Cursor::new(&mut buf[..]); write!(&mut c, "{}-{}-{:x}", id, ckey, next_cid).unwrap(); let size = c.position() as usize; let s = str::from_utf8(&buf[..size]).unwrap(); *next_cid += 1; ArrayString::from_str(s).unwrap() } enum Stream { Plain(NetStream), Tls(TlsStream), } impl Identify for AsyncTcpStream { fn set_id(&mut self, _id: &str) { // do nothing } } impl Identify for AsyncUnixStream { fn set_id(&mut self, _id: &str) { // do nothing } } impl Identify for AsyncTlsStream<'_> { fn set_id(&mut self, id: &str) { // server generates ids known to always be accepted self.inner().set_id(id).unwrap(); } } enum BatchType { KeepAlive, Cancel, } struct ChannelPool { items: RefCell, channel::LocalReceiver)>>, } impl ChannelPool { fn new(capacity: usize) -> Self { Self { items: RefCell::new(VecDeque::with_capacity(capacity)), } } fn take(&self) -> Option<(channel::LocalSender, channel::LocalReceiver)> { let p = &mut *self.items.borrow_mut(); p.pop_back() } fn push(&self, pair: (channel::LocalSender, channel::LocalReceiver)) { let p = &mut *self.items.borrow_mut(); p.push_back(pair); } } struct ConnectionDone { ckey: usize, } struct ConnectionItem { id: ArrayString<32>, stop: Option, zreceiver_sender: channel::LocalSender<(arena::Rc, usize)>, shared: Option>, batch_key: Option, } struct ConnectionItems { nodes: Slab>, next_cid: u32, batch: Batch, } impl ConnectionItems { fn new(capacity: usize, batch: Batch) -> Self { Self { nodes: Slab::with_capacity(capacity), next_cid: 0, batch, } } } struct ConnectionsInner { active: list::List, count: usize, max: usize, } struct Connections { items: Rc>, inner: RefCell, } impl Connections { fn new(items: Rc>, max: usize) -> Self { Self { items, inner: RefCell::new(ConnectionsInner { active: list::List::default(), count: 0, max, }), } } fn count(&self) -> usize { self.inner.borrow().count } fn max(&self) -> usize { self.inner.borrow().max } fn add( &self, worker_id: usize, stop: CancellationSender, zreceiver_sender: channel::LocalSender<(arena::Rc, usize)>, shared: Option>, ) -> Result<(usize, ArrayString<32>), ()> { let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); if items.nodes.len() == items.nodes.capacity() { return Err(()); } let nkey = items.nodes.insert(list::Node::new(ConnectionItem { id: ArrayString::new(), stop: Some(stop), zreceiver_sender, shared, batch_key: None, })); items.nodes[nkey].value.id = gen_id(worker_id, nkey, &mut items.next_cid); c.active.push_back(&mut items.nodes, nkey); c.count += 1; Ok((nkey, items.nodes[nkey].value.id)) } // return zreceiver_sender fn remove( &self, ckey: usize, ) -> channel::LocalSender<(arena::Rc, usize)> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let c = &mut *self.inner.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } c.active.remove(&mut items.nodes, nkey); c.count -= 1; let ci = items.nodes.remove(nkey).value; ci.zreceiver_sender } fn regen_id(&self, worker_id: usize, ckey: usize) -> ArrayString<32> { let nkey = ckey; let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[nkey].value; // clear active keep alive if let Some(bkey) = ci.batch_key.take() { items.batch.remove(bkey); } ci.id = gen_id(worker_id, nkey, &mut items.next_cid); ci.id } fn check_id(&self, ckey: usize, id: &[u8]) -> bool { let nkey = ckey; let items = &*self.items.borrow(); let ci = match items.nodes.get(nkey) { Some(n) => &n.value, None => return false, }; ci.id.as_bytes() == id } fn try_send( &self, ckey: usize, value: (arena::Rc, usize), ) -> Result<(), mpsc::TrySendError<(arena::Rc, usize)>> { let nkey = ckey; let items = &*self.items.borrow(); let ci = &items.nodes[nkey].value; ci.zreceiver_sender.try_send(value) } fn stop_all(&self, about_to_stop: F) where F: Fn(usize), { let items = &mut *self.items.borrow_mut(); let cinner = &*self.inner.borrow_mut(); let mut next = cinner.active.head; while let Some(nkey) = next { let n = &mut items.nodes[nkey]; let ci = &mut n.value; about_to_stop(nkey); ci.stop = None; next = n.next; } } fn items_capacity(&self) -> usize { self.items.borrow().nodes.capacity() } fn is_item_stream(&self, ckey: usize) -> bool { let items = &*self.items.borrow(); match items.nodes.get(ckey) { Some(n) => { let ci = &n.value; ci.shared.is_some() } None => false, } } fn batch_is_empty(&self) -> bool { let items = &*self.items.borrow(); items.batch.is_empty() } fn batch_len(&self) -> usize { let items = &*self.items.borrow(); items.batch.len() } fn batch_capacity(&self) -> usize { let items = &*self.items.borrow(); items.batch.capacity() } fn batch_clear(&self) { let items = &mut *self.items.borrow_mut(); items.batch.clear(); } fn batch_add(&self, ckey: usize) -> Result<(), ()> { let items = &mut *self.items.borrow_mut(); let ci = &mut items.nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // only batch connections with known handler addresses let addr_ref = cshared.to_addr(); let addr = match addr_ref.get() { Some(addr) => addr, None => return Err(()), }; let bkey = items.batch.add(addr, false, ckey)?; ci.batch_key = Some(bkey); Ok(()) } fn next_batch_message( &self, from: &str, btype: BatchType, ) -> Option<(usize, ArrayVec, zmq::Message)> { let items = &mut *self.items.borrow_mut(); let nodes = &mut items.nodes; let batch = &mut items.batch; while !batch.is_empty() { let group = { let group = batch.take_group(|ckey| { let ci = &nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); // addr could have been removed after adding to the batch cshared.to_addr().get()?; Some((ci.id.as_bytes(), cshared.out_seq())) }); match group { Some(group) => group, None => continue, } }; let count = group.ids().len(); assert!(count <= zhttppacket::IDS_MAX); let zreq = zhttppacket::Request { from: from.as_bytes(), ids: group.ids(), multi: true, ptype: match btype { BatchType::KeepAlive => zhttppacket::RequestPacket::KeepAlive, BatchType::Cancel => zhttppacket::RequestPacket::Cancel, }, ptype_str: "", }; let mut data = [0; BULK_PACKET_SIZE_MAX]; let size = match zreq.serialize(&mut data) { Ok(size) => size, Err(e) => { error!( "failed to serialize keep-alive packet with {} ids: {}", zreq.ids.len(), e ); continue; } }; let data = &data[..size]; let mut addr = ArrayVec::::new(); if addr.try_extend_from_slice(group.addr()).is_err() { error!("failed to prepare addr"); continue; } let msg = zmq::Message::from(data); drop(group); for &ckey in batch.last_group_ckeys() { let ci = &mut nodes[ckey].value; let cshared = ci.shared.as_ref().unwrap().get(); cshared.inc_out_seq(); ci.batch_key = None; } return Some((count, addr, msg)); } None } } struct ConnectionCid<'a> { worker_id: usize, ckey: usize, conns: &'a Connections, } impl<'a> ConnectionCid<'a> { fn new(worker_id: usize, ckey: usize, conns: &'a Connections) -> Self { Self { worker_id, ckey, conns, } } } impl CidProvider for ConnectionCid<'_> { fn get_new_assigned_cid(&mut self) -> ArrayString<32> { self.conns.regen_id(self.worker_id, self.ckey) } } #[derive(Clone)] struct ConnectionOpts { instance_id: Rc, buffer_size: usize, timeout: Duration, rb_tmp: Rc, packet_buf: Rc>>, tmp_buf: Rc>>, } struct ConnectionReqOpts { body_buffer_size: usize, sender: channel::LocalSender, } struct ConnectionStreamOpts { blocks_max: usize, blocks_avail: Arc, messages_max: usize, allow_compression: bool, sender: channel::LocalSender, sender_stream: channel::LocalSender<(ArrayVec, zmq::Message)>, stream_shared_mem: Rc>, } enum ConnectionModeOpts { Req(ConnectionReqOpts), Stream(ConnectionStreamOpts), } struct Worker { thread: Option>, stop: Option>, } impl Worker { #[allow(clippy::too_many_arguments)] fn new( instance_id: &str, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, connection_blocks_max: usize, blocks_avail: &Arc, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, req_acceptor: channel::Receiver<(usize, NetStream, SocketAddr)>, stream_acceptor: channel::Receiver<(usize, NetStream, SocketAddr)>, req_acceptor_tls: &[(bool, Option)], stream_acceptor_tls: &[(bool, Option)], identities: &Arc, zsockman: &Arc, handle_bound: usize, ) -> Self { debug!("server-worker {}: starting", id); let (stop, r_stop) = channel::channel(1); let (s_ready, ready) = channel::channel(1); let instance_id = String::from(instance_id); let blocks_avail = Arc::clone(blocks_avail); let req_acceptor_tls = req_acceptor_tls.to_owned(); let stream_acceptor_tls = stream_acceptor_tls.to_owned(); let identities = Arc::clone(identities); let zsockman = Arc::clone(zsockman); let thread = thread::Builder::new() .name(format!("server-worker-{}", id)) .spawn(move || { let maxconn = req_maxconn + stream_maxconn; // 1 task per connection, plus a handful of supporting tasks let tasks_max = maxconn + WORKER_NON_CONNECTION_TASKS_MAX; let registrations_max = REGISTRATIONS_PER_TASK_MAX * tasks_max; let reactor = Reactor::new(registrations_max); let executor = Executor::new(tasks_max); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(REACTOR_BUDGET)); }); } executor .spawn(Self::run( r_stop, s_ready, instance_id, id, req_maxconn, stream_maxconn, buffer_size, body_buffer_size, connection_blocks_max, blocks_avail, messages_max, req_timeout, stream_timeout, allow_compression, req_acceptor, stream_acceptor, req_acceptor_tls, stream_acceptor_tls, identities, zsockman, handle_bound, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("server-worker {}: stopped", id); }) .unwrap(); ready.recv().unwrap(); Self { thread: Some(thread), stop: Some(stop), } } fn stop(&mut self) { self.stop = None; } #[allow(clippy::too_many_arguments)] async fn run( stop: channel::Receiver<()>, ready: channel::Sender<()>, instance_id: String, id: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, connection_blocks_max: usize, blocks_avail: Arc, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, allow_compression: bool, req_acceptor: channel::Receiver<(usize, NetStream, SocketAddr)>, stream_acceptor: channel::Receiver<(usize, NetStream, SocketAddr)>, req_acceptor_tls: Vec<(bool, Option)>, stream_acceptor_tls: Vec<(bool, Option)>, identities: Arc, zsockman: Arc, handle_bound: usize, ) { let executor = Executor::current().unwrap(); let reactor = Reactor::current().unwrap(); let stop = AsyncReceiver::new(stop); let req_acceptor = AsyncReceiver::new(req_acceptor); let stream_acceptor = AsyncReceiver::new(stream_acceptor); debug!("server-worker {}: allocating buffers", id); let rb_tmp = Rc::new(TmpBuffer::new(buffer_size * connection_blocks_max)); // large enough to fit anything let packet_buf = Rc::new(RefCell::new(vec![0; buffer_size + body_buffer_size + 4096])); // same size as working buffers let tmp_buf = Rc::new(RefCell::new(vec![0; buffer_size])); let instance_id = Rc::new(instance_id); let ka_batch = stream_maxconn.div_ceil(KEEP_ALIVE_BATCHES); let batch = Batch::new(ka_batch); let maxconn = req_maxconn + stream_maxconn; let conn_items = Rc::new(RefCell::new(ConnectionItems::new(maxconn, batch))); let req_conns = Rc::new(Connections::new(conn_items.clone(), req_maxconn)); let stream_conns = Rc::new(Connections::new(conn_items.clone(), stream_maxconn)); let (req_accept_stop, r_req_accept_stop) = async_local_channel(1, 1); let (stream_accept_stop, r_stream_accept_stop) = async_local_channel(1, 1); let (req_handle_stop, r_req_handle_stop) = async_local_channel(1, 1); let (stream_handle_stop, r_stream_handle_stop) = async_local_channel(1, 1); let (keep_alives_stop, r_keep_alives_stop) = async_local_channel(1, 1); let (s_req_accept_done, req_accept_done) = async_local_channel(1, 1); let (s_stream_accept_done, stream_accept_done) = async_local_channel(1, 1); let (s_req_handle_done, req_handle_done) = async_local_channel(1, 1); let (s_stream_handle_done, stream_handle_done) = async_local_channel(1, 1); let (s_keep_alives_done, keep_alives_done) = async_local_channel(1, 1); // max_senders is 1 per connection + 1 for the accept task let (zreq_sender, zreq_receiver) = local_channel(handle_bound, req_maxconn + 1); // max_senders is 1 per connection + 1 for the accept task let (zstream_out_sender, zstream_out_receiver) = local_channel(handle_bound, stream_maxconn + 1); // max_senders is 1 per connection + 1 for the accept task + 1 for the keep alive task let (zstream_out_stream_sender, zstream_out_stream_receiver) = local_channel(handle_bound, stream_maxconn + 2); let zreq_receiver = AsyncLocalReceiver::new(zreq_receiver); let zstream_out_receiver = AsyncLocalReceiver::new(zstream_out_receiver); let zstream_out_stream_receiver = AsyncLocalReceiver::new(zstream_out_stream_receiver); let req_handle = zhttpsocket::AsyncClientReqHandle::new( zsockman.client_req_handle(format!("{}-", id).as_bytes()), ); let stream_handle = zhttpsocket::AsyncClientStreamHandle::new( zsockman.client_stream_handle(format!("{}-", id).as_bytes()), ); let stream_shared_mem = Rc::new(arena::RcMemory::new(stream_maxconn)); let zreceiver_pool = Rc::new(ChannelPool::new(maxconn)); for _ in 0..maxconn { zreceiver_pool.push(local_channel(RESP_SENDER_BOUND, 1)); } let (s_req_cdone, r_req_cdone) = { let (s_from_handle, r_from_handle) = channel::local_channel( HANDLE_ACCEPT_BOUND, 1, &reactor.local_registration_memory(), ); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for the accept task let (s_from_conn, r_from_conn) = channel::local_channel( req_conns.max(), req_conns.max() + 1, &reactor.local_registration_memory(), ); executor .spawn(Self::accept_task( "req_accept", id, r_req_accept_stop, s_req_accept_done, req_acceptor, req_acceptor_tls, identities.clone(), executor.spawner(), zreceiver_pool.clone(), AsyncLocalReceiver::new(r_from_handle), s_from_conn, req_conns.clone(), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: req_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, ConnectionModeOpts::Req(ConnectionReqOpts { body_buffer_size, sender: zreq_sender, }), )) .unwrap(); (s_from_handle, r_from_conn) }; let (s_stream_cdone, r_stream_cdone) = { let (s_from_handle, r_from_handle) = channel::local_channel( HANDLE_ACCEPT_BOUND, 1, &reactor.local_registration_memory(), ); // bound is 1 per connection, so all connections can indicate done at once // max_senders is 1 per connection + 1 for the accept task let (s_from_conn, r_from_conn) = channel::local_channel( stream_conns.max(), stream_conns.max() + 1, &reactor.local_registration_memory(), ); let zstream_out_stream_sender = zstream_out_stream_sender .try_clone(&reactor.local_registration_memory()) .unwrap(); executor .spawn(Self::accept_task( "stream_accept", id, r_stream_accept_stop, s_stream_accept_done, stream_acceptor, stream_acceptor_tls, identities.clone(), executor.spawner(), zreceiver_pool.clone(), AsyncLocalReceiver::new(r_from_handle), s_from_conn, stream_conns.clone(), ConnectionOpts { instance_id: instance_id.clone(), buffer_size, timeout: stream_timeout, rb_tmp: rb_tmp.clone(), packet_buf: packet_buf.clone(), tmp_buf: tmp_buf.clone(), }, ConnectionModeOpts::Stream(ConnectionStreamOpts { blocks_max: connection_blocks_max, blocks_avail, messages_max, allow_compression, sender: zstream_out_sender, sender_stream: zstream_out_stream_sender, stream_shared_mem, }), )) .unwrap(); (s_from_handle, r_from_conn) }; executor .spawn(Self::req_handle_task( id, r_req_handle_stop, s_req_handle_done, zreq_receiver, AsyncLocalReceiver::new(r_req_cdone), AsyncLocalSender::new(s_req_cdone), req_handle, req_maxconn, req_conns.clone(), )) .unwrap(); executor .spawn(Self::stream_handle_task( id, r_stream_handle_stop, s_stream_handle_done, instance_id.clone(), zstream_out_receiver, zstream_out_stream_receiver, AsyncLocalReceiver::new(r_stream_cdone), AsyncLocalSender::new(s_stream_cdone), stream_handle, stream_maxconn, stream_conns.clone(), )) .unwrap(); executor .spawn(Self::keep_alives_task( id, r_keep_alives_stop, s_keep_alives_done, instance_id.clone(), zstream_out_stream_sender, stream_conns.clone(), )) .unwrap(); debug!("server-worker {}: started", id); ready.send(()).unwrap(); drop(ready); // wait for stop let _ = stop.recv().await; // stop keep alives drop(keep_alives_stop); let _ = keep_alives_done.recv().await; // stop connections drop(req_accept_stop); drop(stream_accept_stop); let _ = req_accept_done.recv().await; let _ = stream_accept_done.recv().await; // stop remaining tasks drop(req_handle_stop); drop(stream_handle_stop); let _ = req_handle_done.recv().await; let stream_handle = stream_handle_done.recv().await.unwrap(); // send cancels stream_conns.batch_clear(); let now = reactor.now(); let shutdown_timeout = Timeout::new(now + SHUTDOWN_TIMEOUT); let mut next_cancel_index = 0; 'outer: while next_cancel_index < stream_conns.items_capacity() { while stream_conns.batch_len() < stream_conns.batch_capacity() && next_cancel_index < stream_conns.items_capacity() { let key = next_cancel_index; next_cancel_index += 1; if stream_conns.is_item_stream(key) { // ignore errors let _ = stream_conns.batch_add(key); } } while let Some((count, addr, msg)) = stream_conns.next_batch_message(&instance_id, BatchType::Cancel) { debug!( "server-worker {}: sending cancels for {} sessions", id, count ); match select_2( pin!(stream_handle.send_to_addr(addr, msg)), shutdown_timeout.elapsed(), ) .await { Select2::R1(r) => r.unwrap(), Select2::R2(_) => break 'outer, } } } } #[allow(clippy::too_many_arguments)] async fn accept_task( name: &str, id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, acceptor: AsyncReceiver<(usize, NetStream, SocketAddr)>, acceptor_tls: Vec<(bool, Option)>, identities: Arc, spawner: Spawner, zreceiver_pool: Rc, usize)>>, cdone: AsyncLocalReceiver, s_cdone: channel::LocalSender, conns: Rc, opts: ConnectionOpts, mode_opts: ConnectionModeOpts, ) { let mut tls_acceptors = Vec::new(); for config in acceptor_tls { if config.0 { let default_cert = config.1.as_deref(); tls_acceptors.push(Some(TlsAcceptor::new(&identities, default_cert))); } else { tls_acceptors.push(None); } } let reactor = Reactor::current().unwrap(); debug!("server-worker {}: task started: {}", id, name); loop { let acceptor_recv = if conns.count() < conns.max() { Some(acceptor.recv()) } else { None }; let (pos, mut stream, peer_addr) = match select_3(stop.recv(), cdone.recv(), select_option(acceptor_recv)).await { // stop.recv Select3::R1(_) => break, // cdone.recv Select3::R2(result) => match result { Ok(done) => { let zreceiver_sender = conns.remove(done.ckey); let zreceiver = zreceiver_sender .make_receiver(&reactor.local_registration_memory()) .unwrap(); zreceiver.clear(); zreceiver_pool.push((zreceiver_sender, zreceiver)); continue; } Err(e) => panic!("cdone channel error: {}", e), }, // acceptor_recv Select3::R3(result) => match result { Ok(ret) => ret, Err(_) => continue, // ignore errors }, }; if let NetStream::Tcp(stream) = &mut stream { set_socket_opts(stream); } let stream = match stream { NetStream::Tcp(stream) => match &tls_acceptors[pos] { Some(tls_acceptor) => match tls_acceptor.accept(stream) { Ok(stream) => { debug!("server-worker {}: tls accept", id); Stream::Tls(stream) } Err(e) => { error!("server-worker {}: tls accept: {}", id, e); continue; } }, None => { debug!("server-worker {}: plain accept", id); Stream::Plain(NetStream::Tcp(stream)) } }, NetStream::Unix(stream) => Stream::Plain(NetStream::Unix(stream)), }; let (cstop, r_cstop) = CancellationToken::new(&reactor.local_registration_memory()); let s_cdone = s_cdone .try_clone(&reactor.local_registration_memory()) .unwrap(); let (ckey, conn_id, zreceiver, mode_opts, shared) = match &mode_opts { ConnectionModeOpts::Req(req_opts) => { let zreq_sender = req_opts .sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zreq_receiver_sender, zreq_receiver) = zreceiver_pool.take().unwrap(); let (ckey, conn_id) = conns.add(id, cstop, zreq_receiver_sender, None).unwrap(); debug!( "server-worker {}: req conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); let mode_opts = ConnectionModeOpts::Req(ConnectionReqOpts { body_buffer_size: req_opts.body_buffer_size, sender: zreq_sender, }); (ckey, conn_id, zreq_receiver, mode_opts, None) } ConnectionModeOpts::Stream(stream_opts) => { let zstream_out_sender = stream_opts .sender .try_clone(&reactor.local_registration_memory()) .unwrap(); let zstream_out_stream_sender = stream_opts .sender_stream .try_clone(&reactor.local_registration_memory()) .unwrap(); let (zstream_receiver_sender, zstream_receiver) = zreceiver_pool.take().unwrap(); let shared = arena::Rc::new(StreamSharedData::new(), &stream_opts.stream_shared_mem) .unwrap(); let (ckey, conn_id) = conns .add( id, cstop, zstream_receiver_sender, Some(arena::Rc::clone(&shared)), ) .unwrap(); debug!( "server-worker {}: stream conn starting {} {}/{}", id, ckey, conns.count(), conns.max(), ); let mode_opts = ConnectionModeOpts::Stream(ConnectionStreamOpts { blocks_max: stream_opts.blocks_max, blocks_avail: Arc::clone(&stream_opts.blocks_avail), messages_max: stream_opts.messages_max, allow_compression: stream_opts.allow_compression, sender: zstream_out_sender, sender_stream: zstream_out_stream_sender, stream_shared_mem: stream_opts.stream_shared_mem.clone(), }); (ckey, conn_id, zstream_receiver, mode_opts, Some(shared)) } }; match mode_opts { ConnectionModeOpts::Req(req_opts) => { if spawner .spawn(Self::req_connection_task( r_cstop, s_cdone, id, ckey, conn_id, stream, peer_addr, zreceiver, conns.clone(), opts.clone(), req_opts, )) .is_err() { // this should never happen. we only accept a connection if // we know we can spawn panic!("failed to spawn req_connection_task"); } } ConnectionModeOpts::Stream(stream_opts) => { if spawner .spawn(Self::stream_connection_task( r_cstop, s_cdone, id, ckey, conn_id, stream, peer_addr, zreceiver, conns.clone(), opts.clone(), stream_opts, shared.unwrap(), )) .is_err() { // this should never happen. we only accept a connection if // we know we can spawn panic!("failed to spawn stream_connection_task"); } } } } drop(s_cdone); conns.stop_all(|ckey| debug!("server-worker {}: stopping {}", id, ckey)); while cdone.recv().await.is_ok() {} debug!("server-worker {}: task stopped: {}", id, name); } #[allow(clippy::too_many_arguments)] async fn req_handle_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, zreq_receiver: AsyncLocalReceiver, r_cdone: AsyncLocalReceiver, s_cdone: AsyncLocalSender, req_handle: zhttpsocket::AsyncClientReqHandle, req_maxconn: usize, conns: Rc, ) { let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * req_maxconn); let req_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let req_resp_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let resume_waker = task::create_resume_waker(); debug!("server-worker {}: task started: req_handle", id); let mut handle_send = pin!(None); let mut done_send = None; loop { let receiver_recv = if handle_send.is_none() { Some(zreq_receiver.recv()) } else { None }; let done_recv = if done_send.is_none() { Some(r_cdone.recv()) } else { None }; match select_6( stop.recv(), select_option(receiver_recv), select_option(handle_send.as_mut().as_pin_mut()), select_option(done_recv), select_option(done_send.as_mut()), pin!(req_handle.recv()), ) .await { // stop.recv Select6::R1(_) => break, // receiver_recv Select6::R2(result) => match result { Ok(msg) => handle_send.set(Some(req_handle.send(msg))), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send Select6::R3(result) => { handle_send.set(None); if let Err(e) = result { error!("req send error: {}", e); } } // done_recv Select6::R4(result) => match result { Ok(msg) => done_send = Some(s_cdone.send(msg)), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // done send Select6::R5(result) => { done_send = None; if let Err(mpsc::SendError(_)) = result { // this can happen if accept ends first break; } } // req_handle.recv Select6::R6(result) => match result { Ok(msg) => { let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &req_scratch_mem, ) .unwrap(); let zresp = match zhttppacket::OwnedResponse::parse(msg, 0, scratch) { Ok(zresp) => zresp, Err(e) => { warn!("server-worker {}: zhttp parse error: {}", id, e); continue; } }; let zresp = arena::Rc::new(zresp, &req_resp_mem).unwrap(); let mut count = 0; for (i, rid) in zresp.get().get().ids.iter().enumerate() { let key = match get_key(rid.id) { Ok(key) => key, Err(_) => continue, }; if !conns.check_id(key, rid.id) { continue; } // this should always succeed, since afterwards we yield // to let the connection receive the message match conns.try_send(key, (arena::Rc::clone(&zresp), i)) { Ok(()) => count += 1, Err(mpsc::TrySendError::Full(_)) => error!( "server-worker {}: connection-{} cannot receive message", id, key ), Err(mpsc::TrySendError::Disconnected(_)) => {} // conn task ended } } debug!( "server-worker {}: queued zmq message for {} conns", id, count ); if count > 0 { yield_to_local_events(&resume_waker).await; } } Err(e) => panic!("server-worker {}: handle read error {}", id, e), }, } } debug!("server-worker {}: task stopped: req_handle", id); } #[allow(clippy::too_many_arguments)] async fn stream_handle_task( id: usize, stop: AsyncLocalReceiver<()>, done: AsyncLocalSender, instance_id: Rc, zstream_out_receiver: AsyncLocalReceiver, zstream_out_stream_receiver: AsyncLocalReceiver<(ArrayVec, zmq::Message)>, r_cdone: AsyncLocalReceiver, s_cdone: AsyncLocalSender, stream_handle: zhttpsocket::AsyncClientStreamHandle, stream_maxconn: usize, conns: Rc, ) { let msg_retained_max = 1 + (MSG_RETAINED_PER_CONNECTION_MAX * stream_maxconn); let stream_scratch_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let stream_resp_mem = Rc::new(arena::RcMemory::new(msg_retained_max)); let resume_waker = task::create_resume_waker(); debug!("server-worker {}: task started: stream_handle", id); { let mut handle_send_to_any = pin!(None); let mut handle_send_to_addr = pin!(None); let mut done_send = None; loop { let receiver_recv = if handle_send_to_any.is_none() { Some(zstream_out_receiver.recv()) } else { None }; let stream_receiver_recv = if handle_send_to_addr.is_none() { Some(zstream_out_stream_receiver.recv()) } else { None }; let done_recv = if done_send.is_none() { Some(r_cdone.recv()) } else { None }; match select_8( stop.recv(), select_option(receiver_recv), select_option(handle_send_to_any.as_mut().as_pin_mut()), select_option(stream_receiver_recv), select_option(handle_send_to_addr.as_mut().as_pin_mut()), select_option(done_recv), select_option(done_send.as_mut()), pin!(stream_handle.recv()), ) .await { // stop.recv Select8::R1(_) => break, // receiver_recv Select8::R2(result) => match result { Ok(msg) => handle_send_to_any.set(Some(stream_handle.send_to_any(msg))), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send_to_any Select8::R3(result) => { handle_send_to_any.set(None); if let Err(e) = result { error!("stream out send error: {}", e); } } // stream_receiver_recv Select8::R4(result) => match result { Ok((addr, msg)) => { handle_send_to_addr.set(Some(stream_handle.send_to_addr(addr, msg))) } Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // handle_send_to_addr Select8::R5(result) => { handle_send_to_addr.set(None); if let Err(e) = result { error!("stream out stream send error: {}", e); } } // done_recv Select8::R6(result) => match result { Ok(msg) => done_send = Some(s_cdone.send(msg)), Err(mpsc::RecvError) => break, // this can happen if accept+conns end first }, // done send Select8::R7(result) => { done_send = None; if let Err(mpsc::SendError(_)) = result { // this can happen if accept ends first break; } } // stream_handle.recv Select8::R8(result) => match result { Ok((msg, from_router)) => { let msg_data = &msg.get()[..]; let offset = if from_router { 0 } else { let (addr, offset) = match get_addr_and_offset(msg_data) { Ok(ret) => ret, Err(_) => { warn!("server-worker {}: packet has unexpected format", id); continue; } }; if addr != *instance_id { warn!("server-worker {}: packet not for us", id); continue; } offset }; let scratch = arena::Rc::new( RefCell::new(zhttppacket::ParseScratch::new()), &stream_scratch_mem, ) .unwrap(); let zresp = match zhttppacket::OwnedResponse::parse(msg, offset, scratch) { Ok(zresp) => zresp, Err(e) => { warn!("server-worker {}: zhttp parse error: {}", id, e); continue; } }; let zresp = arena::Rc::new(zresp, &stream_resp_mem).unwrap(); let mut count = 0; for (i, rid) in zresp.get().get().ids.iter().enumerate() { let key = match get_key(rid.id) { Ok(key) => key, Err(_) => continue, }; if !conns.check_id(key, rid.id) { continue; } // this should always succeed, since afterwards we yield // to let the connection receive the message match conns.try_send(key, (arena::Rc::clone(&zresp), i)) { Ok(()) => count += 1, Err(mpsc::TrySendError::Full(_)) => error!( "server-worker {}: connection-{} cannot receive message", id, key ), Err(mpsc::TrySendError::Disconnected(_)) => {} // conn task ended } } debug!( "server-worker {}: queued zmq message for {} conns", id, count ); if count > 0 { yield_to_local_events(&resume_waker).await; } } Err(e) => panic!("server-worker {}: handle read error {}", id, e), }, } } } // give the handle back done.send(stream_handle).await.unwrap(); debug!("server-worker {}: task stopped: stream_handle", id); } #[allow(clippy::too_many_arguments)] async fn req_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: ArrayString<32>, stream: Stream, peer_addr: SocketAddr, zreceiver: channel::LocalReceiver<(arena::Rc, usize)>, conns: Rc, opts: ConnectionOpts, req_opts: ConnectionReqOpts, ) { let done = AsyncLocalSender::new(done); let zreceiver = AsyncLocalReceiver::new(zreceiver); let mut cid_provider = ConnectionCid::new(worker_id, ckey, &conns); debug!( "server-worker {}: task started: connection-{}", worker_id, ckey ); match stream { Stream::Plain(stream) => match stream { NetStream::Tcp(stream) => { server_req_connection( token, cid, &mut cid_provider, AsyncTcpStream::new(stream), Some(&peer_addr), false, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.packet_buf, opts.timeout, AsyncLocalSender::new(req_opts.sender), zreceiver, ) .await } NetStream::Unix(stream) => { server_req_connection( token, cid, &mut cid_provider, AsyncUnixStream::new(stream), Some(&peer_addr), false, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.packet_buf, opts.timeout, AsyncLocalSender::new(req_opts.sender), zreceiver, ) .await } }, Stream::Tls(stream) => { let tls_waker_data = RefWakerData::new(TlsWaker::new()); server_req_connection( token, cid, &mut cid_provider, AsyncTlsStream::new(stream, &tls_waker_data), Some(&peer_addr), true, opts.buffer_size, req_opts.body_buffer_size, &opts.rb_tmp, opts.packet_buf, opts.timeout, AsyncLocalSender::new(req_opts.sender), zreceiver, ) .await } } done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "server-worker {}: task stopped: connection-{}", worker_id, ckey ); } #[allow(clippy::too_many_arguments)] async fn stream_connection_task( token: CancellationToken, done: channel::LocalSender, worker_id: usize, ckey: usize, cid: ArrayString<32>, stream: Stream, peer_addr: SocketAddr, zreceiver: channel::LocalReceiver<(arena::Rc, usize)>, conns: Rc, opts: ConnectionOpts, stream_opts: ConnectionStreamOpts, shared: arena::Rc, ) { let done = AsyncLocalSender::new(done); let zreceiver = AsyncLocalReceiver::new(zreceiver); let mut cid_provider = ConnectionCid::new(worker_id, ckey, &conns); debug!( "server-worker {}: task started: connection-{}", worker_id, ckey ); match stream { Stream::Plain(stream) => match stream { NetStream::Tcp(stream) => { server_stream_connection( token, cid, &mut cid_provider, AsyncTcpStream::new(stream), Some(&peer_addr), false, opts.buffer_size, stream_opts.blocks_max, &stream_opts.blocks_avail, stream_opts.messages_max, &opts.rb_tmp, opts.packet_buf, opts.tmp_buf, opts.timeout, stream_opts.allow_compression, &opts.instance_id, AsyncLocalSender::new(stream_opts.sender), AsyncLocalSender::new(stream_opts.sender_stream), zreceiver, shared, ) .await } NetStream::Unix(stream) => { server_stream_connection( token, cid, &mut cid_provider, AsyncUnixStream::new(stream), Some(&peer_addr), false, opts.buffer_size, stream_opts.blocks_max, &stream_opts.blocks_avail, stream_opts.messages_max, &opts.rb_tmp, opts.packet_buf, opts.tmp_buf, opts.timeout, stream_opts.allow_compression, &opts.instance_id, AsyncLocalSender::new(stream_opts.sender), AsyncLocalSender::new(stream_opts.sender_stream), zreceiver, shared, ) .await } }, Stream::Tls(stream) => { let tls_waker_data = RefWakerData::new(TlsWaker::new()); server_stream_connection( token, cid, &mut cid_provider, AsyncTlsStream::new(stream, &tls_waker_data), Some(&peer_addr), true, opts.buffer_size, stream_opts.blocks_max, &stream_opts.blocks_avail, stream_opts.messages_max, &opts.rb_tmp, opts.packet_buf, opts.tmp_buf, opts.timeout, stream_opts.allow_compression, &opts.instance_id, AsyncLocalSender::new(stream_opts.sender), AsyncLocalSender::new(stream_opts.sender_stream), zreceiver, shared, ) .await } } done.send(ConnectionDone { ckey }).await.unwrap(); debug!( "server-worker {}: task stopped: connection-{}", worker_id, ckey ); } async fn keep_alives_task( id: usize, stop: AsyncLocalReceiver<()>, _done: AsyncLocalSender<()>, instance_id: Rc, sender: channel::LocalSender<(ArrayVec, zmq::Message)>, conns: Rc, ) { debug!("server-worker {}: task started: keep_alives", id); let reactor = Reactor::current().unwrap(); let mut keep_alive_count = 0; let mut next_keep_alive_time = reactor.now() + KEEP_ALIVE_INTERVAL; let next_keep_alive_timeout = Timeout::new(next_keep_alive_time); let mut next_keep_alive_index = 0; let sender = AsyncLocalSender::new(sender); 'main: loop { // wait for next keep alive time match select_2(stop.recv(), next_keep_alive_timeout.elapsed()).await { Select2::R1(_) => break, Select2::R2(_) => {} } for _ in 0..conns.batch_capacity() { if next_keep_alive_index >= conns.items_capacity() { break; } let key = next_keep_alive_index; next_keep_alive_index += 1; if conns.is_item_stream(key) { // ignore errors let _ = conns.batch_add(key); } } keep_alive_count += 1; if keep_alive_count >= KEEP_ALIVE_BATCHES { keep_alive_count = 0; next_keep_alive_index = 0; } // keep steady pace next_keep_alive_time += KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); while !conns.batch_is_empty() { let send = match select_2(stop.recv(), sender.wait_sendable()).await { Select2::R1(_) => break 'main, Select2::R2(send) => send, }; // there could be no message if items removed or message construction failed let (count, addr, msg) = match conns.next_batch_message(&instance_id, BatchType::KeepAlive) { Some(ret) => ret, None => continue, }; debug!( "server-worker {}: sending keep alives for {} sessions", id, count ); if let Err(e) = send.try_send((addr, msg)) { error!("zhttp write error: {}", e); } } let now = reactor.now(); if now >= next_keep_alive_time + KEEP_ALIVE_INTERVAL { // got really behind somehow. just skip ahead next_keep_alive_time = now + KEEP_ALIVE_INTERVAL; next_keep_alive_timeout.set_deadline(next_keep_alive_time); } } debug!("server-worker {}: task stopped: keep_alives", id); } } impl Drop for Worker { fn drop(&mut self) { self.stop(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } pub struct Server { addrs: Vec, workers: Vec, // underscore-prefixed because we never reference after construction _req_listener: Listener, _stream_listener: Listener, } impl Server { #[allow(clippy::too_many_arguments)] pub fn new( instance_id: &str, worker_count: usize, req_maxconn: usize, stream_maxconn: usize, buffer_size: usize, body_buffer_size: usize, blocks_max: usize, connection_blocks_max: usize, messages_max: usize, req_timeout: Duration, stream_timeout: Duration, listen_addrs: &[ListenConfig], certs_dir: &Path, allow_compression: bool, zsockman: zhttpsocket::ClientSocketManager, handle_bound: usize, ) -> Result { assert!(blocks_max >= stream_maxconn * 2); let identities = Arc::new(IdentityCache::new(certs_dir)); let mut req_listeners = Vec::new(); let mut stream_listeners = Vec::new(); let mut req_acceptor_tls = Vec::new(); let mut stream_acceptor_tls = Vec::new(); let zsockman = Arc::new(zsockman); let mut addrs = Vec::new(); for lc in listen_addrs.iter() { match &lc.spec { ListenSpec::Tcp { addr, tls, default_cert, } => { let l = match TcpListener::bind(*addr) { Ok(l) => l, Err(e) => return Err(format!("failed to bind {}: {}", addr, e)), }; let addr = l.local_addr().unwrap(); info!("listening on {}", addr); addrs.push(SocketAddr::Ip(addr)); if lc.stream { stream_listeners.push(NetListener::Tcp(l)); stream_acceptor_tls.push((*tls, default_cert.clone())); } else { req_listeners.push(NetListener::Tcp(l)); req_acceptor_tls.push((*tls, default_cert.clone())); }; } ListenSpec::Local { path, mode, user, group, } => { // ensure pipe file doesn't exist match fs::remove_file(path) { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::NotFound => {} Err(e) => panic!("{}", e), } let l = match UnixListener::bind(path) { Ok(l) => l, Err(e) => return Err(format!("failed to bind {:?}: {}", path, e)), }; if let Some(mode) = mode { let perms = fs::Permissions::from_mode(*mode); if let Err(e) = fs::set_permissions(path, perms) { return Err(format!("failed to set mode on {:?}: {}", path, e)); } } if let Some(user) = user { if let Err(e) = set_user(path, user) { return Err(format!( "failed to set user {:?} on {:?}: {}", user, path, e )); } } if let Some(group) = group { if let Err(e) = set_group(path, group) { return Err(format!( "failed to set group {:?} on {:?}: {}", group, path, e )); } } let addr = l.local_addr().unwrap(); info!("listening on {:?}", addr); addrs.push(SocketAddr::Unix(addr)); if lc.stream { stream_listeners.push(NetListener::Unix(l)); stream_acceptor_tls.push((false, None)); } else { req_listeners.push(NetListener::Unix(l)); req_acceptor_tls.push((false, None)); }; } } } let blocks_avail = Arc::new(Counter::new(blocks_max - (stream_maxconn * 2))); let mut workers = Vec::new(); let mut req_lsenders = Vec::new(); let mut stream_lsenders = Vec::new(); for i in 0..worker_count { // rendezvous channels let (s, req_r) = channel::channel(0); req_lsenders.push(s); let (s, stream_r) = channel::channel(0); stream_lsenders.push(s); let w = Worker::new( instance_id, i, req_maxconn / worker_count, stream_maxconn / worker_count, buffer_size, body_buffer_size, connection_blocks_max, &blocks_avail, messages_max, req_timeout, stream_timeout, allow_compression, req_r, stream_r, &req_acceptor_tls, &stream_acceptor_tls, &identities, &zsockman, handle_bound, ); workers.push(w); } let req_listener = Listener::new("listener-req", req_listeners, req_lsenders); let stream_listener = Listener::new("listener-stream", stream_listeners, stream_lsenders); Ok(Self { addrs, workers, _req_listener: req_listener, _stream_listener: stream_listener, }) } pub fn addrs(&self) -> &[SocketAddr] { &self.addrs } pub fn task_sizes() -> Vec<(String, usize)> { let req_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (_, zreceiver) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let batch = Batch::new(1); let conn_items = Rc::new(RefCell::new(ConnectionItems::new(1, batch))); let conns = Rc::new(Connections::new(conn_items, 1)); let stream = { let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); let stream = unsafe { TcpStream::from_raw_fd(socket.into_raw_fd()) }; Stream::Plain(NetStream::Tcp(stream)) }; let peer_addr = SocketAddr::Ip(std::net::SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10000, )); let fut = Worker::req_connection_task( stop, done, 0, 0, ArrayString::from("0-0-0").unwrap(), stream, peer_addr, zreceiver, conns, ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionReqOpts { body_buffer_size: 0, sender, }, ); mem::size_of_val(&fut) }; let stream_task_size = { let reactor = Reactor::new(10); let (_, stop) = CancellationToken::new(&reactor.local_registration_memory()); let (done, _) = local_channel(1, 1); let (_, zreceiver) = local_channel(1, 1); let (sender, _) = local_channel(1, 1); let (sender_stream, _) = local_channel(1, 1); let batch = Batch::new(1); let conn_items = Rc::new(RefCell::new(ConnectionItems::new(1, batch))); let conns = Rc::new(Connections::new(conn_items, 1)); let stream = { let socket = Socket::new(Domain::IPV4, Type::STREAM, None).unwrap(); let stream = unsafe { TcpStream::from_raw_fd(socket.into_raw_fd()) }; Stream::Plain(NetStream::Tcp(stream)) }; let peer_addr = SocketAddr::Ip(std::net::SocketAddr::new( IpAddr::V4(Ipv4Addr::new(127, 0, 0, 1)), 10000, )); let stream_shared_mem = Rc::new(arena::RcMemory::new(1)); let shared = arena::Rc::new(StreamSharedData::new(), &stream_shared_mem).unwrap(); let fut = Worker::stream_connection_task( stop, done, 0, 0, ArrayString::from("0-0-0").unwrap(), stream, peer_addr, zreceiver, conns, ConnectionOpts { instance_id: Rc::new("".to_string()), buffer_size: 0, timeout: Duration::from_millis(0), rb_tmp: Rc::new(TmpBuffer::new(1)), packet_buf: Rc::new(RefCell::new(Vec::new())), tmp_buf: Rc::new(RefCell::new(Vec::new())), }, ConnectionStreamOpts { blocks_max: 2, blocks_avail: Arc::new(Counter::new(0)), messages_max: 0, allow_compression: false, sender, sender_stream, stream_shared_mem, }, shared, ); mem::size_of_val(&fut) }; vec![ ("server_req_connection_task".to_string(), req_task_size), ( "server_stream_connection_task".to_string(), stream_task_size, ), ] } } impl Drop for Server { fn drop(&mut self) { for w in self.workers.iter_mut() { w.stop(); } } } pub struct TestServer { server: Server, thread: Option>, stop: channel::Sender<()>, } impl TestServer { pub fn new(workers: usize) -> Self { let zmq_context = Arc::new(zmq::Context::new()); let req_maxconn = 100; let stream_maxconn = 100; let maxconn = req_maxconn + stream_maxconn; let mut zsockman = zhttpsocket::ClientSocketManager::new( Arc::clone(&zmq_context), "test", (MSG_RETAINED_PER_CONNECTION_MAX * maxconn) + (MSG_RETAINED_PER_WORKER_MAX * workers), 100, 100, 100, ); zsockman .set_client_req_specs(&[SpecInfo { spec: String::from("inproc://server-test"), bind: true, ipc_file_mode: 0, }]) .unwrap(); zsockman .set_client_stream_specs( &[SpecInfo { spec: String::from("inproc://server-test-out"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://server-test-out-stream"), bind: true, ipc_file_mode: 0, }], &[SpecInfo { spec: String::from("inproc://server-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let addr1 = "127.0.0.1:0".parse().unwrap(); let addr2 = "127.0.0.1:0".parse().unwrap(); let server = Server::new( "test", workers, req_maxconn, stream_maxconn, 1024, 1024, stream_maxconn * 2, 2, 10, Duration::from_secs(5), Duration::from_secs(5), &[ ListenConfig { spec: ListenSpec::Tcp { addr: addr1, tls: false, default_cert: None, }, stream: false, }, ListenConfig { spec: ListenSpec::Tcp { addr: addr2, tls: false, default_cert: None, }, stream: true, }, ], Path::new("."), false, zsockman, 100, ) .unwrap(); let (started_s, started_r) = channel::channel(1); let (stop_s, stop_r) = channel::channel(1); let thread = thread::Builder::new() .name("test-server".to_string()) .spawn(move || { Self::run(started_s, stop_r, zmq_context); }) .unwrap(); // wait for handler thread to start started_r.recv().unwrap(); Self { server, thread: Some(thread), stop: stop_s, } } pub fn req_addr(&self) -> std::net::SocketAddr { match self.server.addrs()[0] { SocketAddr::Ip(a) => a, _ => unimplemented!("test server doesn't implement unix sockets"), } } pub fn stream_addr(&self) -> std::net::SocketAddr { match self.server.addrs()[1] { SocketAddr::Ip(a) => a, _ => unimplemented!("test server doesn't implement unix sockets"), } } fn respond(id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"code")?; w.write_int(200)?; w.write_string(b"reason")?; w.write_string(b"OK")?; w.write_string(b"body")?; w.write_string(b"world\n")?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_stream(prefix_addr: bool, id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); if prefix_addr { cursor.write_all(b"test ")?; } cursor.write_all(b"T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; w.write_string(b"code")?; w.write_int(200)?; w.write_string(b"reason")?; w.write_string(b"OK")?; w.write_string(b"headers")?; w.start_array()?; if !prefix_addr { w.start_array()?; w.write_string(b"Response-Path")?; w.write_string(b"router")?; w.end_array()?; } w.start_array()?; w.write_string(b"Content-Length")?; w.write_string(b"6")?; w.end_array()?; w.end_array()?; w.write_string(b"body")?; w.write_string(b"world\n")?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_ws(id: &[u8]) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"test T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(0)?; w.write_string(b"code")?; w.write_int(101)?; w.write_string(b"reason")?; w.write_string(b"Switching Protocols")?; w.write_string(b"credits")?; w.write_int(1024)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn respond_msg( id: &[u8], seq: u32, ptype: &str, content_type: &str, body: &[u8], code: Option, ) -> Result { let mut dest = [0; 1024]; let mut cursor = io::Cursor::new(&mut dest[..]); cursor.write_all(b"test T")?; let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; w.write_string(b"from")?; w.write_string(b"handler")?; w.write_string(b"id")?; w.write_string(id)?; w.write_string(b"seq")?; w.write_int(seq as isize)?; if ptype.is_empty() { w.write_string(b"content-type")?; w.write_string(content_type.as_bytes())?; } else { w.write_string(b"type")?; w.write_string(ptype.as_bytes())?; } if let Some(x) = code { w.write_string(b"code")?; w.write_int(x as isize)?; } w.write_string(b"body")?; w.write_string(body)?; w.end_map()?; w.flush()?; let size = cursor.position() as usize; Ok(zmq::Message::from(&dest[..size])) } fn run( started: channel::Sender<()>, stop: channel::Receiver<()>, zmq_context: Arc, ) { let rep_sock = zmq_context.socket(zmq::REP).unwrap(); rep_sock.connect("inproc://server-test").unwrap(); let in_sock = zmq_context.socket(zmq::PULL).unwrap(); in_sock.connect("inproc://server-test-out").unwrap(); let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock.set_identity(b"handler").unwrap(); in_stream_sock .connect("inproc://server-test-out-stream") .unwrap(); let out_sock = zmq_context.socket(zmq::XPUB).unwrap(); out_sock.connect("inproc://server-test-in").unwrap(); // ensure zsockman is subscribed let msg = out_sock.recv_msg(0).unwrap(); assert_eq!(&msg[..], b"\x01test "); started.send(()).unwrap(); let mut poller = event::Poller::new(1).unwrap(); poller .register_custom( stop.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&rep_sock.get_fd().unwrap()), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_sock.get_fd().unwrap()), mio::Token(3), mio::Interest::READABLE, ) .unwrap(); poller .register( &mut SourceFd(&in_stream_sock.get_fd().unwrap()), mio::Token(4), mio::Interest::READABLE, ) .unwrap(); let mut rep_events = rep_sock.get_events().unwrap(); let mut in_events = in_sock.get_events().unwrap(); let mut in_stream_events = in_stream_sock.get_events().unwrap(); loop { while rep_events.contains(zmq::POLLIN) { let parts = match rep_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; assert_eq!(parts.len(), 1); let msg = &parts[0]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut method = ""; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "method" => { let s = tnetstring::parse_string(f.data).unwrap(); method = str::from_utf8(s).unwrap(); } _ => {} } } assert_eq!(method, "GET"); let msg = Self::respond(id.as_bytes()).unwrap(); rep_sock.send(msg, 0).unwrap(); rep_events = rep_sock.get_events().unwrap(); } while in_events.contains(zmq::POLLIN) { let parts = match in_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; in_events = in_sock.get_events().unwrap(); assert_eq!(parts.len(), 1); let msg = &parts[0]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut method = ""; let mut uri = ""; let mut router_resp = false; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "method" => { let s = tnetstring::parse_string(f.data).unwrap(); method = str::from_utf8(s).unwrap(); } "uri" => { let s = tnetstring::parse_string(f.data).unwrap(); uri = str::from_utf8(s).unwrap(); } "router-resp" => { router_resp = tnetstring::parse_bool(f.data).unwrap(); } _ => {} } } if !uri.contains("router-resp") { router_resp = false; } assert_eq!(method, "GET"); if uri.starts_with("ws:") { let msg = Self::respond_ws(id.as_bytes()).unwrap(); out_sock.send(msg, 0).unwrap(); } else { let msg = Self::respond_stream(!router_resp, id.as_bytes()).unwrap(); if router_resp { in_stream_sock .send_multipart([b"test".as_slice(), &[], &msg], 0) .unwrap(); } else { out_sock.send(msg, 0).unwrap(); } } } while in_stream_events.contains(zmq::POLLIN) { let parts = match in_stream_sock.recv_multipart(zmq::DONTWAIT) { Ok(parts) => parts, Err(zmq::Error::EAGAIN) => { break; } Err(e) => panic!("recv error: {:?}", e), }; in_stream_events = in_stream_sock.get_events().unwrap(); assert_eq!(parts.len(), 3); assert_eq!(parts[1].len(), 0); let msg = &parts[2]; assert_eq!(msg[0], b'T'); let mut id = ""; let mut seq = None; let mut ptype = ""; let mut content_type = ""; let mut body = &b""[..]; let mut code = None; for f in tnetstring::parse_map(&msg[1..]).unwrap() { let f = f.unwrap(); match f.key { "id" => { let s = tnetstring::parse_string(f.data).unwrap(); id = str::from_utf8(s).unwrap(); } "seq" => { seq = Some(tnetstring::parse_int(f.data).unwrap() as u32); } "type" => { let s = tnetstring::parse_string(f.data).unwrap(); ptype = str::from_utf8(s).unwrap(); } "content-type" => { let s = tnetstring::parse_string(f.data).unwrap(); content_type = str::from_utf8(s).unwrap(); } "body" => { body = tnetstring::parse_string(f.data).unwrap(); } "code" => { code = Some(tnetstring::parse_int(f.data).unwrap() as u16); } _ => {} } } let seq = seq.unwrap(); // as a hack to make the test server stateless, respond to every message // using the received sequence number. for messages we don't care about, // respond with keep-alive in order to keep the sequencing going if ptype.is_empty() || ptype == "ping" || ptype == "pong" || ptype == "close" { if ptype == "ping" { ptype = "pong"; } let msg = Self::respond_msg(id.as_bytes(), seq, ptype, content_type, body, code) .unwrap(); out_sock.send(msg, 0).unwrap(); } else { let msg = Self::respond_msg(id.as_bytes(), seq, "keep-alive", "", &b""[..], None) .unwrap(); out_sock.send(msg, 0).unwrap(); } } poller.poll(None).unwrap(); let mut done = false; for event in poller.iter_events() { match event.token() { mio::Token(1) => { if stop.try_recv().is_ok() { done = true; break; } } mio::Token(2) => { rep_events = rep_sock.get_events().unwrap(); } mio::Token(3) => { in_events = in_sock.get_events().unwrap(); } mio::Token(4) => { in_stream_events = in_stream_sock.get_events().unwrap(); } _ => unreachable!(), } } if done { break; } } } } impl Drop for TestServer { fn drop(&mut self) { self.stop.try_send(()).unwrap(); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[cfg(test)] pub mod tests { use super::*; use crate::connmgr::websocket; use std::io::Read; use test_log::test; fn recv_frame( stream: &mut R, buf: &mut Vec, ) -> Result<(bool, u8, Vec), io::Error> { loop { let fi = match websocket::read_header(buf) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); continue; } Err(e) => return Err(e), }; while buf.len() < fi.payload_offset + fi.payload_size { let mut chunk = [0; 1024]; let size = stream.read(&mut chunk)?; if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } buf.extend_from_slice(&chunk[..size]); } let content = Vec::from(&buf[fi.payload_offset..(fi.payload_offset + fi.payload_size)]); *buf = buf.split_off(fi.payload_offset + fi.payload_size); return Ok((fi.fin, fi.opcode, content)); } } #[test] fn test_server() { let server = TestServer::new(1); // req let mut client = std::net::TcpStream::connect(&server.req_addr()).unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n" ); // stream (http) let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); client .write(b"GET /hello HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n" ); // stream (http) with responses via router let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); client .write(b"GET /hello?router-resp HTTP/1.0\r\nHost: example.com\r\n\r\n") .unwrap(); let mut buf = Vec::new(); client.read_to_end(&mut buf).unwrap(); assert_eq!( str::from_utf8(&buf).unwrap(), "HTTP/1.0 200 OK\r\nResponse-Path: router\r\nContent-Length: 6\r\n\r\nworld\n" ); // stream (ws) let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); let req = concat!( "GET /hello HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n", ); client.write(req.as_bytes()).unwrap(); let mut buf = Vec::new(); let mut resp_end = 0; loop { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { resp_end = i + 4; break; } } if resp_end > 0 { break; } } let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&buf[..resp_end]).unwrap(), expected); buf = buf.split_off(resp_end); // send message let mut data = vec![0; 1024]; let body = &b"hello"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_TEXT, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv message let (fin, opcode, content) = recv_frame(&mut client, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_TEXT); assert_eq!(str::from_utf8(&content).unwrap(), "hello"); } #[test] fn test_ws() { let server = TestServer::new(1); let mut client = std::net::TcpStream::connect(&server.stream_addr()).unwrap(); let req = concat!( "GET /hello HTTP/1.1\r\n", "Host: example.com\r\n", "Upgrade: websocket\r\n", "Sec-WebSocket-Version: 13\r\n", "Sec-WebSocket-Key: abcde\r\n", "\r\n", ); client.write(req.as_bytes()).unwrap(); let mut buf = Vec::new(); let mut resp_end = 0; loop { let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); buf.extend_from_slice(&chunk[..size]); for i in 0..(buf.len() - 3) { if &buf[i..(i + 4)] == b"\r\n\r\n" { resp_end = i + 4; break; } } if resp_end > 0 { break; } } let expected = concat!( "HTTP/1.1 101 Switching Protocols\r\n", "Upgrade: websocket\r\n", "Connection: Upgrade\r\n", "Sec-WebSocket-Accept: 8m4i+0BpIKblsbf+VgYANfQKX4w=\r\n", "\r\n", ); assert_eq!(str::from_utf8(&buf[..resp_end]).unwrap(), expected); buf = buf.split_off(resp_end); // send binary let mut data = vec![0; 1024]; let body = &[1, 2, 3][..]; let size = websocket::write_header( true, false, websocket::OPCODE_BINARY, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv binary let (fin, opcode, content) = recv_frame(&mut client, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_BINARY); assert_eq!(content, &[1, 2, 3][..]); buf.clear(); // send ping let mut data = vec![0; 1024]; let body = &b""[..]; let size = websocket::write_header( true, false, websocket::OPCODE_PING, body.len(), None, &mut data, ) .unwrap(); client.write(&data[..size]).unwrap(); // recv pong let (fin, opcode, content) = recv_frame(&mut client, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_PONG); assert_eq!(str::from_utf8(&content).unwrap(), ""); buf.clear(); // send close let mut data = vec![0; 1024]; let body = &b"\x03\xf0gone"[..]; let size = websocket::write_header( true, false, websocket::OPCODE_CLOSE, body.len(), None, &mut data, ) .unwrap(); data[size..(size + body.len())].copy_from_slice(body); client.write(&data[..(size + body.len())]).unwrap(); // recv close let (fin, opcode, content) = recv_frame(&mut client, &mut buf).unwrap(); assert_eq!(fin, true); assert_eq!(opcode, websocket::OPCODE_CLOSE); assert_eq!(&content, &b"\x03\xf0gone"[..]); // expect tcp close let mut chunk = [0; 1024]; let size = client.read(&mut chunk).unwrap(); assert_eq!(size, 0); } #[cfg(target_arch = "x86_64")] #[cfg(debug_assertions)] #[test] fn test_task_sizes() { // sizes in debug mode at commit 4c1b0bb177314051405ef5be3cde023e9d1ad635 const REQ_TASK_SIZE_BASE: usize = 5824; const STREAM_TASK_SIZE_BASE: usize = 7760; // cause tests to fail if sizes grow too much const GROWTH_LIMIT: usize = 1000; const REQ_TASK_SIZE_MAX: usize = REQ_TASK_SIZE_BASE + GROWTH_LIMIT; const STREAM_TASK_SIZE_MAX: usize = STREAM_TASK_SIZE_BASE + GROWTH_LIMIT; let sizes = Server::task_sizes(); assert_eq!(sizes[0].0, "server_req_connection_task"); assert!(sizes[0].1 <= REQ_TASK_SIZE_MAX); assert_eq!(sizes[1].0, "server_stream_connection_task"); assert!(sizes[1].1 <= STREAM_TASK_SIZE_MAX); } } pushpin-1.41.0/src/connmgr/tls.rs000066400000000000000000001305731504671364300167070ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::event::{self, ReadinessExt}; use crate::core::io::{AsyncRead, AsyncWrite}; use crate::core::net::AsyncTcpStream; use crate::core::reactor::Registration; use crate::core::task::get_reactor; use crate::core::waker::{RefWake, RefWaker, RefWakerData}; use arrayvec::ArrayString; use log::debug; use mio::net::TcpStream; use openssl::error::ErrorStack; use openssl::pkey::PKey; use openssl::ssl::{ self, HandshakeError, MidHandshakeSslStream, NameType, SniError, SslAcceptor, SslConnector, SslContext, SslContextBuilder, SslFiletype, SslMethod, SslStream, SslVerifyMode, }; use openssl::x509::X509; use std::any::Any; use std::cell::{Ref, RefCell}; use std::cmp; use std::collections::HashMap; use std::fmt; use std::fs; use std::future::Future; use std::io; use std::io::{Read, Write}; use std::marker::PhantomData; use std::mem; use std::os::fd::{FromRawFd, IntoRawFd}; use std::path; use std::path::{Path, PathBuf}; use std::pin::Pin; use std::ptr; use std::str::FromStr; use std::sync::{Arc, Mutex, MutexGuard}; use std::task::{Context, Poll, Waker}; use std::time::{Duration, Instant, SystemTime}; const DOMAIN_LEN_MAX: usize = 253; const CONFIG_CACHE_TTL: Duration = Duration::from_secs(60); enum IdentityError { InvalidName, CertMetadata(PathBuf, io::Error), KeyMetadata(PathBuf, io::Error), SslContext(ErrorStack), CertContent(PathBuf, ErrorStack), KeyContent(PathBuf, ErrorStack), CertCheck(ErrorStack), } impl fmt::Display for IdentityError { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::InvalidName => write!(f, "invalid name"), Self::CertMetadata(fname, e) => { write!(f, "failed to read cert file metadata {:?}: {}", fname, e) } Self::KeyMetadata(fname, e) => { write!(f, "failed to read key file metadata {:?}: {}", fname, e) } Self::SslContext(e) => write!(f, "failed to create SSL context: {}", e), Self::CertContent(fname, e) => { write!(f, "failed to read cert content {:?}: {}", fname, e) } Self::KeyContent(fname, e) => { write!(f, "failed to read key content {:?}: {}", fname, e) } Self::CertCheck(e) => write!(f, "failed to check private key: {}", e), } } } struct Identity { ssl_context: SslContext, cert_fname: PathBuf, key_fname: PathBuf, modified: Option, } impl Identity { fn from_name(dir: &Path, name: &str) -> Result { // forbid long names if name.len() > DOMAIN_LEN_MAX { return Err(IdentityError::InvalidName); } // forbid control chars and '/', for safe filesystem usage for c in name.chars() { if (c as u32) < 0x20 || path::is_separator(c) { return Err(IdentityError::InvalidName); } } let cert_fname = dir.join(Path::new(&format!("{}.crt", name))); let cert_metadata = match fs::metadata(&cert_fname) { Ok(md) => md, Err(e) => return Err(IdentityError::CertMetadata(cert_fname, e)), }; let key_fname = dir.join(Path::new(&format!("{}.key", name))); let key_metadata = match fs::metadata(&key_fname) { Ok(md) => md, Err(e) => return Err(IdentityError::KeyMetadata(key_fname, e)), }; let cert_modified = cert_metadata.modified(); let key_modified = key_metadata.modified(); #[allow(clippy::unnecessary_unwrap)] let modified = if cert_modified.is_ok() && key_modified.is_ok() { Some(cmp::max(cert_modified.unwrap(), key_modified.unwrap())) } else { None }; let mut ctx = match SslContextBuilder::new(SslMethod::tls()) { Ok(ctx) => ctx, Err(e) => return Err(IdentityError::SslContext(e)), }; if let Err(e) = ctx.set_certificate_chain_file(&cert_fname) { return Err(IdentityError::CertContent(cert_fname, e)); } if let Err(e) = ctx.set_private_key_file(&key_fname, SslFiletype::PEM) { return Err(IdentityError::KeyContent(key_fname, e)); } if let Err(e) = ctx.check_private_key() { return Err(IdentityError::CertCheck(e)); } Ok(Self { ssl_context: ctx.build(), cert_fname, key_fname, modified, }) } } fn modified_after(fnames: &[&Path], t: SystemTime) -> Result { for fname in fnames { match fs::metadata(fname)?.modified() { Ok(modified) if modified > t => return Ok(true), _ => {} } } Ok(false) } struct IdentityRef<'a> { _data: MutexGuard<'a, HashMap>, name: &'a str, value: &'a Identity, } pub struct IdentityCache { dir: PathBuf, data: Mutex>, } impl IdentityCache { pub fn new(certs_dir: &Path) -> Self { Self { dir: certs_dir.to_path_buf(), data: Mutex::new(HashMap::new()), } } fn get_by_domain<'a>(&'a self, domain: &str) -> Option> { let name = domain.to_lowercase(); // try to find a file named after the exact host, then try with a // wildcard pattern at the same subdomain level. the filename // format uses underscores instead of asterisks. so, a domain of // www.example.com will attempt to be matched against a file named // www.example.com.crt and _.example.com.crt. wildcards at other // levels are not supported if let Some(identity) = self.get_by_name(&name) { return Some(identity); } let pos = name.find('.')?; let name = format!("_{}", &name[pos..]); if let Some(identity) = self.get_by_name(&name) { return Some(identity); } None } fn get_by_name<'a>(&'a self, name: &str) -> Option> { self.ensure_updated(name); let data = self.data.lock().unwrap(); if let Some((name, value)) = data.get_key_value(name) { // extending the lifetimes is safe because we keep the owning MutexGuard let name = unsafe { mem::transmute::<&String, &'a String>(name) }; let value = unsafe { mem::transmute::<&Identity, &'a Identity>(value) }; Some(IdentityRef { _data: data, name: name.as_str(), value, }) } else { None } } fn ensure_updated(&self, name: &str) { let mut data = self.data.lock().unwrap(); let mut update = false; if let Some(value) = data.get(name) { if let Some(modified) = value.modified { update = modified_after(&[&value.cert_fname, &value.key_fname], modified) .unwrap_or(true); } } else { update = true; } if update { let identity = match Identity::from_name(&self.dir, name) { Ok(identity) => identity, Err(e) => { debug!("failed to load cert {}: {}", name, e); return; } }; data.insert(String::from(name), identity); debug!("loaded cert: {}", name); } } } trait ReadWrite: Read + Write + Any + Send { fn as_any(&mut self) -> &mut dyn Any; fn into_any(self: Box) -> Box; } impl ReadWrite for T { fn as_any(&mut self) -> &mut dyn Any { self } fn into_any(self: Box) -> Box { self } } enum Stream { Ssl(SslStream), MidHandshakeSsl(MidHandshakeSslStream), NoSsl, } pub struct TlsAcceptor { acceptor: SslAcceptor, } impl TlsAcceptor { pub fn new(cache: &Arc, default_cert: Option<&str>) -> Self { let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let cache = Arc::clone(cache); let default_cert: Option = default_cert.map(|s| s.to_owned()); acceptor.set_servername_callback(move |ssl, _| { let identity = match ssl.servername(NameType::HOST_NAME) { Some(name) => { debug!("tls server name: {}", name); match cache.get_by_domain(name) { Some(ctx) => ctx, None => match &default_cert { Some(default_cert) => match cache.get_by_name(default_cert) { Some(ctx) => ctx, None => return Err(SniError::ALERT_FATAL), }, None => return Err(SniError::ALERT_FATAL), }, } } None => match &default_cert { Some(default_cert) => match cache.get_by_name(default_cert) { Some(ctx) => ctx, None => return Err(SniError::ALERT_FATAL), }, None => return Err(SniError::ALERT_FATAL), }, }; debug!("using cert: {}", identity.name); if ssl.set_ssl_context(&identity.value.ssl_context).is_err() { return Err(SniError::ALERT_FATAL); } Ok(()) }); Self { acceptor: acceptor.build(), } } pub fn new_self_signed() -> Self { let mut acceptor = SslAcceptor::mozilla_intermediate(SslMethod::tls()).unwrap(); let cert_pem = concat!( "-----BEGIN CERTIFICATE-----\n", "MIICpDCCAYwCCQDkzIPOmEje1DANBgkqhkiG9w0BAQsFADAUMRIwEAYDVQQDDAls\n", "b2NhbGhvc3QwHhcNMjMwNjA4MjIxMjE3WhcNMjMwNjA5MjIxMjE3WjAUMRIwEAYD\n", "VQQDDAlsb2NhbGhvc3QwggEiMA0GCSqGSIb3DQEBAQUAA4IBDwAwggEKAoIBAQC7\n", "Lj9eFGJ0hsbtn1ebNaakK/f3tktLbYhT7eZ547T1OYfPs9stk7ZMaNPXv/CPbz4x\n", "5NZC89rghUScZYFGAfQE5Rxrso8vUzUSAzRebSm5LG3BYsHyKf7lZkD3cK1kBPtl\n", "lRMQ0/Jg6WkUglYWV8/2Cm8SoJpdllBbbl0bOu1S8QMswb4IrZ1UE130tbP5SnSb\n", "bke2ahVrnJ2lC63sD64rBedYWm5FSHlJ2ciRPe1tr+owqSVrHrjZjrTHovyMVsff\n", "BFJ1iVfnzkxR/tyGFlHHngkRdwtO81Orc9yAIe8v1U3y6F+Tk2LIwW4PYh/xqj4W\n", "ijPttBqrybO5T+jDV/PNAgMBAAEwDQYJKoZIhvcNAQELBQADggEBADQmWrdkwdtR\n", "Fu+9GBjXsmjPNvN72Da4UtLf8Y+LgA/XYKGCFaGxpFm+61DOpbjpUR3B8MRQzn45\n", "x4/ZcNmRrYj7yiBlj/Y/bQKfBLaTG2JCJ2ffdBgZMPG3U9wLQKsUbOsdznkSYG18\n", "CGTM3btznIlW7pkDsw3CRkKoYWNRd0STzifa2ASCEgRAFemYIj/YysVw6nWTtIHY\n", "5Ez+TDwOpUkuk2haE6UvaxR0+q3r+10907HqZejyLmSY+FQk1ylAfJtJcJvpbrB+\n", "kQa8kPmOm+hnLGDXFI0qfBHfuiKDX7yi39aFgWI/Mbz5wKHr0IIoJmncayYacnGX\n", "coUhiF2hpf0=\n", "-----END CERTIFICATE-----\n" ); let key_pem = concat!( "-----BEGIN PRIVATE KEY-----\n", "MIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC7Lj9eFGJ0hsbt\n", "n1ebNaakK/f3tktLbYhT7eZ547T1OYfPs9stk7ZMaNPXv/CPbz4x5NZC89rghUSc\n", "ZYFGAfQE5Rxrso8vUzUSAzRebSm5LG3BYsHyKf7lZkD3cK1kBPtllRMQ0/Jg6WkU\n", "glYWV8/2Cm8SoJpdllBbbl0bOu1S8QMswb4IrZ1UE130tbP5SnSbbke2ahVrnJ2l\n", "C63sD64rBedYWm5FSHlJ2ciRPe1tr+owqSVrHrjZjrTHovyMVsffBFJ1iVfnzkxR\n", "/tyGFlHHngkRdwtO81Orc9yAIe8v1U3y6F+Tk2LIwW4PYh/xqj4WijPttBqrybO5\n", "T+jDV/PNAgMBAAECggEAB1lIeZwZRXPpKXkhCmHv2fAz+xC4Igz51jm327854orQ\n", "rzHjgAWVmahf8M+DVU5Lxc+zLcu/IyN4Tx+ZFLOM7ghEtmG7R2Nf6QYhLzff9Hov\n", "EPGcpbJKZJ1AHbbZx9x+Nj3FEtsPYAip7Hk1ggkOjB1awQN3LAdzvjM2CpSkrqXg\n", "c4GQ4hK3tkyIZxPiC6pr6246+UjakzFGXT5zzQajbkFHrM8s4Wn42tbdd6N14jgv\n", "5mdR6bAzusG8P3IRlO4zQ/NQTCXI6kz4SdTlZERaxt35pThXRkcifMPcGRTageax\n", "l1ZxBIRjTSp60tPR6fcH8std8hEcRExcOeCmOld4gQKBgQDwWz5vQCUyvza6l/O3\n", "G6huXmQcpFea5PpWtII55bp3DTen6SrB3cGGtKZZqfN7IXFODUIUIvQEf4bI8r0y\n", "Vu6Sypnq+CIbRN5aul7X+do5gEpFEZW+BdbBN+mCBaf16xaxS9GWZj1wCWSjyE4s\n", "PE7jEbLgVPwd+8FmK3XemaF7bQKBgQDHXQC7XjZ0OxfeAOVLz1vShBBlbDtJEonY\n", "cuSveZqEiLEaUFuU3XFuExbyfCRjNNsz6JROXvCO2KQ6HbI/tkZCmJYoQ8mhhAF+\n", "5QN9hGZgMPcvPEZW4AEih5qVrwO3IQGF3YJnYLvyyroEjQ7nSwCf/HPCF5Gl/K41\n", "QPRlM5e94QKBgFyhPYGQfgV9rbDhqLpTvWizle934o8+WcAalumLQH5rKJzcfm7y\n", "cIfijQ2XMs+sRsdm0qWCBvrIzwAYlJOW7yDBVeo5MKPDudHLa4verZxldboCmev+\n", "whH641IJrf5XWIqBhsdopZrM8+0u3/mqUFiwVHiiJ/vCL3mZnDZqjNJNAoGAFge2\n", "7v2IMuvcxVGABRKS6P5i+XIuUvLTfLGlh6Z+ZqrcNzYuCJM315wQaxdAxh2vI1tO\n", "GCLxnjdeXnWtntC7jtxhq21iOJDnwWf5LMOWtIZ0qimU9ECon3IwqN3AIVpqWqqR\n", "oG7WFgxE5f/YZ8Kn/QXenNIR7C+x6HyXBR/gYsECgYEAg6PSkpYdOxaTZzaxIxS3\n", "HUUy7H1+wzV/ZCKIMZEfH23kUiHMZXjp3xI1FTlGcbMFpOkmjwi+MFHEMcvmwzmc\n", "owdohdh7ngo60nkgMwz5TyWBWDdT+Otogi7F37qAt/fjd4xmNjsyTY4b2OwuP1/S\n", "X7Rmwy1AQ2WKrwOSy4d3xDs=\n", "-----END PRIVATE KEY-----" ); let cert = X509::from_pem(cert_pem.as_bytes()).unwrap(); let key = PKey::private_key_from_pem(key_pem.as_bytes()).unwrap(); acceptor.set_certificate(&cert).unwrap(); acceptor.set_private_key(&key).unwrap(); Self { acceptor: acceptor.build(), } } pub fn accept( &self, stream: mio::net::TcpStream, ) -> Result, ssl::Error> { let result = TlsStream::new(false, stream, |stream| { let stream = match self.acceptor.accept(stream) { Ok(stream) => Stream::Ssl(stream), Err(HandshakeError::SetupFailure(e)) => return Err(e.into()), Err(HandshakeError::Failure(stream)) => return Err(stream.into_error()), Err(HandshakeError::WouldBlock(stream)) => Stream::MidHandshakeSsl(stream), }; Ok(stream) }); match result { Ok(stream) => Ok(stream), Err((_, e)) => Err(e), } } } pub enum VerifyMode { Full, None, } #[derive(Debug)] pub enum TlsStreamError { Io(io::Error), Ssl(ErrorStack), Unusable, } impl TlsStreamError { fn into_io_error(self) -> io::Error { match self { TlsStreamError::Io(e) => e, _ => io::Error::from(io::ErrorKind::Other), } } } impl From for TlsStreamError { fn from(e: ssl::Error) -> Self { match e.into_io_error() { Ok(e) => Self::Io(e), Err(e) => match e.ssl_error() { Some(e) => Self::Ssl(e.clone()), None => Self::Io(io::Error::from(io::ErrorKind::Other)), }, } } } fn replace_at(value_at: &mut T, replace_fn: F) where F: FnOnce(T) -> T, { // SAFETY: we use ptr::read to get the current value and then put a new // value in its place with ptr::write before returning unsafe { let p = value_at as *mut T; ptr::write(p, replace_fn(ptr::read(p))); } } fn apply_wants(e: &ssl::Error, interests: &mut Option) { match e.code() { ssl::ErrorCode::WANT_READ => *interests = Some(mio::Interest::READABLE), ssl::ErrorCode::WANT_WRITE => *interests = Some(mio::Interest::WRITABLE), _ => {} } } struct Connector { inner: Arc, created: Instant, } struct Connectors { verify_full: Option, verify_none: Option, } // represents a cache of reusable data among sessions. internally, this data // consists of SslConnectors for the purpose of caching root certs read from // disk. the type is given a vague name in order to avoid committing to what // exactly is cached. pub struct TlsConfigCache { connectors: Mutex, } impl Default for TlsConfigCache { fn default() -> Self { Self::new() } } impl TlsConfigCache { pub fn new() -> Self { Self { connectors: Mutex::new(Connectors { verify_full: None, verify_none: None, }), } } fn get_connector(&self, verify_mode: VerifyMode) -> Result, ErrorStack> { let mut connectors = self .connectors .lock() .expect("failed to obtain lock on tls config cache"); let slot = match verify_mode { VerifyMode::Full => &mut connectors.verify_full, VerifyMode::None => &mut connectors.verify_none, }; let connector = match slot { Some(c) if c.created.elapsed() < CONFIG_CACHE_TTL => &c.inner, _ => { let mut builder = SslConnector::builder(SslMethod::tls())?; match verify_mode { VerifyMode::Full => builder.set_verify(SslVerifyMode::PEER), VerifyMode::None => builder.set_verify(SslVerifyMode::NONE), } let c = slot.insert(Connector { inner: Arc::new(builder.build()), created: Instant::now(), }); &c.inner } }; Ok(Arc::clone(connector)) } } pub struct TlsStream { stream: Stream<&'static mut Box>, plain_stream: Box>, id: ArrayString<64>, client: bool, interests_for_handshake: Option, interests_for_shutdown: Option, interests_for_read: Option, interests_for_write: Option, _marker: PhantomData, } impl TlsStream where T: Read + Write + Any + Send, { pub fn connect( domain: &str, stream: T, verify_mode: VerifyMode, config_cache: &TlsConfigCache, ) -> Result { Self::new(true, stream, |stream| { let connector = config_cache.get_connector(verify_mode)?; let stream = match connector.connect(domain, stream) { Ok(stream) => Stream::Ssl(stream), Err(HandshakeError::SetupFailure(e)) => return Err(e.into()), Err(HandshakeError::Failure(stream)) => return Err(stream.into_error()), Err(HandshakeError::WouldBlock(stream)) => Stream::MidHandshakeSsl(stream), }; Ok(stream) }) } pub fn get_inner<'a>(&'a mut self) -> &'a mut T { let plain_stream: &'a mut Box = match &mut self.stream { Stream::Ssl(stream) => stream.get_mut(), Stream::MidHandshakeSsl(stream) => stream.get_mut(), Stream::NoSsl => Box::as_mut(&mut self.plain_stream), }; let plain_stream: &mut dyn ReadWrite = Box::as_mut(plain_stream); plain_stream.as_any().downcast_mut().unwrap() } #[allow(clippy::result_unit_err)] pub fn set_id(&mut self, id: &str) -> Result<(), ()> { self.id = match ArrayString::from_str(id) { Ok(s) => s, Err(_) => return Err(()), }; Ok(()) } pub fn interests_for_handshake(&self) -> Option { self.interests_for_handshake } pub fn interests_for_shutdown(&self) -> Option { self.interests_for_shutdown } pub fn interests_for_read(&self) -> Option { self.interests_for_read } pub fn interests_for_write(&self) -> Option { self.interests_for_write } pub fn ensure_handshake(&mut self) -> Result<(), TlsStreamError> { self.interests_for_handshake = None; match &self.stream { Stream::Ssl(_) => Ok(()), Stream::MidHandshakeSsl(_) => match mem::replace(&mut self.stream, Stream::NoSsl) { Stream::MidHandshakeSsl(stream) => match stream.handshake() { Ok(stream) => { debug!("{} {}: tls handshake success", self.log_prefix(), self.id); self.stream = Stream::Ssl(stream); Ok(()) } Err(HandshakeError::SetupFailure(e)) => Err(TlsStreamError::Ssl(e)), Err(HandshakeError::Failure(stream)) => Err(stream.into_error().into()), Err(HandshakeError::WouldBlock(stream)) => { apply_wants(stream.error(), &mut self.interests_for_handshake); self.stream = Stream::MidHandshakeSsl(stream); Err(TlsStreamError::Io(io::Error::from( io::ErrorKind::WouldBlock, ))) } }, _ => unreachable!(), }, Stream::NoSsl => Err(TlsStreamError::Unusable), } } pub fn shutdown(&mut self) -> Result<(), io::Error> { self.interests_for_shutdown = None; let stream = match &mut self.stream { Stream::Ssl(stream) => stream, _ => return Err(io::Error::from(io::ErrorKind::Other)), }; if let Err(e) = stream.shutdown() { apply_wants(&e, &mut self.interests_for_shutdown); match e.into_io_error() { Ok(e) => return Err(e), Err(_) => return Err(io::Error::from(io::ErrorKind::Other)), } } debug!("{} {}: tls shutdown sent", self.log_prefix(), self.id); Ok(()) } pub fn change_inner(mut self, change_fn: F) -> TlsStream where F: FnOnce(T) -> U, U: Read + Write + Any + Send, { let plain_stream: &mut Box = Box::as_mut(&mut self.plain_stream); replace_at(plain_stream, |plain_stream: Box| { let plain_stream: Box = plain_stream.into_any().downcast().unwrap(); let plain_stream: U = change_fn(*plain_stream); Box::new(plain_stream) }); // SAFETY: nothing is changing except the phantom data type unsafe { mem::transmute(self) } } fn new(client: bool, stream: T, init_fn: F) -> Result where F: FnOnce( &'static mut Box, ) -> Result>, ssl::Error>, { // box the stream, casting to ReadWrite let inner_box: Box = Box::new(stream); // box it again. this way we have a pointer-to-a-pointer on the heap, // allowing us to change where the outer pointer points to later on // without changing its location let mut outer_box: Box> = Box::new(inner_box); // get the outer pointer let outer: &mut Box = Box::as_mut(&mut outer_box); // safety: TlsStream will take ownership of outer_box, and the value // referred to by outer_box is on the heap, and outer_box will not // be dropped until TlsStream is dropped, so the value referred to // by outer_box will remain valid for the lifetime of TlsStream. // further, outer is a mutable reference, and will only ever be // exclusively mutably accessed, either when wrapped by SslStream // or MidHandshakeSslStream, or when known to be not wrapped let outer: &'static mut Box = unsafe { mem::transmute(outer) }; let stream = match init_fn(outer) { Ok(stream) => stream, Err(e) => { let inner_box: Box = *outer_box; let stream: T = *inner_box.into_any().downcast().unwrap(); return Err((stream, e)); } }; Ok(Self { stream, plain_stream: outer_box, id: ArrayString::from("").unwrap(), client, interests_for_handshake: None, interests_for_shutdown: None, interests_for_read: None, interests_for_write: None, _marker: PhantomData, }) } fn log_prefix(&self) -> &'static str { if self.client { "client-conn" } else { "conn" } } fn ssl_read(&mut self, buf: &mut [u8]) -> Result { self.interests_for_read = None; if let Err(e) = self.ensure_handshake() { match &e { TlsStreamError::Io(e) if e.kind() == io::ErrorKind::WouldBlock => { self.interests_for_read = self.interests_for_handshake; } _ => {} } return Err(e); } let stream = match &mut self.stream { Stream::Ssl(stream) => stream, _ => unreachable!(), }; match stream.ssl_read(buf) { Ok(size) => Ok(size), Err(e) if e.code() == ssl::ErrorCode::ZERO_RETURN => Ok(0), Err(e) => { apply_wants(&e, &mut self.interests_for_read); Err(e.into()) } } } fn ssl_write(&mut self, buf: &[u8]) -> Result { self.interests_for_write = None; if let Err(e) = self.ensure_handshake() { match &e { TlsStreamError::Io(e) if e.kind() == io::ErrorKind::WouldBlock => { self.interests_for_write = self.interests_for_handshake; } _ => {} } return Err(e); } let stream = match &mut self.stream { Stream::Ssl(stream) => stream, _ => unreachable!(), }; match stream.ssl_write(buf) { Ok(size) => Ok(size), Err(e) => { apply_wants(&e, &mut self.interests_for_write); Err(e.into()) } } } } impl Read for TlsStream where T: Read + Write + Any + Send, { fn read(&mut self, buf: &mut [u8]) -> Result { match self.ssl_read(buf) { Ok(size) => Ok(size), Err(e) => Err(e.into_io_error()), } } } impl Write for TlsStream where T: Read + Write + Any + Send, { fn write(&mut self, buf: &[u8]) -> Result { match self.ssl_write(buf) { Ok(size) => Ok(size), Err(e) => Err(e.into_io_error()), } } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } struct TlsOpInner { readiness: event::Readiness, waker: Option<(Waker, mio::Interest)>, } struct TlsOp { inner: RefCell, } impl TlsOp { fn new() -> Self { Self { inner: RefCell::new(TlsOpInner { readiness: None, waker: None, }), } } fn readiness(&self) -> event::Readiness { self.inner.borrow().readiness } pub fn set_readiness(&self, readiness: event::Readiness) { self.inner.borrow_mut().readiness = readiness; } pub fn clear_readiness(&self, readiness: mio::Interest) { let inner = &mut *self.inner.borrow_mut(); if let Some(cur) = inner.readiness.take() { inner.readiness = cur.remove(readiness); } } fn set_waker(&self, waker: &Waker, interest: mio::Interest) { let inner = &mut *self.inner.borrow_mut(); let waker = if let Some((current_waker, _)) = inner.waker.take() { if current_waker.will_wake(waker) { // keep the current waker current_waker } else { // switch to the new waker waker.clone() } } else { // we didn't have a waker yet, so we'll use this one waker.clone() }; inner.waker = Some((waker, interest)); } fn clear_waker(&self) { let inner = &mut *self.inner.borrow_mut(); inner.waker = None; } fn apply_readiness(&self, readiness: mio::Interest) { let inner = &mut *self.inner.borrow_mut(); let (became_readable, became_writable) = { let prev_readiness = inner.readiness; inner.readiness.merge(readiness); ( !prev_readiness.contains_any(mio::Interest::READABLE) && inner.readiness.contains_any(mio::Interest::READABLE), !prev_readiness.contains_any(mio::Interest::WRITABLE) && inner.readiness.contains_any(mio::Interest::WRITABLE), ) }; if became_readable || became_writable { if let Some((_, interest)) = &inner.waker { if (became_readable && interest.is_readable()) || (became_writable && interest.is_writable()) { let (waker, _) = inner.waker.take().unwrap(); waker.wake(); } } } } } pub struct TlsWaker { registration: RefCell>, handshake: TlsOp, shutdown: TlsOp, read: TlsOp, write: TlsOp, } #[allow(clippy::new_without_default)] impl TlsWaker { pub fn new() -> Self { Self { registration: RefCell::new(None), handshake: TlsOp::new(), shutdown: TlsOp::new(), read: TlsOp::new(), write: TlsOp::new(), } } fn registration(&self) -> Ref<'_, Registration> { Ref::map(self.registration.borrow(), |b| b.as_ref().unwrap()) } fn set_registration(&self, registration: Registration) { let readiness = registration.readiness(); registration.clear_readiness(mio::Interest::READABLE | mio::Interest::WRITABLE); for op in [&self.handshake, &self.shutdown, &self.read, &self.write] { op.set_readiness(readiness); } *self.registration.borrow_mut() = Some(registration); } fn take_registration(&self) -> Registration { self.registration.borrow_mut().take().unwrap() } } impl RefWake for TlsWaker { fn wake(&self) { if let Some(readiness) = self.registration().readiness() { self.registration() .clear_readiness(mio::Interest::READABLE | mio::Interest::WRITABLE); for op in [&self.handshake, &self.shutdown, &self.read, &self.write] { op.apply_readiness(readiness); } } } } pub struct AsyncTlsStream<'a> { waker: RefWaker<'a, TlsWaker>, stream: Option>, } impl<'a: 'b, 'b> AsyncTlsStream<'a> { pub fn new(mut s: TlsStream, waker_data: &'a RefWakerData) -> Self { let registration = get_reactor() .register_io( s.get_inner(), mio::Interest::READABLE | mio::Interest::WRITABLE, ) .unwrap(); // assume I/O operations are ready to be attempted registration.set_readiness(Some(mio::Interest::READABLE | mio::Interest::WRITABLE)); Self::new_with_registration(s, waker_data, registration) } pub fn connect( domain: &str, stream: AsyncTcpStream, verify_mode: VerifyMode, waker_data: &'a RefWakerData, config_cache: &TlsConfigCache, ) -> Result { let (registration, stream) = stream.into_evented().into_parts(); let stream = match TlsStream::connect(domain, stream, verify_mode, config_cache) { Ok(stream) => stream, Err((mut stream, e)) => { registration.deregister_io(&mut stream).unwrap(); return Err(e); } }; Ok(Self::new_with_registration( stream, waker_data, registration, )) } pub fn ensure_handshake(&'b mut self) -> EnsureHandshakeFuture<'a, 'b> { EnsureHandshakeFuture { s: self } } pub fn inner(&mut self) -> &mut TlsStream { self.stream.as_mut().unwrap() } pub fn into_inner(mut self) -> TlsStream { let mut stream = self.stream.take().unwrap(); self.waker .registration() .deregister_io(stream.get_inner()) .unwrap(); stream } pub fn into_std(mut self) -> TlsStream { let mut stream = self.stream.take().unwrap(); self.waker .registration() .deregister_io(stream.get_inner()) .unwrap(); stream.change_inner(|stream| unsafe { std::net::TcpStream::from_raw_fd(stream.into_raw_fd()) }) } // assumes stream is in non-blocking mode pub fn from_std( stream: TlsStream, waker_data: &'a RefWakerData, ) -> Self { let stream = stream.change_inner(TcpStream::from_std); Self::new(stream, waker_data) } fn new_with_registration( s: TlsStream, waker_data: &'a RefWakerData, registration: Registration, ) -> Self { let waker = RefWaker::new(waker_data); waker.set_registration(registration); waker.registration().set_waker_persistent(true); waker.registration().set_waker( waker.as_std(&mut mem::MaybeUninit::uninit()), mio::Interest::READABLE | mio::Interest::WRITABLE, ); Self { waker, stream: Some(s), } } } impl Drop for AsyncTlsStream<'_> { fn drop(&mut self) { let registration = self.waker.take_registration(); if let Some(stream) = &mut self.stream { registration.deregister_io(stream.get_inner()).unwrap(); } } } impl AsyncRead for AsyncTlsStream<'_> { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8], ) -> Poll> { let f = &mut *self; let registration = f.waker.registration(); let op = &f.waker.read; let stream = f.stream.as_mut().unwrap(); let interests = stream.interests_for_read(); if let Some(interests) = interests { if !op.readiness().contains_any(interests) { op.set_waker(cx.waker(), interests); return Poll::Pending; } } if !registration.pull_from_budget_with_waker(cx.waker()) { return Poll::Pending; } match stream.read(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { let interests = stream.interests_for_read().unwrap(); op.clear_readiness(interests); op.set_waker(cx.waker(), interests); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn cancel(&mut self) { let op = &self.waker.read; op.clear_waker(); } } impl AsyncWrite for AsyncTlsStream<'_> { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let f = &mut *self; let registration = f.waker.registration(); let op = &f.waker.write; let stream = f.stream.as_mut().unwrap(); let interests = stream.interests_for_write(); if let Some(interests) = interests { if !op.readiness().contains_any(interests) { op.set_waker(cx.waker(), interests); return Poll::Pending; } } if !registration.pull_from_budget_with_waker(cx.waker()) { return Poll::Pending; } match stream.write(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { let interests = stream.interests_for_write().unwrap(); op.clear_readiness(interests); op.set_waker(cx.waker(), interests); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll> { let f = &mut *self; let registration = f.waker.registration(); let op = &f.waker.shutdown; let stream = f.stream.as_mut().unwrap(); let interests = stream.interests_for_shutdown(); if let Some(interests) = interests { if !op.readiness().contains_any(interests) { op.set_waker(cx.waker(), interests); return Poll::Pending; } } if !registration.pull_from_budget_with_waker(cx.waker()) { return Poll::Pending; } match stream.shutdown() { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { let interests = stream.interests_for_shutdown().unwrap(); op.clear_readiness(interests); op.set_waker(cx.waker(), interests); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn is_writable(&self) -> bool { let op = &self.waker.write; let stream = self.stream.as_ref().unwrap(); if let Some(interests) = stream.interests_for_write() { op.readiness().contains_any(interests) } else { true } } fn cancel(&mut self) { let write_op = &self.waker.write; let shutdown_op = &self.waker.shutdown; write_op.clear_waker(); shutdown_op.clear_waker(); } } pub struct EnsureHandshakeFuture<'a, 'b> { s: &'b mut AsyncTlsStream<'a>, } impl Future for EnsureHandshakeFuture<'_, '_> { type Output = Result<(), TlsStreamError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let registration = f.s.waker.registration(); let op = &f.s.waker.handshake; let stream = f.s.stream.as_mut().unwrap(); let interests = stream.interests_for_handshake(); if let Some(interests) = interests { if !op.readiness().contains_any(interests) { op.set_waker(cx.waker(), interests); return Poll::Pending; } } if !registration.pull_from_budget_with_waker(cx.waker()) { return Poll::Pending; } match stream.ensure_handshake() { Ok(()) => Poll::Ready(Ok(())), Err(TlsStreamError::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { let interests = stream.interests_for_handshake().unwrap(); op.clear_readiness(interests); op.set_waker(cx.waker(), interests); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for EnsureHandshakeFuture<'_, '_> { fn drop(&mut self) { let op = &self.s.waker.handshake; op.clear_waker(); } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use crate::core::io::{AsyncReadExt, AsyncWriteExt}; use crate::core::net::AsyncTcpListener; use crate::core::reactor::Reactor; use std::str; #[derive(Debug)] struct ReadWriteA { a: i32, } impl Read for ReadWriteA { fn read(&mut self, _buf: &mut [u8]) -> Result { Err(io::Error::from(io::ErrorKind::WouldBlock)) } } impl Write for ReadWriteA { fn write(&mut self, _buf: &[u8]) -> Result { Err(io::Error::from(io::ErrorKind::WouldBlock)) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } #[derive(Debug)] struct ReadWriteB { b: i32, } impl Read for ReadWriteB { fn read(&mut self, _buf: &mut [u8]) -> Result { Err(io::Error::from(io::ErrorKind::WouldBlock)) } } impl Write for ReadWriteB { fn write(&mut self, _buf: &[u8]) -> Result { Err(io::Error::from(io::ErrorKind::WouldBlock)) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } #[derive(Debug)] struct ReadWriteC { c: i32, } impl Read for ReadWriteC { fn read(&mut self, _buf: &mut [u8]) -> Result { Err(io::Error::from(io::ErrorKind::Other)) } } impl Write for ReadWriteC { fn write(&mut self, _buf: &[u8]) -> Result { Err(io::Error::from(io::ErrorKind::Other)) } fn flush(&mut self) -> Result<(), io::Error> { Err(io::Error::from(io::ErrorKind::Other)) } } #[test] fn test_get_change_inner() { let a = ReadWriteA { a: 1 }; let mut stream = TlsStream::connect("localhost", a, VerifyMode::Full, &TlsConfigCache::new()).unwrap(); assert_eq!(stream.get_inner().a, 1); let mut stream = stream.change_inner(|_| ReadWriteB { b: 2 }); assert_eq!(stream.get_inner().b, 2); } #[test] fn test_connect_error() { let c = ReadWriteC { c: 1 }; let (stream, e) = match TlsStream::connect("localhost", c, VerifyMode::Full, &TlsConfigCache::new()) { Ok(_) => panic!("unexpected success"), Err(ret) => ret, }; assert_eq!(stream.c, 1); assert_eq!(e.into_io_error().unwrap().kind(), io::ErrorKind::Other); } #[test] fn test_async_tlsstream() { let reactor = Reactor::new(3); // 3 registrations let executor = Executor::new(2); // 2 tasks let spawner = executor.spawner(); executor .spawn(async move { let addr = "127.0.0.1:0".parse().unwrap(); let listener = AsyncTcpListener::bind(addr).expect("failed to bind"); let acceptor = TlsAcceptor::new_self_signed(); let addr = listener.local_addr().unwrap(); spawner .spawn(async move { let stream = AsyncTcpStream::connect(&[addr]).await.unwrap(); let tls_waker_data = RefWakerData::new(TlsWaker::new()); let mut stream = AsyncTlsStream::connect( "localhost", stream, VerifyMode::None, &tls_waker_data, &TlsConfigCache::new(), ) .unwrap(); stream.ensure_handshake().await.unwrap(); let size = stream.write("hello".as_bytes()).await.unwrap(); assert_eq!(size, 5); stream.close().await.unwrap(); }) .unwrap(); let (stream, _) = listener.accept().await.unwrap(); let stream = acceptor.accept(stream).unwrap(); let tls_waker_data = RefWakerData::new(TlsWaker::new()); let mut stream = AsyncTlsStream::new(stream, &tls_waker_data); let mut resp = [0u8; 1024]; let mut resp = io::Cursor::new(&mut resp[..]); loop { let mut buf = [0; 1024]; let size = stream.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.write(&buf[..size]).unwrap(); } let size = resp.position() as usize; let resp = str::from_utf8(&resp.get_ref()[..size]).unwrap(); assert_eq!(resp, "hello"); stream.close().await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } } pushpin-1.41.0/src/connmgr/track.rs000066400000000000000000000156021504671364300172040ustar00rootroot00000000000000/* * Copyright (C) 2023 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::channel::AsyncLocalReceiver; use std::cell::Cell; use std::future::Future; use std::ops::Deref; use std::pin::Pin; use std::sync::mpsc; use std::task::{Context, Poll}; #[derive(Default)] pub struct TrackFlag(Cell); impl TrackFlag { pub fn get(&self) -> bool { self.0.get() } pub fn set(&self, v: bool) { self.0.set(v); } } struct TrackInner<'a, T> { value: T, active: &'a TrackFlag, } // wrap a value and a shared flag representing the value's liveness. on init, // the flag is set to true. on drop, the flag is set to false pub struct Track<'a, T> { inner: Option>, } impl<'a, T> Track<'a, T> { pub fn new(value: T, active: &'a TrackFlag) -> Self { active.set(true); Self { inner: Some(TrackInner { value, active }), } } } impl<'a, A, B> Track<'a, (A, B)> { pub fn map_first(mut orig: Self) -> (Track<'a, A>, B) { let ((a, b), active) = { let inner = orig.inner.take().unwrap(); drop(orig); (inner.value, inner.active) }; (Track::new(a, active), b) } } impl Drop for Track<'_, T> { fn drop(&mut self) { if let Some(inner) = &self.inner { inner.active.set(false); } } } impl Deref for Track<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { &self.inner.as_ref().unwrap().value } } #[derive(Debug, PartialEq)] pub enum RecvError { Disconnected, ValueActive, } // wrap an AsyncLocalReceiver and a shared flag representing the liveness of // one received value at a time. each received value is wrapped in Track and // must be dropped before reading the next value pub struct TrackedAsyncLocalReceiver<'a, T> { inner: AsyncLocalReceiver, value_active: &'a TrackFlag, } impl<'a, T> TrackedAsyncLocalReceiver<'a, T> { pub fn new(r: AsyncLocalReceiver, value_active: &'a TrackFlag) -> Self { value_active.set(false); Self { inner: r, value_active, } } // attempt to receive a value from the inner receiver. if a previously // received value has not been dropped, this method returns an error pub async fn recv(&self) -> Result, RecvError> { if self.value_active.get() { return Err(RecvError::ValueActive); } let v = match self.inner.recv().await { Ok(v) => v, Err(mpsc::RecvError) => return Err(RecvError::Disconnected), }; Ok(Track::new(v, self.value_active)) } } #[derive(Debug, PartialEq)] pub struct ValueActiveError; pub struct TrackFuture<'a, F> { fut: F, value_active: &'a TrackFlag, } impl Future for TrackFuture<'_, F> where F: Future>, E: From, { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { // SAFETY: pin projection let (fut, value_active) = unsafe { let s = self.get_unchecked_mut(); let fut = Pin::new_unchecked(&mut s.fut); (fut, &s.value_active) }; let result = fut.poll(cx); if value_active.get() { return Poll::Ready(Err(ValueActiveError.into())); } result } } // wrap a future and a shared flag representing the liveness of some value. // if the value is true after polling the inner future, return an error pub fn track_future(fut: F, value_active: &TrackFlag) -> TrackFuture<'_, F> where F: Future, { TrackFuture { fut, value_active } } #[cfg(test)] mod tests { use super::*; use crate::core::channel; use crate::core::executor::Executor; use crate::core::reactor::Reactor; use crate::core::task::yield_task; #[test] fn track_value() { let f = TrackFlag::default(); let v = Track::new(42, &f); assert!(f.get()); assert_eq!(*v, 42); drop(v); assert!(!f.get()); } #[test] fn track_async_local_receiver() { let reactor = Reactor::new(2); let executor = Executor::new(1); let (s, r) = channel::local_channel(2, 1, &reactor.local_registration_memory()); s.try_send(1).unwrap(); s.try_send(2).unwrap(); drop(s); executor .spawn(async move { let f = TrackFlag::default(); let r = TrackedAsyncLocalReceiver::new(AsyncLocalReceiver::new(r), &f); let v = r.recv().await.unwrap(); assert_eq!(*v, 1); assert!(r.recv().await.is_err()); drop(v); let v = r.recv().await.unwrap(); assert_eq!(*v, 2); assert!(r.recv().await.is_err()); // no values left drop(v); assert!(r.recv().await.is_err()); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn track_value_and_future() { let executor = Executor::new(1); executor .spawn(async move { let f = TrackFlag::default(); // awaiting while the flag is active is an error let ret = track_future( async { let v = Track::new(1, &f); // this line will cause the error yield_task().await; // this line never reached drop(v); Ok(()) }, &f, ) .await; assert_eq!(ret, Err(ValueActiveError)); // awaiting while the flag is not active is ok let ret: Result<_, ValueActiveError> = track_future( async { let v = Track::new(1, &f); drop(v); // this is ok yield_task().await; Ok(()) }, &f, ) .await; assert_eq!(ret, Ok(())); }) .unwrap(); executor.run(|_| Ok(())).unwrap(); } } pushpin-1.41.0/src/connmgr/websocket.rs000066400000000000000000002105311504671364300200640ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::buffer::{ trim_for_display, write_vectored_offset, Buffer, LimitBufsMut, RingBuffer, VECTORED_MAX, }; use crate::core::http1::HeaderParamsIterator; use arrayvec::ArrayVec; use log::{log_enabled, trace}; use miniz_oxide::deflate; use miniz_oxide::inflate::stream::{inflate, InflateState}; use miniz_oxide::{DataFormat, MZError, MZFlush, MZStatus}; use std::ascii; use std::cell::{Cell, RefCell}; use std::cmp; use std::fmt; use std::io; use std::io::Write; use std::mem::{self, MaybeUninit}; pub const WS_GUID: &str = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; // 1 byte flags + 9 bytes payload size + 4 bytes mask pub const HEADER_SIZE_MAX: usize = 14; const LOG_CONTENT_MAX: usize = 1_000; const PSIZE_3BYTE: usize = 126; const PSIZE_9BYTE: usize = 65536; pub const OPCODE_CONTINUATION: u8 = 0; pub const OPCODE_TEXT: u8 = 1; pub const OPCODE_BINARY: u8 = 2; pub const OPCODE_CLOSE: u8 = 8; pub const OPCODE_PING: u8 = 9; pub const OPCODE_PONG: u8 = 10; pub const CONTROL_FRAME_PAYLOAD_MAX: usize = 125; const DEFAULT_MAX_WINDOW_BITS: u8 = 15; const DEFLATE_SUFFIX: [u8; 4] = [0x00, 0x00, 0xff, 0xff]; const ENC_NEXT_BUF_SIZE: usize = DEFLATE_SUFFIX.len(); struct Bufs<'a> { data: &'a [&'a [u8]], } impl<'a> Bufs<'a> { fn new(data: &'a [&'a [u8]]) -> Self { Self { data } } } impl fmt::Display for Bufs<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use fmt::Write; let data: Vec = self .data .iter() .flat_map(|v| *v) .copied() .take(LOG_CONTENT_MAX + 1) .collect(); let mut s = String::new(); write!(&mut s, "\"")?; for b in data { write!(&mut s, "{}", ascii::escape_default(b))?; } write!(&mut s, "\"")?; write!(f, "{}", trim_for_display(&s, LOG_CONTENT_MAX)) } } #[derive(Clone, Copy)] pub struct FrameInfo { pub fin: bool, pub rsv1: bool, pub opcode: u8, pub mask: Option<[u8; 4]>, pub payload_offset: usize, pub payload_size: usize, } fn header_size(payload_size: usize, masked: bool) -> usize { let size = if payload_size < PSIZE_3BYTE { 1 + 1 } else if payload_size < PSIZE_9BYTE { 1 + 3 } else { 1 + 9 }; if masked { size + 4 } else { size } } pub fn read_header(buf: &[u8]) -> Result { if buf.len() < 2 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let b1 = buf[1] & 0x7f; #[allow(clippy::comparison_chain)] let (mut hsize, psize) = if b1 < (PSIZE_3BYTE as u8) { (2, b1 as usize) } else if b1 == (PSIZE_3BYTE as u8) { if buf.len() < 2 + 2 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut arr = [0; 2]; arr.copy_from_slice(&buf[2..4]); (4, u16::from_be_bytes(arr) as usize) } else { if buf.len() < 2 + 8 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut arr = [0; 8]; arr.copy_from_slice(&buf[2..10]); (10, u64::from_be_bytes(arr) as usize) }; let mask = if buf[1] & 0x80 != 0 { if buf.len() < hsize + 4 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } let mut mask = [0; 4]; mask.copy_from_slice(&buf[hsize..hsize + 4]); hsize += 4; Some(mask) } else { None }; Ok(FrameInfo { fin: buf[0] & 0x80 != 0, rsv1: buf[0] & 0x40 != 0, opcode: buf[0] & 0x0f, mask, payload_offset: hsize, payload_size: psize, }) } // return payload offset pub fn write_header( fin: bool, rsv1: bool, opcode: u8, payload_size: usize, mask: Option<[u8; 4]>, buf: &mut [u8], ) -> Result { let hsize = header_size(payload_size, mask.is_some()); if buf.len() < hsize { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let mut b0 = 0; if fin { b0 |= 0x80; } if rsv1 { b0 |= 0x40; } b0 |= opcode & 0x0f; buf[0] = b0; let hsize = if payload_size < PSIZE_3BYTE { buf[1] = payload_size as u8; 2 } else if payload_size < PSIZE_9BYTE { buf[1] = PSIZE_3BYTE as u8; let arr = (payload_size as u16).to_be_bytes(); buf[2..4].copy_from_slice(&arr); 4 } else { buf[1] = 0x7f; let arr = (payload_size as u64).to_be_bytes(); buf[2..10].copy_from_slice(&arr); 10 }; if let Some(mask) = mask { buf[1] |= 0x80; buf[hsize..hsize + 4].copy_from_slice(&mask); Ok(hsize + 4) } else { Ok(hsize) } } pub fn apply_mask(buf: &mut [u8], mask: [u8; 4], offset: usize) { for (i, c) in buf.iter_mut().enumerate() { *c ^= mask[(offset + i) % 4]; } } pub fn apply_mask_vectored(bufs: &mut [&mut [u8]], mask: [u8; 4], offset: usize) { let mut count = 0; for buf in bufs { apply_mask(buf, mask, offset + count); count += buf.len(); } } fn parse_empty(s: &str, dest: &mut bool) -> Result<(), io::Error> { // must not be set yet and value must be empty if *dest || !s.is_empty() { return Err(io::Error::from(io::ErrorKind::InvalidData)); } *dest = true; Ok(()) } // set default to allow the param with no value fn parse_bits(s: &str, dest: &mut Option, default: Option) -> Result<(), io::Error> { // must not be set yet if dest.is_some() { return Err(io::Error::from(io::ErrorKind::InvalidData)); } if s.is_empty() { if let Some(x) = default { *dest = Some(x); return Ok(()); } } // must be a valid u8 let x = match s.parse() { Ok(x) => x, Err(_) => return Err(io::Error::from(io::ErrorKind::InvalidData)), }; // number must be between 8 and 15, inclusive if (8..=15).contains(&x) { *dest = Some(x); return Ok(()); } Err(io::Error::from(io::ErrorKind::InvalidData)) } pub struct PerMessageDeflateConfig { pub client_no_context_takeover: bool, pub server_no_context_takeover: bool, pub client_max_window_bits: u8, pub server_max_window_bits: u8, } impl PerMessageDeflateConfig { pub fn from_params(params: HeaderParamsIterator) -> Result { let mut client_no_context_takeover = false; let mut server_no_context_takeover = false; let mut client_max_window_bits = None; let mut server_max_window_bits = None; for param in params { let (k, v) = param?; match k { "client_no_context_takeover" => parse_empty(v, &mut client_no_context_takeover)?, "server_no_context_takeover" => parse_empty(v, &mut server_no_context_takeover)?, "client_max_window_bits" => parse_bits(v, &mut client_max_window_bits, Some(15))?, "server_max_window_bits" => parse_bits(v, &mut server_max_window_bits, None)?, _ => return Err(io::Error::from(io::ErrorKind::InvalidData)), // undefined param } } Ok(Self { client_no_context_takeover, server_no_context_takeover, client_max_window_bits: client_max_window_bits.unwrap_or(DEFAULT_MAX_WINDOW_BITS), server_max_window_bits: server_max_window_bits.unwrap_or(DEFAULT_MAX_WINDOW_BITS), }) } #[allow(clippy::result_unit_err)] pub fn create_response(&self) -> Result { // we don't support non-default server_max_window_bits if self.server_max_window_bits != DEFAULT_MAX_WINDOW_BITS { return Err(()); } Ok(Self { // ack. makes no difference to us client_no_context_takeover: self.client_no_context_takeover, // ack. we'll agree to whatever the client wants server_no_context_takeover: self.server_no_context_takeover, // ignore. we always support the maximum window size client_max_window_bits: DEFAULT_MAX_WINDOW_BITS, // ignore. we require the client to support the maximum window size server_max_window_bits: DEFAULT_MAX_WINDOW_BITS, }) } #[allow(clippy::result_unit_err)] pub fn check_response(&self) -> Result<(), ()> { // we don't support non-default client_max_window_bits if self.client_max_window_bits != DEFAULT_MAX_WINDOW_BITS { return Err(()); } Ok(()) } pub fn serialize(&self, w: &mut W) -> Result<(), io::Error> { if self.client_no_context_takeover { write!(w, "; client_no_context_takeover")?; } if self.server_no_context_takeover { write!(w, "; server_no_context_takeover")?; } if self.client_max_window_bits != DEFAULT_MAX_WINDOW_BITS { write!( w, "; client_max_window_bits={}", self.client_max_window_bits )?; } if self.server_max_window_bits != DEFAULT_MAX_WINDOW_BITS { write!( w, "; server_max_window_bits={}", self.server_max_window_bits )?; } Ok(()) } } impl Default for PerMessageDeflateConfig { fn default() -> Self { Self { client_no_context_takeover: false, server_no_context_takeover: false, client_max_window_bits: DEFAULT_MAX_WINDOW_BITS, server_max_window_bits: DEFAULT_MAX_WINDOW_BITS, } } } trait ArrayVecExt { fn resize(&mut self, new_len: usize, value: T); fn shift_left(&mut self, amount: usize); } impl ArrayVecExt for ArrayVec { fn resize(&mut self, new_len: usize, value: T) { assert!(new_len <= self.capacity()); #[allow(clippy::comparison_chain)] if new_len > self.len() { let old_len = self.len(); unsafe { self.set_len(new_len); } self[old_len..].fill(value); } else if new_len < self.len() { self.truncate(new_len); } } fn shift_left(&mut self, amount: usize) { assert!(amount <= self.len()); self.copy_within(amount.., 0); unsafe { self.set_len(self.len() - amount); } } } pub struct DeflateEncoder { enc: Box, next_buf: ArrayVec, end: bool, } #[allow(clippy::new_without_default)] impl DeflateEncoder { pub fn new() -> Self { let mut enc = Box::::default(); enc.set_format_and_level( DataFormat::Raw, deflate::CompressionLevel::DefaultLevel as u8, ); Self { enc, next_buf: ArrayVec::new(), end: false, } } pub fn reset(&mut self) { self.enc.reset(); } pub fn encode( &mut self, src: &[u8], end: bool, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error> { let (read, mut written, mut end_ack) = self.encode_step(src, end, dest)?; if !src.is_empty() && read == src.len() && end && !end_ack { let (r, w, ea) = self.encode_step(&[], end, &mut dest[written..])?; assert_eq!(r, 0); written += w; end_ack = ea; } Ok((read, written, end_ack)) } pub fn encode_to_ringbuffer + AsMut<[u8]>>( &mut self, src: &[u8], end: bool, dest: &mut RingBuffer, ) -> Result<(usize, bool), io::Error> { let wbuf = dest.write_buf(); let (mut read, written, mut end_ack) = self.encode(src, end, wbuf)?; let write_maxed = written == wbuf.len(); dest.write_commit(written); if !end_ack && write_maxed && dest.remaining_capacity() > 0 { let (r, written, ea) = self.encode(&src[read..], end, dest.write_buf())?; dest.write_commit(written); read += r; end_ack = ea; } Ok((read, end_ack)) } fn encode_step( &mut self, src: &[u8], end: bool, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error> { // once end=true has been processed, the caller must stop providing // data in src and must continue to set end until end is returned if self.end && (!src.is_empty() || !end) { return Err(io::Error::from(io::ErrorKind::Other)); } let next_buf = &mut self.next_buf; // we want to flush exactly once per message. to ensure this, we // flush only when there is no more input (to avoid a situation of // input not being accepted at the time of flush) and if we have not // flushed yet for the current message let flush = if src.is_empty() && end && !self.end { self.end = true; MZFlush::Sync } else { MZFlush::None }; let (consumed, written, maybe_more) = if dest.len() > next_buf.len() && dest.len() >= next_buf.remaining_capacity() { // if there's enough room in dest to hold all of next_buf plus at // least one more byte, and there's at least as much room in dest // as in next_buf, then encode directly into dest // move next_buf into dest let offset = next_buf.len(); dest[..offset].copy_from_slice(next_buf.as_ref()); next_buf.clear(); // encode into the remaining space let (result, maybe_more) = { let dest = &mut dest[offset..]; assert!(!dest.is_empty()); let result = deflate::stream::deflate(&mut self.enc, src, dest, flush); match result.status { Ok(MZStatus::Ok) => {} Err(MZError::Buf) => {} _ => return Err(io::Error::from(io::ErrorKind::Other)), } assert!(result.bytes_consumed <= src.len()); assert!(result.bytes_written <= dest.len()); (result, result.bytes_written == dest.len()) }; let dest = &mut dest[..(offset + result.bytes_written)]; // keep back the ending bytes in next_buf assert!(next_buf.is_empty()); let keep = cmp::min(ENC_NEXT_BUF_SIZE, dest.len()); next_buf.write_all(&dest[(dest.len() - keep)..]).unwrap(); let written = dest.len() - keep; (result.bytes_consumed, written, maybe_more) } else { // if next_buf can't fit into dest with room to spare, or if // there's more room in next_buf than in dest, then encode into a // temporary buffer and move the bytes into place afterwards. // note that the temporary buffer will be small // dest.len() is either less than or equal to next_buf.len() // or less than next_buf.remaining_capacity(). in either case // this will not exceed next_buf's capacity assert!(dest.len() <= ENC_NEXT_BUF_SIZE); // stating the obvious assert!(next_buf.remaining_capacity() <= ENC_NEXT_BUF_SIZE); let tmp_size = dest.len() + next_buf.remaining_capacity(); // based on above asserts assert!(tmp_size <= ENC_NEXT_BUF_SIZE * 2); let mut tmp: ArrayVec = ArrayVec::new(); tmp.resize(tmp_size, 0); // encode into tmp let (result, maybe_more) = { let result = deflate::stream::deflate(&mut self.enc, src, tmp.as_mut(), flush); match result.status { Ok(MZStatus::Ok) => {} Err(MZError::Buf) => {} _ => return Err(io::Error::from(io::ErrorKind::Other)), } assert!(result.bytes_consumed <= src.len()); assert!(result.bytes_written <= tmp.len()); (result, result.bytes_written == tmp.len()) }; tmp.truncate(result.bytes_written); let mut written = 0; // if the encoded bytes don't fit in next_buf, then we can // move some bytes to dest if tmp.len() > next_buf.remaining_capacity() { let to_write = tmp.len() - next_buf.remaining_capacity(); // move the starting bytes of next_buf to the front of dest let size = cmp::min(to_write, next_buf.len()); dest[..size].copy_from_slice(&next_buf[..size]); next_buf.shift_left(size); written += size; // if dest still has room, move from tmp if written < to_write { assert!(next_buf.is_empty()); let size = to_write - written; assert!(size <= tmp.len()); dest[written..(written + size)].copy_from_slice(&tmp[..size]); tmp.shift_left(size); written += size; } } // append tmp to next_buf next_buf.write_all(tmp.as_ref()).unwrap(); (result.bytes_consumed, written, maybe_more) }; let mut end_ack = false; if self.end && consumed == src.len() && next_buf.len() == DEFLATE_SUFFIX.len() && !maybe_more { if next_buf.as_ref() != DEFLATE_SUFFIX { return Err(io::Error::from(io::ErrorKind::Other)); } self.next_buf.clear(); self.end = false; end_ack = true; } Ok((consumed, written, end_ack)) } } pub struct DeflateDecoder { dec: Box, suffix_pos: Option, } #[allow(clippy::new_without_default)] impl DeflateDecoder { pub fn new() -> Self { Self { dec: InflateState::new_boxed(DataFormat::Raw), suffix_pos: None, } } } pub trait Decoder { fn decode( &mut self, src: &[u8], end: bool, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error>; } impl Decoder for DeflateDecoder { fn decode( &mut self, src: &[u8], end: bool, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error> { let (consumed, mut written) = if self.suffix_pos.is_none() { let result = inflate(&mut self.dec, src, dest, MZFlush::None); match result.status { Ok(MZStatus::Ok) => {} Err(MZError::Buf) => {} _ => return Err(io::Error::from(io::ErrorKind::Other)), } assert!(result.bytes_consumed <= src.len()); assert!(result.bytes_written <= dest.len()); if result.bytes_consumed == src.len() && end { self.suffix_pos = Some(0); } if result.bytes_written == dest.len() { return Ok((result.bytes_consumed, result.bytes_written, false)); } (result.bytes_consumed, result.bytes_written) } else { (0, 0) }; let mut end_ack = false; if let Some(pos) = &mut self.suffix_pos { // if the input is fully consumed when end is set, then the // caller must continue to set end until end is returned if !end { return Err(io::Error::from(io::ErrorKind::Other)); } let dest = &mut dest[written..]; let suffix = DEFLATE_SUFFIX; let suffix_left = &suffix[*pos..]; let result = inflate(&mut self.dec, suffix_left, dest, MZFlush::None); match result.status { Ok(MZStatus::Ok) => {} Err(MZError::Buf) => {} _ => return Err(io::Error::from(io::ErrorKind::Other)), } assert!(result.bytes_consumed <= suffix_left.len()); assert!(result.bytes_written <= dest.len()); *pos += result.bytes_consumed; // we are done when the entire input is consumed and there is // space left in the output buffer. if there is no space left in // the output buffer then there might be more to write if *pos == suffix.len() && result.bytes_written < dest.len() { self.suffix_pos = None; end_ack = true; } written += result.bytes_written; } Ok((consumed, written, end_ack)) } } pub fn deflate_codec_state_size() -> usize { let encoder_size = mem::size_of::(); let decoder_size = mem::size_of::(); encoder_size + decoder_size } // call preprocess_fn on any bytes about to be decoded. this can be used // to apply mask processing as needed fn decode_from_buffer( src: &mut T, limit: usize, end: bool, dec: &mut D, dest: &mut [u8], mut preprocess_fn: F, ) -> Result<(usize, bool), io::Error> where T: Buffer + ?Sized, D: Decoder, F: FnMut(&mut [u8], usize), { let buf = src.read_buf_mut(); let limit = cmp::min(limit, buf.len()); let buf = &mut buf[..limit]; preprocess_fn(buf, 0); let (read, mut written, mut end_ack) = dec.decode(buf, end, dest)?; let read_maxed = read == buf.len(); src.read_commit(read); let buf = src.read_buf_mut(); let buf = &mut buf[..(limit - read)]; if !end_ack && read_maxed && !buf.is_empty() { // this will not overlap with previously preprocessed bytes preprocess_fn(buf, read); let (read, w, ea) = dec.decode(buf, end, &mut dest[written..])?; src.read_commit(read); written += w; end_ack = ea; } Ok((written, end_ack)) } fn unmask_and_decode( src: &mut T, limit: usize, end: bool, mask: Option<[u8; 4]>, mask_offset: usize, dec: &mut D, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error> where T: Buffer + ?Sized, D: Decoder, { // if a mask needs to be applied, it needs to be applied to the // received bytes before they are passed to the decoder. however, // we don't know in advance how many bytes the decoder will // accept. in order to preserve the integrity of the input // buffer, and to avoid copying, we apply the mask directly to // the input buffer and then revert it on any bytes that weren't // accepted. in the best case, the decoder will accept all the // bytes with nothing to revert. in the worst case, the decoder // will accept nothing and all the bytes will be reverted let mut masked = 0; let orig_len = src.len(); let (written, output_end) = decode_from_buffer(src, limit, end, dec, dest, |buf, offset| { if let Some(mask) = mask { apply_mask(buf, mask, mask_offset + offset); masked += buf.len(); } })?; let read = orig_len - src.len(); if let Some(mask) = mask { // undo the mask on any unread bytes assert!(masked >= read); masked -= read; let mut bufs_arr = MaybeUninit::<[&mut [u8]; VECTORED_MAX]>::uninit(); let mut bufs = src.read_bufs_mut(&mut bufs_arr).limit(masked); apply_mask_vectored(bufs.as_slice(), mask, mask_offset + read); } Ok((read, written, output_end)) } // mask src, then call f(), then unmask src. f() returns (skip_unmask, R) fn with_mask(src: &mut [&mut [u8]], mask: Option<[u8; 4]>, mask_offset: usize, f: F) -> R where F: FnOnce(&mut [&mut [u8]]) -> (usize, R), { if let Some(mask) = mask { apply_mask_vectored(src, mask, mask_offset); } let (skip_unmask, ret) = f(src); if let Some(mask) = mask { apply_mask_vectored( src.skip(skip_unmask).as_slice(), mask, mask_offset + skip_unmask, ); } ret } #[cfg(test)] pub struct Frame<'a> { pub opcode: u8, pub data: &'a [u8], pub fin: bool, } #[derive(PartialEq)] pub enum CompressionMode { Compressed, Uncompressed, } #[derive(Debug, PartialEq, Clone, Copy)] pub enum State { // call: send_frame, recv_frame // next: Connected, PeerClosed, Closing Connected, // call: send_frame // next: PeerClosed, Finished PeerClosed, // call: recv_frame // next: Closing, Finished Closing, // session has completed Finished, } #[derive(Debug)] pub enum Error { Io(io::Error), InvalidControlFrame, UnexpectedOpcode, CompressionError, } impl From for Error { fn from(e: io::Error) -> Self { Self::Io(e) } } struct SendingFrame { opcode: u8, header: ArrayVec, payload_size: usize, sent: usize, } struct SendingMessage { opcode: u8, mask: Option<[u8; 4]>, frame_sent: bool, end_len: Option, enc_output_end: bool, } struct ReceivingMessage { opcode: u8, frame_payload_read: usize, compression_mode: CompressionMode, } struct Sending { frame: RefCell>, message: RefCell>, } struct Receiving { frame: Option, message: Option, } struct DeflateState { enc: DeflateEncoder, dec: DeflateDecoder, allow_takeover: bool, enc_buf: RingBuffer, } pub struct Protocol { state: Cell, sending: Sending, receiving: RefCell, deflate_state: Option>>, } impl + AsMut<[u8]>> Protocol { pub fn new(deflate_config: Option<(bool, RingBuffer)>) -> Self { let deflate_state = deflate_config.map(|(allow_takeover, enc_buf)| { RefCell::new(DeflateState { enc: DeflateEncoder::new(), dec: DeflateDecoder::new(), allow_takeover, enc_buf, }) }); Self { state: Cell::new(State::Connected), sending: Sending { frame: RefCell::new(None), message: RefCell::new(None), }, receiving: RefCell::new(Receiving { frame: None, message: None, }), deflate_state, } } pub fn state(&self) -> State { self.state.get() } pub fn send_frame( &self, writer: &mut W, opcode: u8, src: &mut [&mut [u8]], fin: bool, rsv1: bool, mask: Option<[u8; 4]>, ) -> Result { assert!(self.state.get() == State::Connected || self.state.get() == State::PeerClosed); let sending_frame = &mut *self.sending.frame.borrow_mut(); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if sending_frame.is_none() { let mut header = ArrayVec::from([0; HEADER_SIZE_MAX]); let size = write_header(fin, rsv1, opcode, src_len, mask, &mut header)?; header.truncate(size); *sending_frame = Some(SendingFrame { opcode, header, payload_size: src_len, sent: 0, }); } let frame = sending_frame.as_mut().unwrap(); let header = frame.header.as_slice(); let frame_size = header.len() + frame.payload_size; let (header_remaining, payload_sent) = if frame.sent < header.len() { (header.len() - frame.sent, 0) } else { (0, frame.sent - header.len()) }; assert!(payload_sent <= frame.payload_size); let mut src = src.limit(frame.payload_size - payload_sent); let src = src.as_slice(); // to avoid copying, we apply the mask directly to the input // buffer and then revert it on any bytes that weren't written. // in the best case, all bytes will be written with nothing to // revert. in the worst case, nothing will be written and all // the bytes will be reverted let size = with_mask(src, mask, payload_sent, |src| { let mut out = ArrayVec::<&[u8], VECTORED_MAX>::new(); if header_remaining > 0 { out.push(&header[frame.sent..]); } for buf in src.iter() { out.push(*buf); } let ret = write_vectored_offset(writer, out.as_slice(), 0); if log_enabled!(log::Level::Trace) { trace!("OUT sock {} -> {:?}", Bufs::new(out.as_slice()), ret); } let size = match &ret { Ok(size) => *size, Err(_) => 0, }; let skip_unmask = size.saturating_sub(header_remaining); (skip_unmask, ret) })?; frame.sent += size; assert!(frame.sent <= frame_size); if frame.sent < header.len() { return Ok(0); } let payload_written = (frame.sent - header.len()) - payload_sent; if frame.sent == frame_size { let opcode = frame.opcode; *sending_frame = None; if opcode == OPCODE_CLOSE { if self.state.get() == State::PeerClosed { self.state.set(State::Finished); } else { self.state.set(State::Closing); } } } Ok(payload_written) } // on success, it's up to the caller to advance the buffer by frame.data.len() #[cfg(test)] pub fn recv_frame<'buf, B: Buffer>( &mut self, rbuf: &'buf mut B, ) -> Option, Error>> { assert!(self.state.get() == State::Connected || self.state.get() == State::Closing); let receiving = &mut *self.receiving.borrow_mut(); if receiving.frame.is_none() { let fi = match read_header(rbuf.read_buf()) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return None, Err(e) => return Some(Err(e.into())), }; rbuf.read_commit(fi.payload_offset); receiving.frame = Some(fi); } let fi = receiving.frame.unwrap(); if rbuf.read_buf().len() < fi.payload_size { return None; } if fi.opcode == OPCODE_CLOSE { if self.state.get() == State::Closing { self.state.set(State::Finished); } else { self.state.set(State::PeerClosed); } } let buf = rbuf.read_buf_mut(); if let Some(mask) = fi.mask { apply_mask(buf, mask, 0); } receiving.frame = None; Some(Ok(Frame { opcode: fi.opcode, data: &buf[..fi.payload_size], fin: fi.fin, })) } pub fn is_sending_message(&self) -> bool { self.sending.message.borrow().is_some() } pub fn send_message_start(&self, opcode: u8, mask: Option<[u8; 4]>) { assert!(self.state.get() == State::Connected || self.state.get() == State::PeerClosed); let sending_message = &mut *self.sending.message.borrow_mut(); assert!(sending_message.is_none()); *sending_message = Some(SendingMessage { opcode, mask, frame_sent: false, end_len: None, enc_output_end: false, }); } // returns (bytes read, done) // note: when compression is used, bytes may be buffered in the encoder // and may not be flushed to the writer until the encoder's buffer is // full or the input ends pub fn send_message_content( &self, writer: &mut W, src: &mut [&mut [u8]], end: bool, ) -> Result<(usize, bool), Error> { assert!(self.state.get() == State::Connected || self.state.get() == State::PeerClosed); let mut sending_message = self.sending.message.borrow_mut(); let msg = sending_message.as_mut().unwrap(); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if let Some(end_len) = msg.end_len { // once the caller has passed end=true, it must continue to pass // end=true in all subsequent calls until this method returns // done assert!(end); // once the caller has passed end=true, it must continue to // provide the expected number of src bytes in all subsequent // calls until this method returns done assert_eq!(src_len, end_len); } else if end { // when the caller passes end=true, note the number of src bytes // provided msg.end_len = Some(src_len); } let is_control = msg.opcode & 0x08 != 0; // control frames (ping, pong, close) must have a small payload length // and must not be fragmented if is_control && (src_len > CONTROL_FRAME_PAYLOAD_MAX || !end) { return Err(Error::InvalidControlFrame); } let opcode = if msg.frame_sent { OPCODE_CONTINUATION } else { msg.opcode }; let (read, sent_all) = match &self.deflate_state { Some(state) if !is_control => { let state = &mut *state.borrow_mut(); let mut read = 0; if !msg.enc_output_end { if src_len > 0 { for (i, buf) in src.iter().enumerate() { // only set end on the last buf let end = end && (i == src.len() - 1); let dest = &mut state.enc_buf; let (r, oe) = state.enc.encode_to_ringbuffer(buf, end, dest)?; read += r; msg.enc_output_end = oe; if r < buf.len() || oe { break; } } } else { let dest = &mut state.enc_buf; let (_, oe) = state.enc.encode_to_ringbuffer(&[], end, dest)?; msg.enc_output_end = oe; } } // we should never get EOS if there are bytes left to send assert!(!msg.enc_output_end || read == src_len); if let Some(end_len) = &mut msg.end_len { *end_len -= read; } let mut sent_all = false; // only attempt to write if we have no consumed byte count // to report, so that if the write returns an error // (including WouldBlock) we can propagate the error without // data loss if read == 0 && (state.enc_buf.len() > 0 || msg.enc_output_end) { // send_frame adds 1 element to vector let mut bufs_arr = MaybeUninit::<[&mut [u8]; VECTORED_MAX - 1]>::uninit(); let bufs = state.enc_buf.read_bufs_mut(&mut bufs_arr); // set on first frame let rsv1 = opcode != OPCODE_CONTINUATION; let size = self.send_frame(writer, opcode, bufs, msg.enc_output_end, rsv1, msg.mask)?; state.enc_buf.read_commit(size); msg.frame_sent = true; sent_all = msg.enc_output_end && state.enc_buf.len() == 0; } (read, sent_all) } _ => { let read = self.send_frame(writer, opcode, src, end, false, msg.mask)?; if let Some(end_len) = &mut msg.end_len { *end_len -= read; } msg.frame_sent = true; (read, end && read == src_len) } }; if sent_all && self.sending.frame.borrow().is_none() { *sending_message = None; if let Some(state) = &self.deflate_state { let mut state = state.borrow_mut(); if !state.allow_takeover { state.enc.reset(); } } } let done = sending_message.is_none(); Ok((read, done)) } pub fn recv_message_content( &self, rbuf: &mut B, dest: &mut [u8], ) -> Option> { assert!(self.state.get() == State::Connected || self.state.get() == State::Closing); let receiving = &mut *self.receiving.borrow_mut(); if receiving.frame.is_none() { let fi = match read_header(rbuf.read_buf()) { Ok(fi) => fi, Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => return None, Err(e) => return Some(Err(e.into())), }; rbuf.read_commit(fi.payload_offset); receiving.frame = Some(fi); if let Some(msg) = &mut receiving.message { if fi.opcode != OPCODE_CONTINUATION { return Some(Err(Error::UnexpectedOpcode)); } // only the first frame should set this bit if fi.rsv1 { return Some(Err(Error::CompressionError)); } msg.frame_payload_read = 0; } else { if fi.opcode == OPCODE_CONTINUATION { return Some(Err(Error::UnexpectedOpcode)); } if fi.opcode & 0x08 != 0 && (fi.payload_size > CONTROL_FRAME_PAYLOAD_MAX || !fi.fin || fi.rsv1) { return Some(Err(Error::InvalidControlFrame)); } let compression_mode = if fi.rsv1 { CompressionMode::Compressed } else { CompressionMode::Uncompressed }; receiving.message = Some(ReceivingMessage { opcode: fi.opcode, frame_payload_read: 0, compression_mode, }); } } let fi = receiving.frame.as_ref().unwrap(); let msg = receiving.message.as_mut().unwrap(); let (written, frame_read_end) = if msg.compression_mode == CompressionMode::Compressed { let state = match &self.deflate_state { Some(state) => state, None => return Some(Err(Error::CompressionError)), }; let mut state = state.borrow_mut(); let left = fi.payload_size - msg.frame_payload_read; let limit = cmp::min(left, rbuf.len()); let end = ((msg.frame_payload_read + limit) == fi.payload_size) && fi.fin; let (read, written, output_end) = match unmask_and_decode( rbuf, limit, end, fi.mask, msg.frame_payload_read, &mut state.dec, dest, ) { Ok(ret) => ret, Err(e) => return Some(Err(e.into())), }; msg.frame_payload_read += read; let frame_read_end = if fi.fin { // finish final frame only when we hit EOS (msg.frame_payload_read == fi.payload_size) && output_end } else { msg.frame_payload_read == fi.payload_size }; if !frame_read_end && written == 0 && rbuf.len() == 0 { // if there's no progress to report and nothing left to read // then we need more input return None; } (written, frame_read_end) } else { let buf = rbuf.read_buf(); // control frames must be available in their entirety if fi.opcode & 0x08 != 0 && buf.len() < fi.payload_size { return None; } let left = fi.payload_size - msg.frame_payload_read; if left > 0 && buf.is_empty() { return None; } let buf = &buf[..cmp::min(left, buf.len())]; let size = cmp::min(buf.len(), dest.len()); let dest = &mut dest[..size]; dest.copy_from_slice(&buf[..size]); rbuf.read_commit(size); if let Some(mask) = fi.mask { apply_mask(dest, mask, msg.frame_payload_read); } msg.frame_payload_read += size; assert!(msg.frame_payload_read <= fi.payload_size); (size, msg.frame_payload_read == fi.payload_size) }; let opcode = msg.opcode; let fin = fi.fin; if frame_read_end { receiving.frame = None; if fin { receiving.message = None; if opcode == OPCODE_CLOSE { if self.state.get() == State::Closing { self.state.set(State::Finished); } else { self.state.set(State::PeerClosed); } } } } Some(Ok((opcode, written, receiving.message.is_none()))) } } pub mod testutil { use super::*; use crate::core::buffer::{TmpBuffer, VecRingBuffer}; use std::rc::Rc; pub struct BenchSendMessageArgs { protocol: Protocol>, dest: Vec, } pub struct BenchSendMessage { use_deflate: bool, content: RefCell>, } impl BenchSendMessage { pub fn new(use_deflate: bool) -> Self { let mut content = Vec::with_capacity(1024); for i in 0..1024 { content.push((i % 256) as u8); } Self { use_deflate, content: RefCell::new(content), } } pub fn init(&self) -> BenchSendMessageArgs { let deflate_config = if self.use_deflate { let tmp = Rc::new(TmpBuffer::new(256)); Some((true, VecRingBuffer::new(256, &tmp))) } else { None }; BenchSendMessageArgs { protocol: Protocol::new(deflate_config), dest: Vec::with_capacity(16_384), } } pub fn run(&self, args: &mut BenchSendMessageArgs) { let p = &mut args.protocol; let src = &mut *self.content.borrow_mut(); let dest = &mut args.dest; p.send_message_start(OPCODE_TEXT, None); let mut src_pos = 0; loop { let (size, done) = p .send_message_content(dest, &mut [&mut src[src_pos..]], true) .unwrap(); src_pos += size; assert!(dest.len() < dest.capacity() || done); if done { break; } } } } pub struct BenchRecvMessageArgs { protocol: Protocol>, rbuf: VecRingBuffer, dest: Vec, } pub struct BenchRecvMessage { use_deflate: bool, tmp: Rc, msg: Vec, } impl BenchRecvMessage { pub fn new(use_deflate: bool) -> Self { let mut content = Vec::with_capacity(1024); for i in 0..1024 { content.push((i % 256) as u8); } let tmp = Rc::new(TmpBuffer::new(16_384)); let deflate_config = if use_deflate { Some((true, VecRingBuffer::new(16_384, &tmp))) } else { None }; let p = Protocol::new(deflate_config); let mut msg = Vec::new(); p.send_message_start(OPCODE_TEXT, None); let mut src_pos = 0; loop { let (size, done) = p .send_message_content(&mut msg, &mut [&mut content[src_pos..]], true) .unwrap(); src_pos += size; if done { break; } } Self { use_deflate, tmp, msg, } } pub fn init(&self) -> BenchRecvMessageArgs { let deflate_config = if self.use_deflate { Some((true, VecRingBuffer::new(256, &self.tmp))) } else { None }; let mut rbuf = VecRingBuffer::new(16_384, &self.tmp); let size = rbuf.write(&self.msg).unwrap(); assert_eq!(size, self.msg.len()); let dest = vec![0; 16_384]; BenchRecvMessageArgs { protocol: Protocol::new(deflate_config), rbuf, dest, } } pub fn run(&self, args: &mut BenchRecvMessageArgs) { let p = &mut args.protocol; let rbuf = &mut args.rbuf; let dest = &mut args.dest; let mut dest_pos = 0; loop { let (_, size, end) = p .recv_message_content(rbuf, &mut dest[dest_pos..]) .unwrap() .unwrap(); dest_pos += size; assert!(dest_pos < dest.len() || end); if end { break; } } } } } #[cfg(test)] mod tests { use super::testutil::*; use super::*; use crate::core::buffer::{TmpBuffer, VecRingBuffer}; use std::collections::VecDeque; use std::rc::Rc; struct MyWriter { data: Vec, allow: usize, } impl MyWriter { fn new() -> Self { Self { data: Vec::new(), allow: 1024, } } } impl Write for MyWriter { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.allow == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } let size = cmp::min(buf.len(), self.allow); let buf = &buf[..size]; self.data.extend_from_slice(buf); self.allow -= size; Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { if buf.is_empty() { continue; } if self.allow == 0 { if total == 0 { return Err(io::Error::from(io::ErrorKind::WouldBlock)); } break; } let size = cmp::min(buf.len(), self.allow); let buf = &buf[..size]; self.data.extend_from_slice(buf); self.allow -= size; total += buf.len(); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } fn make_buf(s: &[u8; N]) -> ArrayVec { ArrayVec::from(*s) } #[test] fn test_header_size() { assert_eq!(header_size(0, false), 2); assert_eq!(header_size(125, false), 2); assert_eq!(header_size(125, true), 6); assert_eq!(header_size(126, false), 4); assert_eq!(header_size(65535, false), 4); assert_eq!(header_size(65535, true), 8); assert_eq!(header_size(65536, false), 10); assert_eq!(header_size(65536, true), HEADER_SIZE_MAX); } #[test] fn test_read_write_header() { let mut buf = [ 0x81, 0x85, 0x01, 0x02, 0x03, 0x04, 0x69, 0x67, 0x6f, 0x68, 0x6e, ]; let r = read_header(&buf); assert!(r.is_ok()); let fi = r.unwrap(); assert_eq!(fi.fin, true); assert_eq!(fi.opcode, OPCODE_TEXT); assert_eq!(fi.mask, Some([0x01, 0x02, 0x03, 0x04])); assert_eq!(fi.payload_offset, 6); assert_eq!(fi.payload_size, 5); let end = fi.payload_offset + fi.payload_size; let payload = &mut buf[fi.payload_offset..end]; apply_mask(payload, (&fi.mask).unwrap(), 0); assert_eq!(payload, b"hello"); let payload = b"hello"; let mut buf2 = Vec::new(); buf2.resize(header_size(payload.len(), true) + payload.len(), 0); let r = write_header( true, false, OPCODE_TEXT, payload.len(), Some([0x01, 0x02, 0x03, 0x04]), &mut buf2, ); assert!(r.is_ok()); let offset = r.unwrap(); assert_eq!(offset, 6); buf2[offset..offset + payload.len()].copy_from_slice(payload); assert_eq!(buf2, buf); } #[test] fn test_apply_mask() { let mut buf = [b'a', b'b', b'c', b'd', b'e']; apply_mask(&mut buf, [0x01, 0x02, 0x03, 0x04], 0); assert_eq!(buf, [0x60, 0x60, 0x60, 0x60, 0x64]); } #[test] fn test_deflate_bulk() { { let mut enc = DeflateEncoder::new(); let mut dec = DeflateDecoder::new(); let data = b"Hello"; let mut compressed = [0; 1024]; let (read, written, end) = enc.encode(data, true, &mut compressed).unwrap(); assert_eq!(read, 5); assert_eq!(end, true); let compressed = &compressed[..written]; let expected = [0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00]; assert_eq!(compressed, &expected); let mut uncompressed = [0; 1024]; let (read, written, end) = dec.decode(&compressed, true, &mut uncompressed).unwrap(); assert_eq!(read, compressed.len()); assert_eq!(end, true); let uncompressed = &uncompressed[..written]; assert_eq!(uncompressed, data); } } #[test] fn test_deflate_by_byte() { { let mut enc = DeflateEncoder::new(); let mut dec = DeflateDecoder::new(); let data = b"Hello"; let mut compressed = [0; 1024]; assert_eq!( enc.encode(&[], false, &mut compressed).unwrap(), (0, 0, false) ); let mut read_pos = 0; let mut write_pos = 0; loop { let (read, written, end) = enc .encode( &data[read_pos..], true, &mut compressed[write_pos..(write_pos + 1)], ) .unwrap(); // there must always be progress assert!(read > 0 || written > 0 || end); read_pos += read; write_pos += written; if end { break; } } let compressed = &compressed[..write_pos]; let expected = [0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00]; assert_eq!(compressed, &expected); let mut uncompressed = [0; 1024]; assert_eq!( dec.decode(&[], false, &mut uncompressed).unwrap(), (0, 0, false) ); let mut read_pos = 0; let mut write_pos = 0; loop { let (read, written, end) = dec .decode( &compressed[read_pos..], true, &mut uncompressed[write_pos..(write_pos + 1)], ) .unwrap(); // there must always be progress assert!(read > 0 || written > 0 || end); read_pos += read; write_pos += written; if end { break; } } assert_eq!(read_pos, compressed.len()); let uncompressed = &uncompressed[..write_pos]; assert_eq!(uncompressed, data); } } #[test] fn test_send_frame() { let p = Protocol::<[u8; 0]>::new(None); assert_eq!(p.state(), State::Connected); let mut writer = MyWriter::new(); let size = p .send_frame( &mut writer, OPCODE_TEXT, &mut [&mut make_buf(b"hello")], true, false, None, ) .unwrap(); assert_eq!(size, 5); assert_eq!(writer.data, b"\x81\x05hello"); assert_eq!(p.state(), State::Connected); } #[test] fn test_send_frame_masked() { let p = Protocol::<[u8; 0]>::new(None); let mut buf = make_buf(b"hello"); let mut writer = MyWriter::new(); writer.allow = 0; let e = p .send_frame( &mut writer, OPCODE_TEXT, &mut [&mut buf], true, false, Some([0x01, 0x02, 0x03, 0x04]), ) .unwrap_err(); let e = match e { Error::Io(e) => e, _ => panic!("unexpected error type"), }; assert_eq!(e.kind(), io::ErrorKind::WouldBlock); assert_eq!(&buf, b"hello".as_slice()); assert_eq!(p.sending.frame.borrow().is_some(), true); writer.allow = 3; let size = p .send_frame( &mut writer, OPCODE_TEXT, &mut [&mut buf], true, false, Some([0x01, 0x02, 0x03, 0x04]), ) .unwrap(); let expected = [0x81, 0x85, 0x01]; assert_eq!(size, 0); assert_eq!(writer.data, expected); assert_eq!(&buf, b"hello".as_slice()); assert_eq!(p.sending.frame.borrow().is_some(), true); writer.allow = 4; let size = p .send_frame( &mut writer, OPCODE_TEXT, &mut [&mut buf], true, false, Some([0x01, 0x02, 0x03, 0x04]), ) .unwrap(); let expected = [0x81, 0x85, 0x01, 0x02, 0x03, 0x04, b'h' ^ 0x01]; assert_eq!(size, 1); assert_eq!(writer.data, expected); assert_eq!(&buf, [b'h' ^ 0x01, b'e', b'l', b'l', b'o'].as_slice()); assert_eq!(p.sending.frame.borrow().is_some(), true); writer.allow = 1024; let size = p .send_frame( &mut writer, OPCODE_TEXT, &mut [&mut buf[1..]], true, false, Some([0x01, 0x02, 0x03, 0x04]), ) .unwrap(); let expected = [ 0x81, 0x85, 0x01, 0x02, 0x03, 0x04, b'h' ^ 0x01, b'e' ^ 0x02, b'l' ^ 0x03, b'l' ^ 0x04, b'o' ^ 0x01, ]; assert_eq!(size, 4); assert_eq!(writer.data, expected); assert_eq!( &buf, [ b'h' ^ 0x01, b'e' ^ 0x02, b'l' ^ 0x03, b'l' ^ 0x04, b'o' ^ 0x01 ] .as_slice() ); assert_eq!(p.sending.frame.borrow().is_some(), false); } #[test] fn test_send_message() { let p = Protocol::<[u8; 0]>::new(None); assert_eq!(p.state(), State::Connected); let mut writer = MyWriter::new(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content( &mut writer, &mut [&mut make_buf(b"hel"), &mut make_buf(b"lo")], true, ) .unwrap(); assert_eq!(size, 5); assert_eq!(done, true); assert_eq!(writer.data, b"\x81\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b"hello")], false) .unwrap(); assert_eq!(size, 5); assert_eq!(done, false); assert_eq!(writer.data, b"\x01\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); let (size, done) = p.send_message_content(&mut writer, &mut [], true).unwrap(); assert_eq!(size, 0); assert_eq!(done, true); assert_eq!(writer.data, b"\x80\x00"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_PING, None); let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b"hello")], true) .unwrap(); assert_eq!(size, 5); assert_eq!(done, true); assert_eq!(writer.data, b"\x89\x05hello"); assert_eq!(p.state(), State::Connected); writer.data.clear(); p.send_message_start(OPCODE_PING, None); let r = p.send_message_content(&mut writer, &mut [&mut make_buf(b"hello")], false); assert!(r.is_err()); let p = Protocol::<[u8; 0]>::new(None); writer.data.clear(); writer.allow = 3; p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b"hello")], true) .unwrap(); assert_eq!(size, 1); assert_eq!(done, false); assert_eq!(writer.data, b"\x81\x05h"); writer.allow = 4; let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b"ello")], true) .unwrap(); assert_eq!(size, 4); assert_eq!(done, true); assert_eq!(writer.data, b"\x81\x05hello"); } #[test] fn test_recv_frame() { let mut data = b"\x81\x05hello".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let mut p = Protocol::<[u8; 0]>::new(None); assert_eq!(p.state(), State::Connected); let frame = p.recv_frame(&mut rbuf).unwrap().unwrap(); assert_eq!(frame.opcode, OPCODE_TEXT); assert_eq!(frame.data, b"hello"); assert_eq!(frame.fin, true); let size = frame.data.len(); rbuf.read_commit(size); assert_eq!(p.state(), State::Connected); } #[test] fn test_recv_message() { let mut data = b"\x81\x05hello".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let p = Protocol::<[u8; 0]>::new(None); assert_eq!(p.state(), State::Connected); let mut dest = [0; 1024]; let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hello"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let r = p.recv_message_content(&mut rbuf, &mut dest); assert!(r.is_none()); let mut data = b"\x01\x03hel\x80\x02lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"lo"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"\x81\x05hel".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); assert!(p.recv_message_content(&mut rbuf, &mut dest).is_none()); let mut data = b"lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"lo"); assert_eq!(end, true); assert_eq!(p.state(), State::Connected); let mut data = b"\x01\x03hel\x01\x02lo".to_vec(); let mut rbuf = io::Cursor::new(&mut data[..]); let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"hel"); assert_eq!(end, false); let r = p.recv_message_content(&mut rbuf, &mut dest).unwrap(); assert!(r.is_err()); } #[test] fn test_send_recv_compressed() { let tmp = Rc::new(TmpBuffer::new(1024)); let p = Protocol::new(Some((true, VecRingBuffer::new(1024, &tmp)))); let mut writer = MyWriter::new(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content( &mut writer, &mut [&mut make_buf(b"Hel"), &mut make_buf(b"lo")], true, ) .unwrap(); assert_eq!(size, 5); assert_eq!(done, false); assert_eq!(writer.data.is_empty(), true); let (size, done) = p.send_message_content(&mut writer, &mut [], true).unwrap(); assert_eq!(size, 0); assert_eq!(done, true); assert_eq!( writer.data, &[0xc1, 0x07, 0xf2, 0x48, 0xcd, 0xc9, 0xc9, 0x07, 0x00] ); let mut rbuf = io::Cursor::new(writer.data.as_mut()); let p = Protocol::new(Some((true, VecRingBuffer::new(1024, &tmp)))); let mut dest = [0; 1024]; let (opcode, size, end) = p .recv_message_content(&mut rbuf, &mut dest) .unwrap() .unwrap(); let data = &dest[..size]; assert_eq!(opcode, OPCODE_TEXT); assert_eq!(data, b"Hello"); assert_eq!(end, true); } #[test] fn test_send_recv_compressed_fragmented() { let tmp = Rc::new(TmpBuffer::new(1024)); let p = Protocol::new(Some((true, VecRingBuffer::new(1024, &tmp)))); let mut writer = MyWriter::new(); p.send_message_start(OPCODE_TEXT, None); let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b"hello")], false) .unwrap(); assert_eq!(size, 5); assert_eq!(done, false); // flush the encoded data { let state = &mut *p.deflate_state.as_ref().unwrap().borrow_mut(); let (_, output_end) = state .enc .encode_to_ringbuffer(&[], true, &mut state.enc_buf) .unwrap(); assert_eq!(output_end, true); state.enc_buf.write(&DEFLATE_SUFFIX).unwrap(); } // send flushed data as first frame let (size, done) = p.send_message_content(&mut writer, &mut [], false).unwrap(); assert_eq!(size, 0); assert_eq!(done, false); assert_eq!(writer.data.is_empty(), false); // send second frame let (size, done) = p .send_message_content(&mut writer, &mut [&mut make_buf(b" world")], true) .unwrap(); assert_eq!(size, 6); assert_eq!(done, false); let (size, done) = p.send_message_content(&mut writer, &mut [], true).unwrap(); assert_eq!(size, 0); assert_eq!(done, true); let p = Protocol::new(Some((true, VecRingBuffer::new(1024, &tmp)))); let mut writer_data = VecDeque::from(writer.data); let mut input = Vec::new(); let mut result: Vec = Vec::new(); // feed one byte at a time loop { let mut rbuf = io::Cursor::new(input.as_mut()); let mut dest = [0; 1024]; let ret = p.recv_message_content(&mut rbuf, &mut dest); let read = rbuf.position() as usize; input = input.split_off(read); if ret.is_none() { input.push(writer_data.pop_front().unwrap()); continue; } let (opcode, size, end) = ret.unwrap().unwrap(); assert_eq!(opcode, OPCODE_TEXT); result.extend(&dest[..size]); if end { break; } } assert_eq!(result, b"hello world"); } struct LimitedDeflateDecoder { dec: DeflateDecoder, limit: usize, } impl LimitedDeflateDecoder { fn new(limit: usize) -> Self { Self { dec: DeflateDecoder::new(), limit, } } fn increase_limit(&mut self, amt: usize) { self.limit += amt } } impl Decoder for LimitedDeflateDecoder { fn decode( &mut self, src: &[u8], end: bool, dest: &mut [u8], ) -> Result<(usize, usize, bool), io::Error> { let limit = cmp::min(self.limit, src.len()); let end = end && limit >= src.len(); let (read, written, output_end) = self.dec.decode(&src[..limit], end, dest)?; self.limit -= read; Ok((read, written, output_end)) } } #[test] fn test_unmask_and_decode() { // "Hello" compressed and masked with [0x01, 0x02, 0x03, 0x04] let mut msg = [ 0xf2 ^ 0x01, 0x48 ^ 0x02, 0xcd ^ 0x03, 0xc9 ^ 0x04, 0xc9 ^ 0x01, 0x07 ^ 0x02, 0x00 ^ 0x03, ]; let mask = [0x01, 0x02, 0x03, 0x04]; let mut rbuf = io::Cursor::new(&mut msg[..]); let mut dec = LimitedDeflateDecoder::new(5); let mut dest = [0; 1024]; let mut written = 0; let (read, w, output_end) = unmask_and_decode(&mut rbuf, 1024, true, Some(mask), 0, &mut dec, &mut dest).unwrap(); written += w; assert_eq!(read, 5); assert_eq!(output_end, false); dec.increase_limit(1024); let (read, w, output_end) = unmask_and_decode( &mut rbuf, 1024, true, Some(mask), read, &mut dec, &mut dest[written..], ) .unwrap(); written += w; assert_eq!(read, 2); assert_eq!(output_end, true); assert_eq!(&dest[..written], b"Hello"); } #[test] fn bench_send_message() { let t = BenchSendMessage::new(false); t.run(&mut t.init()); } #[test] fn bench_send_message_with_deflate() { let t = BenchSendMessage::new(true); t.run(&mut t.init()); } #[test] fn bench_recv_message() { let t = BenchRecvMessage::new(false); t.run(&mut t.init()); } #[test] fn bench_recv_message_with_deflate() { let t = BenchRecvMessage::new(true); t.run(&mut t.init()); } } pushpin-1.41.0/src/connmgr/zhttppacket.rs000066400000000000000000001716201504671364300204440ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::arena; use crate::core::tnetstring; use arrayvec::ArrayVec; use std::cell::RefCell; use std::io; use std::mem; use std::str; use thiserror::Error; pub const IDS_MAX: usize = 128; const HEADERS_MAX: usize = 64; const EMPTY_BYTES: &[u8] = b""; const EMPTY_ID: Id = Id { id: EMPTY_BYTES, seq: None, }; pub const EMPTY_HEADER: Header = Header { name: "", value: EMPTY_BYTES, }; const EMPTY_HEADERS: [Header; 0] = [EMPTY_HEADER; 0]; #[derive(Debug, Error, PartialEq)] pub enum ParseError { #[error("unrecognized data format")] Unrecognized, #[error(transparent)] TnetParse(#[from] tnetstring::ParseError), #[error("{0} must be of type {1}")] WrongType(&'static str, tnetstring::FrameType), #[error("{0} must be of type map or string")] NotMapOrString(&'static str), #[error("{0} must be a utf-8 string")] NotUtf8(&'static str), #[error("{0} must not be negative")] NegativeInt(&'static str), #[error("too many ids")] TooManyIds, #[error("too many headers")] TooManyHeaders, #[error("header item must have size 2")] InvalidHeader, #[error("no id")] NoId, } trait ErrorContext { fn field(self, field: &'static str) -> Result; } impl ErrorContext for Result { fn field(self, field: &'static str) -> Result { match self { Ok(v) => Ok(v), Err(tnetstring::ParseError::WrongType(_, expected)) => { Err(ParseError::WrongType(field, expected)) } Err(e) => Err(e.into()), } } } impl ErrorContext for Result { fn field(self, field: &'static str) -> Result { match self { Ok(v) => Ok(v), Err(_) => Err(ParseError::NotUtf8(field)), } } } #[derive(Clone, Copy)] pub struct Id<'a> { pub id: &'a [u8], pub seq: Option, } pub struct Header<'a> { pub name: &'a str, pub value: &'a [u8], } #[derive(Debug, PartialEq)] pub enum ContentType { Text, Binary, } type IdsScratch<'a> = ArrayVec, IDS_MAX>; type HeadersScratch<'a> = ArrayVec, HEADERS_MAX>; pub struct ParseScratch<'a> { ids: IdsScratch<'a>, headers: HeadersScratch<'a>, } #[allow(clippy::new_without_default)] impl ParseScratch<'_> { pub fn new() -> Self { Self { ids: ArrayVec::new(), headers: ArrayVec::new(), } } } trait Serialize<'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error>; } trait Parse<'buf: 'scratch, 'scratch> { type Parsed; fn parse( root: tnetstring::MapIterator<'buf>, scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result; } struct CommonData<'buf, 'ids> { from: &'buf [u8], ids: &'ids [Id<'buf>], multi: bool, ptype_str: &'buf str, } impl<'buf: 'ids, 'ids> CommonData<'buf, 'ids> { fn serialize(&self, w: &mut tnetstring::Writer<'buf, '_>) -> Result<(), io::Error> { if !self.from.is_empty() { w.write_string(b"from")?; w.write_string(self.from)?; } #[allow(clippy::comparison_chain)] if self.ids.len() == 1 { w.write_string(b"id")?; w.write_string(self.ids[0].id)?; if let Some(seq) = self.ids[0].seq { w.write_string(b"seq")?; w.write_int(seq as isize)?; } } else if self.ids.len() > 1 { w.write_string(b"id")?; w.start_array()?; for id in self.ids.iter() { w.start_map()?; w.write_string(b"id")?; w.write_string(id.id)?; if let Some(seq) = id.seq { w.write_string(b"seq")?; w.write_int(seq as isize)?; } w.end_map()?; } w.end_array()?; } if self.multi { w.write_string(b"ext")?; w.start_map()?; w.write_string(b"multi")?; w.write_bool(true)?; w.end_map()?; } if !self.ptype_str.is_empty() { w.write_string(b"type")?; w.write_string(self.ptype_str.as_bytes())?; } Ok(()) } fn parse( root: tnetstring::MapIterator<'buf>, scratch: &'ids mut IdsScratch<'buf>, ) -> Result { let mut from = EMPTY_BYTES; let mut multi = false; let mut ptype_str = ""; for e in root { let e = e?; match e.key { "from" => { from = tnetstring::parse_string(e.data).field("from")?; } "id" => { match e.ftype { tnetstring::FrameType::Array => { for idm in tnetstring::parse_array(e.data)? { let idm = idm?; if scratch.remaining_capacity() == 0 { return Err(ParseError::TooManyIds); } let mut id = EMPTY_BYTES; let mut seq = None; for m in tnetstring::parse_map(idm.data)? { let m = m?; match m.key { "id" => { let s = tnetstring::parse_string(m.data).field("id")?; id = s; } "seq" => { let x = tnetstring::parse_int(m.data).field("seq")?; if x < 0 { return Err(ParseError::NegativeInt("seq")); } seq = Some(x as u32); } _ => {} // skip unknown fields } } scratch.push(Id { id, seq }); } } tnetstring::FrameType::String => { let s = tnetstring::parse_string(e.data)?; if scratch.is_empty() { scratch.push(EMPTY_ID); } scratch[0].id = s; } _ => { return Err(ParseError::NotMapOrString("id")); } } } "seq" => { let x = tnetstring::parse_int(e.data).field("seq")?; if x < 0 { return Err(ParseError::NegativeInt("seq")); } if scratch.is_empty() { scratch.push(EMPTY_ID); } scratch[0].seq = Some(x as u32); } "type" => { let s = tnetstring::parse_string(e.data).field("type")?; let s = str::from_utf8(s).field("type")?; ptype_str = s; } "ext" => { let ext = tnetstring::parse_map(e.data).field("ext")?; for m in ext { let m = m?; if m.key == "multi" { let b = tnetstring::parse_bool(m.data).field("multi")?; multi = b; } } } _ => {} // skip unknown fields } } Ok(Self { from, ids: scratch.as_slice(), multi, ptype_str, }) } } pub struct RequestData<'buf, 'headers> { pub credits: u32, pub more: bool, pub stream: bool, pub router_resp: bool, pub max_size: u32, pub timeout: u32, pub method: &'buf str, pub uri: &'buf str, pub headers: &'headers [Header<'buf>], pub content_type: Option, // websocket pub body: &'buf [u8], pub peer_address: &'buf str, pub peer_port: u16, pub connect_host: &'buf str, pub connect_port: u16, pub ignore_policies: bool, pub trust_connect_host: bool, pub ignore_tls_errors: bool, pub follow_redirects: bool, } #[allow(clippy::new_without_default)] impl RequestData<'_, '_> { pub fn new() -> Self { Self { credits: 0, more: false, stream: false, router_resp: false, max_size: 0, timeout: 0, method: "", uri: "", headers: &EMPTY_HEADERS, body: EMPTY_BYTES, content_type: None, peer_address: "", peer_port: 0, connect_host: "", connect_port: 0, ignore_policies: false, trust_connect_host: false, ignore_tls_errors: false, follow_redirects: false, } } } impl<'a> Serialize<'a> for RequestData<'a, 'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { if !self.method.is_empty() { w.write_string(b"method")?; w.write_string(self.method.as_bytes())?; } if !self.uri.is_empty() { w.write_string(b"uri")?; w.write_string(self.uri.as_bytes())?; } if !self.headers.is_empty() { w.write_string(b"headers")?; w.start_array()?; for h in self.headers.iter() { w.start_array()?; w.write_string(h.name.as_bytes())?; w.write_string(h.value)?; w.end_array()?; } w.end_array()?; } if let Some(ctype) = &self.content_type { w.write_string(b"content-type")?; let s: &[u8] = match ctype { ContentType::Text => b"text", ContentType::Binary => b"binary", }; w.write_string(s)?; } if !self.body.is_empty() { w.write_string(b"body")?; w.write_string(self.body)?; } if self.credits > 0 { w.write_string(b"credits")?; w.write_int(self.credits as isize)?; } if self.more { w.write_string(b"more")?; w.write_bool(true)?; } if self.stream { w.write_string(b"stream")?; w.write_bool(true)?; } if self.router_resp { w.write_string(b"router-resp")?; w.write_bool(true)?; } if self.max_size > 0 { w.write_string(b"max-size")?; w.write_int(self.max_size as isize)?; } if self.timeout > 0 { w.write_string(b"timeout")?; w.write_int(self.timeout as isize)?; } if !self.peer_address.is_empty() { w.write_string(b"peer-address")?; w.write_string(self.peer_address.as_bytes())?; w.write_string(b"peer-port")?; w.write_int(self.peer_port as isize)?; } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for RequestData<'buf, 'scratch> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let mut credits = 0; let mut more = false; let mut stream = false; let mut router_resp = false; let mut max_size = 0; let mut timeout = 0; let mut method = ""; let mut uri = ""; let mut content_type = None; let mut body = EMPTY_BYTES; let mut peer_address = ""; let mut peer_port = 0; let mut connect_host = ""; let mut connect_port = 0; let mut ignore_policies = false; let mut trust_connect_host = false; let mut ignore_tls_errors = false; let mut follow_redirects = false; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } "more" => { let b = tnetstring::parse_bool(e.data).field("more")?; more = b; } "stream" => { let b = tnetstring::parse_bool(e.data).field("stream")?; stream = b; } "router-resp" => { let b = tnetstring::parse_bool(e.data).field("router-resp")?; router_resp = b; } "max-size" => { let x = tnetstring::parse_int(e.data).field("max-size")?; if x < 0 { return Err(ParseError::NegativeInt("max-size")); } max_size = x as u32; } "timeout" => { let x = tnetstring::parse_int(e.data).field("timeout")?; if x < 0 { return Err(ParseError::NegativeInt("timeout")); } timeout = x as u32; } "method" => { let s = tnetstring::parse_string(e.data).field("method")?; let s = str::from_utf8(s).field("method")?; method = s; } "uri" => { let s = tnetstring::parse_string(e.data).field("uri")?; let s = str::from_utf8(s).field("uri")?; uri = s; } "headers" => { let headers = tnetstring::parse_array(e.data).field("headers")?; for ha in headers { let ha = ha?; if scratch.remaining_capacity() == 0 { return Err(ParseError::TooManyHeaders); } let mut hi = tnetstring::parse_array(ha.data).field("header item")?; let name = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let name = tnetstring::parse_string(name.data).field("header name")?; let name = str::from_utf8(name).field("header name")?; let value = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let value = tnetstring::parse_string(value.data).field("header value")?; scratch.push(Header { name, value }); } } "content-type" => { let s = tnetstring::parse_string(e.data).field("content-type")?; content_type = Some(match s { b"binary" => ContentType::Binary, _ => ContentType::Text, }); } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } "peer-address" => { let s = tnetstring::parse_string(e.data).field("peer-address")?; let s = str::from_utf8(s).field("peer-address")?; peer_address = s; } "peer-port" => { let x = tnetstring::parse_int(e.data).field("peer-port")?; if x < 0 { return Err(ParseError::NegativeInt("peer-port")); } peer_port = x as u16; } "connect-host" => { let s = tnetstring::parse_string(e.data).field("connect-host")?; let s = str::from_utf8(s).field("connect-host")?; connect_host = s; } "connect-port" => { let x = tnetstring::parse_int(e.data).field("connect-port")?; if x < 0 { return Err(ParseError::NegativeInt("connect-port")); } connect_port = x as u16; } "ignore-policies" => { let b = tnetstring::parse_bool(e.data).field("ignore-policies")?; ignore_policies = b; } "trust-connect-host" => { let b = tnetstring::parse_bool(e.data).field("trust-connect-host")?; trust_connect_host = b; } "ignore-tls-errors" => { let b = tnetstring::parse_bool(e.data).field("ignore-tls-errors")?; ignore_tls_errors = b; } "follow-redirects" => { let b = tnetstring::parse_bool(e.data).field("follow-redirects")?; follow_redirects = b; } _ => {} // skip unknown fields } } Ok(Self { credits, more, stream, router_resp, max_size, timeout, method, uri, headers: scratch.as_slice(), content_type, body, peer_address, peer_port, connect_host, connect_port, ignore_policies, trust_connect_host, ignore_tls_errors, follow_redirects, }) } } pub struct ResponseData<'buf, 'headers> { pub credits: u32, pub more: bool, pub code: u16, pub reason: &'buf str, pub headers: &'headers [Header<'buf>], pub content_type: Option, // websocket pub body: &'buf [u8], } #[allow(clippy::new_without_default)] impl ResponseData<'_, '_> { pub fn new() -> Self { Self { credits: 0, more: false, code: 0, reason: "", headers: &EMPTY_HEADERS, content_type: None, body: EMPTY_BYTES, } } } impl<'a> Serialize<'a> for ResponseData<'a, 'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { if self.code > 0 { w.write_string(b"code")?; w.write_int(self.code as isize)?; } if !self.reason.is_empty() { w.write_string(b"reason")?; w.write_string(self.reason.as_bytes())?; } if !self.headers.is_empty() { w.write_string(b"headers")?; w.start_array()?; for h in self.headers.iter() { w.start_array()?; w.write_string(h.name.as_bytes())?; w.write_string(h.value)?; w.end_array()?; } w.end_array()?; } if let Some(ctype) = &self.content_type { w.write_string(b"content-type")?; let s: &[u8] = match ctype { ContentType::Text => b"text", ContentType::Binary => b"binary", }; w.write_string(s)?; } if !self.body.is_empty() { w.write_string(b"body")?; w.write_string(self.body)?; } if self.credits > 0 { w.write_string(b"credits")?; w.write_int(self.credits as isize)?; } if self.more { w.write_string(b"more")?; w.write_bool(true)?; } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for ResponseData<'buf, 'scratch> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let mut credits = 0; let mut more = false; let mut code = 0; let mut reason = ""; let mut content_type = None; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } "more" => { let b = tnetstring::parse_bool(e.data).field("more")?; more = b; } "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = x as u16; } "reason" => { let s = tnetstring::parse_string(e.data).field("reason")?; let s = str::from_utf8(s).field("reason")?; reason = s; } "headers" => { let headers = tnetstring::parse_array(e.data).field("headers")?; for ha in headers { let ha = ha?; if scratch.remaining_capacity() == 0 { return Err(ParseError::TooManyHeaders); } let mut hi = tnetstring::parse_array(ha.data).field("header item")?; let name = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let name = tnetstring::parse_string(name.data).field("header name")?; let name = str::from_utf8(name).field("header name")?; let value = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let value = tnetstring::parse_string(value.data).field("header value")?; scratch.push(Header { name, value }); } } "content-type" => { let s = tnetstring::parse_string(e.data).field("content-type")?; content_type = Some(match s { b"binary" => ContentType::Binary, _ => ContentType::Text, }); } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } Ok(Self { credits, more, code, reason, headers: scratch.as_slice(), content_type, body, }) } } pub struct RequestErrorData<'a> { pub condition: &'a str, } impl<'a> Serialize<'a> for RequestErrorData<'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { w.write_string(b"condition")?; w.write_string(self.condition.as_bytes())?; Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for RequestErrorData<'buf> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, _scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let mut condition = ""; for e in root { let e = e?; if e.key == "condition" { let s = tnetstring::parse_string(e.data).field("condition")?; let s = str::from_utf8(s).field("condition")?; condition = s; } } Ok(Self { condition }) } } pub struct RejectedInfo<'buf, 'headers> { pub code: u16, pub reason: &'buf str, pub headers: &'headers [Header<'buf>], pub body: &'buf [u8], } pub struct ResponseErrorData<'buf, 'headers> { pub condition: &'buf str, pub rejected_info: Option>, // rejected (websocket) } impl<'a> Serialize<'a> for ResponseErrorData<'a, 'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { w.write_string(b"condition")?; w.write_string(self.condition.as_bytes())?; if let Some(ri) = &self.rejected_info { w.write_string(b"code")?; w.write_int(ri.code as isize)?; w.write_string(b"reason")?; w.write_string(ri.reason.as_bytes())?; if !ri.headers.is_empty() { w.write_string(b"headers")?; w.start_array()?; for h in ri.headers.iter() { w.start_array()?; w.write_string(h.name.as_bytes())?; w.write_string(h.value)?; w.end_array()?; } w.end_array()?; } w.write_string(b"body")?; w.write_string(ri.body)?; } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for ResponseErrorData<'buf, 'scratch> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let mut condition = ""; let mut code = 0; let mut reason = ""; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "condition" => { let s = tnetstring::parse_string(e.data).field("condition")?; let s = str::from_utf8(s).field("condition")?; condition = s; } "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = x as u16; } "reason" => { let s = tnetstring::parse_string(e.data).field("reason")?; let s = str::from_utf8(s).field("reason")?; reason = s; } "headers" => { let headers = tnetstring::parse_array(e.data).field("headers")?; for ha in headers { let ha = ha?; if scratch.remaining_capacity() == 0 { return Err(ParseError::TooManyHeaders); } let mut hi = tnetstring::parse_array(ha.data).field("header item")?; let name = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let name = tnetstring::parse_string(name.data).field("header name")?; let name = str::from_utf8(name).field("header name")?; let value = match hi.next() { Some(Ok(name)) => name, Some(Err(e)) => { return Err(e.into()); } None => { return Err(ParseError::InvalidHeader); } }; let value = tnetstring::parse_string(value.data).field("header value")?; scratch.push(Header { name, value }); } } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } let rejected_info = if condition == "rejected" { Some(RejectedInfo { code, reason, headers: scratch.as_slice(), body, }) } else { None }; Ok(Self { condition, rejected_info, }) } } pub struct CreditData { pub credits: u32, } impl Serialize<'_> for CreditData { fn serialize(&self, w: &mut tnetstring::Writer) -> Result<(), io::Error> { w.write_string(b"credits")?; w.write_int(self.credits as isize)?; Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for CreditData { type Parsed = Self; fn parse( root: tnetstring::MapIterator, _scratch: &mut HeadersScratch, ) -> Result { let mut credits = 0; for e in root { let e = e?; if e.key == "credits" { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } } Ok(Self { credits }) } } pub struct CloseData<'a> { // code, reason pub status: Option<(u16, &'a str)>, } impl<'a> Serialize<'a> for CloseData<'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { if let Some(status) = self.status { w.write_string(b"code")?; w.write_int(status.0 as isize)?; if !status.1.is_empty() { w.write_string(b"body")?; w.write_string(status.1.as_bytes())?; } } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for CloseData<'buf> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, _scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let mut code = None; let mut reason = ""; for e in root { let e = e?; match e.key { "code" => { let x = tnetstring::parse_int(e.data).field("code")?; if x < 0 { return Err(ParseError::NegativeInt("code")); } code = Some(x as u16); } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; let s = str::from_utf8(s).field("condition")?; reason = s; } _ => {} // skip unknown fields } } if let Some(code) = code { Ok(Self { status: Some((code, reason)), }) } else { Ok(Self { status: None }) } } } fn parse_ping_or_pong(root: tnetstring::MapIterator<'_>) -> Result<(u32, &[u8]), ParseError> { let mut credits = 0; let mut body = EMPTY_BYTES; for e in root { let e = e?; match e.key { "credits" => { let x = tnetstring::parse_int(e.data).field("credits")?; if x < 0 { return Err(ParseError::NegativeInt("credits")); } credits = x as u32; } "body" => { let s = tnetstring::parse_string(e.data).field("body")?; body = s; } _ => {} // skip unknown fields } } Ok((credits, body)) } pub struct PingData<'a> { pub credits: u32, pub body: &'a [u8], } impl<'a> Serialize<'a> for PingData<'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { if !self.body.is_empty() { w.write_string(b"body")?; w.write_string(self.body)?; } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for PingData<'buf> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, _scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let (credits, body) = parse_ping_or_pong(root)?; Ok(Self { credits, body }) } } pub struct PongData<'a> { pub credits: u32, pub body: &'a [u8], } impl<'a> Serialize<'a> for PongData<'a> { fn serialize(&self, w: &mut tnetstring::Writer<'a, '_>) -> Result<(), io::Error> { if !self.body.is_empty() { w.write_string(b"body")?; w.write_string(self.body)?; } Ok(()) } } impl<'buf: 'scratch, 'scratch> Parse<'buf, 'scratch> for PongData<'buf> { type Parsed = Self; fn parse( root: tnetstring::MapIterator<'buf>, _scratch: &'scratch mut HeadersScratch<'buf>, ) -> Result { let (credits, body) = parse_ping_or_pong(root)?; Ok(Self { credits, body }) } } pub enum RequestPacket<'buf, 'headers> { Unknown, Data(RequestData<'buf, 'headers>), Error(RequestErrorData<'buf>), Credit(CreditData), KeepAlive, Cancel, HandoffStart, HandoffProceed, Close(CloseData<'buf>), Ping(PingData<'buf>), Pong(PongData<'buf>), } pub enum ResponsePacket<'buf, 'headers> { Unknown, Data(ResponseData<'buf, 'headers>), Error(ResponseErrorData<'buf, 'headers>), Credit(CreditData), KeepAlive, Cancel, HandoffStart, HandoffProceed, Close(CloseData<'buf>), Ping(PingData<'buf>), Pong(PongData<'buf>), } pub fn parse_ids<'buf, 'scratch>( src: &'buf [u8], scratch: &'scratch mut ParseScratch<'buf>, ) -> Result<(&'buf [u8], &'scratch [Id<'buf>]), ParseError> { if src.is_empty() || src[0] != b'T' { return Err(ParseError::Unrecognized); } let root = tnetstring::parse_map(&src[1..]).field("root")?; let mut from = EMPTY_BYTES; for e in root { let e = e?; match e.key { "from" => { from = tnetstring::parse_string(e.data).field("from")?; } "id" => match e.ftype { tnetstring::FrameType::Array => { for idm in tnetstring::parse_array(e.data)? { let idm = idm?; if scratch.ids.remaining_capacity() == 0 { return Err(ParseError::TooManyIds); } let mut id = EMPTY_BYTES; for m in tnetstring::parse_map(idm.data)? { let m = m?; if m.key == "id" { let s = tnetstring::parse_string(m.data).field("id")?; id = s; } } scratch.ids.push(Id { id, seq: None }); } } tnetstring::FrameType::String => { let s = tnetstring::parse_string(e.data)?; scratch.ids.push(Id { id: s, seq: None }); } _ => { return Err(ParseError::NotMapOrString("id")); } }, _ => {} // skip other fields } } Ok((from, scratch.ids.as_slice())) } pub trait PacketParse<'buf: 'scratch, 'scratch> { type Parsed; fn parse( src: &'buf [u8], scratch: &'scratch mut ParseScratch<'buf>, ) -> Result; } pub struct Request<'buf, 'ids, 'headers> { pub from: &'buf [u8], pub ids: &'ids [Id<'buf>], pub multi: bool, pub ptype: RequestPacket<'buf, 'headers>, pub ptype_str: &'buf str, } impl<'buf, 'ids, 'headers> Request<'buf, 'ids, 'headers> { pub fn new_data( from: &'buf [u8], ids: &'ids [Id<'buf>], data: RequestData<'buf, 'headers>, ) -> Self { Self::new(from, ids, RequestPacket::Data(data)) } pub fn new_error(from: &'buf [u8], ids: &'ids [Id<'buf>], condition: &'buf str) -> Self { Self::new( from, ids, RequestPacket::Error(RequestErrorData { condition }), ) } pub fn new_credit(from: &'buf [u8], ids: &'ids [Id<'buf>], credits: u32) -> Self { Self::new(from, ids, RequestPacket::Credit(CreditData { credits })) } pub fn new_keep_alive(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, RequestPacket::KeepAlive) } pub fn new_cancel(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, RequestPacket::Cancel) } pub fn new_handoff_start(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, RequestPacket::HandoffStart) } pub fn new_handoff_proceed(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, RequestPacket::HandoffProceed) } pub fn new_close( from: &'buf [u8], ids: &'ids [Id<'buf>], status: Option<(u16, &'buf str)>, ) -> Self { Self::new(from, ids, RequestPacket::Close(CloseData { status })) } pub fn new_ping(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self::new( from, ids, RequestPacket::Ping(PingData { credits: 0, body }), ) } pub fn new_pong(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self::new( from, ids, RequestPacket::Pong(PongData { credits: 0, body }), ) } pub fn serialize(&self, dest: &mut [u8]) -> Result { if dest.is_empty() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } dest[0] = b'T'; let mut cursor = io::Cursor::new(&mut dest[1..]); let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; let common = CommonData { from: self.from, ids: self.ids, multi: self.multi, ptype_str: match &self.ptype { RequestPacket::Data(_) => "", RequestPacket::Error(_) => "error", RequestPacket::Credit(_) => "credit", RequestPacket::KeepAlive => "keep-alive", RequestPacket::Cancel => "cancel", RequestPacket::HandoffStart => "handoff-start", RequestPacket::HandoffProceed => "handoff-proceed", RequestPacket::Close(_) => "close", RequestPacket::Ping(_) => "ping", RequestPacket::Pong(_) => "pong", RequestPacket::Unknown => panic!("invalid packet type"), }, }; common.serialize(&mut w)?; match &self.ptype { RequestPacket::Data(data) => data.serialize(&mut w)?, RequestPacket::Error(data) => data.serialize(&mut w)?, RequestPacket::Credit(data) => data.serialize(&mut w)?, RequestPacket::Close(data) => data.serialize(&mut w)?, RequestPacket::Ping(data) => data.serialize(&mut w)?, RequestPacket::Pong(data) => data.serialize(&mut w)?, _ => {} } w.end_map()?; w.flush()?; Ok((cursor.position() as usize) + 1) } fn new(from: &'buf [u8], ids: &'ids [Id<'buf>], ptype: RequestPacket<'buf, 'headers>) -> Self { Self { from, ids, multi: false, ptype, ptype_str: "", } } } impl<'buf: 'scratch, 'scratch> PacketParse<'buf, 'scratch> for Request<'buf, 'scratch, 'scratch> { type Parsed = Self; fn parse( src: &'buf [u8], scratch: &'scratch mut ParseScratch<'buf>, ) -> Result { if src.is_empty() || src[0] != b'T' { return Err(ParseError::Unrecognized); } let root = tnetstring::parse_map(&src[1..]).field("root")?; let CommonData { from, ids, multi, ptype_str, } = CommonData::parse(root, &mut scratch.ids)?; let ptype = match ptype_str { // data "" => RequestPacket::Data(RequestData::parse(root, &mut scratch.headers)?), "error" => RequestPacket::Error(RequestErrorData::parse(root, &mut scratch.headers)?), "credit" => RequestPacket::Credit(CreditData::parse(root, &mut scratch.headers)?), "keep-alive" => RequestPacket::KeepAlive, "cancel" => RequestPacket::Cancel, "handoff-start" => RequestPacket::HandoffStart, "handoff-proceed" => RequestPacket::HandoffProceed, "close" => RequestPacket::Close(CloseData::parse(root, &mut scratch.headers)?), "ping" => RequestPacket::Ping(PingData::parse(root, &mut scratch.headers)?), "pong" => RequestPacket::Pong(PongData::parse(root, &mut scratch.headers)?), _ => RequestPacket::Unknown, }; Ok(Self { from, ids, multi, ptype, ptype_str, }) } } pub struct Response<'buf, 'ids, 'headers> { pub from: &'buf [u8], pub ids: &'ids [Id<'buf>], pub multi: bool, pub ptype: ResponsePacket<'buf, 'headers>, pub ptype_str: &'buf str, } impl<'buf, 'ids, 'headers> Response<'buf, 'ids, 'headers> { pub fn new_data( from: &'buf [u8], ids: &'ids [Id<'buf>], data: ResponseData<'buf, 'headers>, ) -> Self { Self::new(from, ids, ResponsePacket::Data(data)) } pub fn new_error( from: &'buf [u8], ids: &'ids [Id<'buf>], edata: ResponseErrorData<'buf, 'headers>, ) -> Self { Self::new(from, ids, ResponsePacket::Error(edata)) } pub fn new_credit(from: &'buf [u8], ids: &'ids [Id<'buf>], credits: u32) -> Self { Self::new(from, ids, ResponsePacket::Credit(CreditData { credits })) } pub fn new_keep_alive(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, ResponsePacket::KeepAlive) } pub fn new_cancel(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, ResponsePacket::Cancel) } pub fn new_handoff_proceed(from: &'buf [u8], ids: &'ids [Id<'buf>]) -> Self { Self::new(from, ids, ResponsePacket::HandoffProceed) } pub fn new_close( from: &'buf [u8], ids: &'ids [Id<'buf>], status: Option<(u16, &'buf str)>, ) -> Self { Self::new(from, ids, ResponsePacket::Close(CloseData { status })) } pub fn new_ping(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self::new( from, ids, ResponsePacket::Ping(PingData { credits: 0, body }), ) } pub fn new_pong(from: &'buf [u8], ids: &'ids [Id<'buf>], body: &'buf [u8]) -> Self { Self::new( from, ids, ResponsePacket::Pong(PongData { credits: 0, body }), ) } pub fn serialize(&self, dest: &mut [u8]) -> Result { if dest.is_empty() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } dest[0] = b'T'; let mut cursor = io::Cursor::new(&mut dest[1..]); let mut w = tnetstring::Writer::new(&mut cursor); w.start_map()?; let common = CommonData { from: self.from, ids: self.ids, multi: self.multi, ptype_str: match &self.ptype { ResponsePacket::Data(_) => "", ResponsePacket::Error(_) => "error", ResponsePacket::Credit(_) => "credit", ResponsePacket::KeepAlive => "keep-alive", ResponsePacket::Cancel => "cancel", ResponsePacket::HandoffStart => "handoff-start", ResponsePacket::HandoffProceed => "handoff-proceed", ResponsePacket::Close(_) => "close", ResponsePacket::Ping(_) => "ping", ResponsePacket::Pong(_) => "pong", ResponsePacket::Unknown => panic!("invalid packet type"), }, }; common.serialize(&mut w)?; match &self.ptype { ResponsePacket::Data(data) => data.serialize(&mut w)?, ResponsePacket::Error(data) => data.serialize(&mut w)?, ResponsePacket::Credit(data) => data.serialize(&mut w)?, ResponsePacket::Close(data) => data.serialize(&mut w)?, ResponsePacket::Ping(data) => data.serialize(&mut w)?, ResponsePacket::Pong(data) => data.serialize(&mut w)?, _ => {} } w.end_map()?; w.flush()?; Ok((cursor.position() as usize) + 1) } fn new(from: &'buf [u8], ids: &'ids [Id<'buf>], ptype: ResponsePacket<'buf, 'headers>) -> Self { Self { from, ids, multi: false, ptype, ptype_str: "", } } } impl<'buf: 'scratch, 'scratch> PacketParse<'buf, 'scratch> for Response<'buf, 'scratch, 'scratch> { type Parsed = Self; fn parse( src: &'buf [u8], scratch: &'scratch mut ParseScratch<'buf>, ) -> Result { if src.is_empty() || src[0] != b'T' { return Err(ParseError::Unrecognized); } let root = tnetstring::parse_map(&src[1..]).field("root")?; let CommonData { from, ids, multi, ptype_str, } = CommonData::parse(root, &mut scratch.ids)?; let ptype = match ptype_str { // data "" => ResponsePacket::Data(ResponseData::parse(root, &mut scratch.headers)?), "error" => ResponsePacket::Error(ResponseErrorData::parse(root, &mut scratch.headers)?), "credit" => ResponsePacket::Credit(CreditData::parse(root, &mut scratch.headers)?), "keep-alive" => ResponsePacket::KeepAlive, "cancel" => ResponsePacket::Cancel, "handoff-start" => ResponsePacket::HandoffStart, "handoff-proceed" => ResponsePacket::HandoffProceed, "close" => ResponsePacket::Close(CloseData::parse(root, &mut scratch.headers)?), "ping" => ResponsePacket::Ping(PingData::parse(root, &mut scratch.headers)?), "pong" => ResponsePacket::Pong(PongData::parse(root, &mut scratch.headers)?), _ => ResponsePacket::Unknown, }; Ok(Self { from, ids, multi, ptype, ptype_str, }) } } pub struct OwnedPacket { inner: T, _scratch: arena::Rc>>, _src: arena::Arc, } impl OwnedPacket where T: PacketParse<'static, 'static, Parsed = T>, { pub fn parse( src: arena::Arc, offset: usize, scratch: arena::Rc>>, ) -> Result { let src_ref: &[u8] = &src.get()[offset..]; // SAFETY: Self will take ownership of src, and the bytes referred to // by src_ref are on the heap, and src will not be modified or // dropped until Self is dropped, so the bytes referred to by src_ref // will remain valid for the lifetime of Self let src_ref: &'static [u8] = unsafe { mem::transmute(src_ref) }; // SAFETY: Self will take ownership of scratch, and the location // referred to by scratch_mut is in an arena, and scratch will not // be dropped until Self is dropped, so the location referred to by // scratch_mut will remain valid for the lifetime of Self // // further, it is safe for T::parse() to write references to src_ref // into scratch_mut, because src_ref and scratch_mut have the same // lifetime let scratch_mut: &'static mut ParseScratch<'static> = unsafe { scratch.get().as_ptr().as_mut().unwrap() }; let inner = T::parse(src_ref, scratch_mut)?; Ok(Self { inner, _scratch: scratch, _src: src, }) } } pub type OwnedRequest = OwnedPacket>; impl OwnedRequest { // the lifetimes are needed #[allow(clippy::needless_lifetimes)] pub fn get<'a>(&'a self) -> &'a Request<'a, 'a, 'a> { let req: &Request = &self.inner; // SAFETY: here we reduce the inner lifetimes from 'static to that of // the owning object, which is fine unsafe { mem::transmute(req) } } } pub type OwnedResponse = OwnedPacket>; impl OwnedResponse { // the lifetimes are needed #[allow(clippy::needless_lifetimes)] pub fn get<'a>(&'a self) -> &'a Response<'a, 'a, 'a> { let resp: &Response = &self.inner; // SAFETY: here we reduce the inner lifetimes from 'static to that of // the owning object, which is fine unsafe { mem::transmute(resp) } } } #[cfg(test)] mod tests { use super::*; use std::rc::Rc; use std::sync::Arc; #[test] fn test_req_serialize() { struct Test { name: &'static str, req: Request<'static, 'static, 'static>, expected: &'static str, } // data, error, credit, keepalive, cancel, handoffstart/proceed, close, ping, pong let tests = [ Test { name: "data", req: Request { from: b"client", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: RequestPacket::Data(RequestData { credits: 0, more: true, stream: false, router_resp: false, max_size: 0, timeout: 0, method: "POST", uri: "http://example.com/path", headers: &[Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello", peer_address: "", peer_port: 0, connect_host: "", connect_port: 0, ignore_policies: false, trust_connect_host: false, ignore_tls_errors: false, follow_redirects: false, }), ptype_str: "", }, expected: concat!( "T161:4:from,6:client,2:id,1:1,3:seq,1:0#6:method,4:POST,3:uri", ",23:http://example.com/path,7:headers,34:30:12:Content-Type,1", "0:text/plain,]]4:body,5:hello,4:more,4:true!}", ), }, Test { name: "error", req: Request { from: b"client", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: RequestPacket::Error(RequestErrorData { condition: "bad-request", }), ptype_str: "", }, expected: concat!( "T77:4:from,6:client,2:id,1:1,3:seq,1:0#4:type,5:error,9:condi", "tion,11:bad-request,}", ), }, ]; for test in tests.iter() { let mut data = [0; 1024]; let size = test.req.serialize(&mut data).unwrap(); assert_eq!( str::from_utf8(&data[..size]).unwrap(), test.expected, "test={}", test.name ); } } #[test] fn test_resp_serialize() { struct Test { name: &'static str, resp: Response<'static, 'static, 'static>, expected: &'static str, } // data, error, credit, keepalive, cancel, handoffstart/proceed, close, ping, pong let tests = [ Test { name: "data", resp: Response { from: b"server", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: ResponsePacket::Data(ResponseData { credits: 0, more: true, code: 200, reason: "OK", headers: &[Header { name: "Content-Type", value: b"text/plain", }], content_type: None, body: b"hello", }), ptype_str: "", }, expected: concat!( "T139:4:from,6:server,2:id,1:1,3:seq,1:0#4:code,3:200#6:reason", ",2:OK,7:headers,34:30:12:Content-Type,10:text/plain,]]4:body,", "5:hello,4:more,4:true!}", ), }, Test { name: "error", resp: Response { from: b"server", ids: &[Id { id: b"1", seq: Some(0), }], multi: false, ptype: ResponsePacket::Error(ResponseErrorData { condition: "bad-request", rejected_info: None, }), ptype_str: "", }, expected: concat!( "T77:4:from,6:server,2:id,1:1,3:seq,1:0#4:type,5:error,9:condi", "tion,11:bad-request,}", ), }, ]; for test in tests.iter() { let mut data = [0; 1024]; let size = test.resp.serialize(&mut data).unwrap(); assert_eq!( str::from_utf8(&data[..size]).unwrap(), test.expected, "test={}", test.name ); } } #[test] fn test_req_parse() { let data = concat!( "T198:4:more,4:true!7:headers,34:30:12:Content-Type,10:text/pl", "ain,]]12:content-type,6:binary,4:from,6:client,2:id,1:1,6:met", "hod,4:POST,3:uri,19:https://example.com,7:credits,3:100#3:seq", ",1:0#4:body,5:hello,}" ) .as_bytes(); let mut scratch = ParseScratch::new(); let req = Request::parse(&data, &mut scratch).unwrap(); assert_eq!(req.from, b"client"); assert_eq!(req.ids.len(), 1); assert_eq!(req.ids[0].id, b"1"); assert_eq!(req.ids[0].seq, Some(0)); let rdata = match req.ptype { RequestPacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.credits, 100); assert_eq!(rdata.more, true); assert_eq!(rdata.method, "POST"); assert_eq!(rdata.uri, "https://example.com"); assert_eq!(rdata.headers.len(), 1); assert_eq!(rdata.headers[0].name, "Content-Type"); assert_eq!(rdata.headers[0].value, b"text/plain"); assert_eq!(rdata.body, b"hello"); let ctype = rdata.content_type.unwrap(); assert_eq!(ctype, ContentType::Binary); } #[test] fn test_resp_parse() { let data = concat!( "T208:4:more,4:true!7:headers,34:30:12:Content-Type,10:text/pl", "ain,]]12:content-type,6:binary,4:from,6:server,2:id,1:1,6:rea", "son,2:OK,7:credits,3:100#9:user-data,12:3:foo,3:bar,}3:seq,1:", "0#4:code,3:200#4:body,5:hello,}" ) .as_bytes(); let mut scratch = ParseScratch::new(); let resp = Response::parse(&data, &mut scratch).unwrap(); assert_eq!(resp.from, b"server"); assert_eq!(resp.ids.len(), 1); assert_eq!(resp.ids[0].id, b"1"); assert_eq!(resp.ids[0].seq, Some(0)); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.credits, 100); assert_eq!(rdata.more, true); assert_eq!(rdata.code, 200); assert_eq!(rdata.reason, "OK"); assert_eq!(rdata.headers.len(), 1); assert_eq!(rdata.headers[0].name, "Content-Type"); assert_eq!(rdata.headers[0].value, b"text/plain"); assert_eq!(rdata.body, b"hello"); let ctype = rdata.content_type.unwrap(); assert_eq!(ctype, ContentType::Binary); } #[test] fn test_owned_req_parse() { let data = concat!( "T198:4:more,4:true!7:headers,34:30:12:Content-Type,10:text/pl", "ain,]]12:content-type,6:binary,4:from,6:client,2:id,1:1,6:met", "hod,4:POST,3:uri,19:https://example.com,7:credits,3:100#3:seq", ",1:0#4:body,5:hello,}" ) .as_bytes(); let msg_memory = Arc::new(arena::ArcMemory::new(1)); let scratch_memory = Rc::new(arena::RcMemory::new(1)); let msg = arena::Arc::new(zmq::Message::from(data), &msg_memory).unwrap(); let scratch = arena::Rc::new(RefCell::new(ParseScratch::new()), &scratch_memory).unwrap(); let req = OwnedRequest::parse(msg, 0, scratch).unwrap(); let req = req.get(); assert_eq!(req.from, b"client"); assert_eq!(req.ids.len(), 1); assert_eq!(req.ids[0].id, b"1"); assert_eq!(req.ids[0].seq, Some(0)); } #[test] fn test_owned_resp_parse() { let data = concat!( "addr T208:4:more,4:true!7:headers,34:30:12:Content-Type,10:te", "xt/plain,]]12:content-type,6:binary,4:from,6:server,2:id,1:1,", "6:reason,2:OK,7:credits,3:100#9:user-data,12:3:foo,3:bar,}3:s", "eq,1:0#4:code,3:200#4:body,5:hello,}" ) .as_bytes(); let msg_memory = Arc::new(arena::ArcMemory::new(1)); let scratch_memory = Rc::new(arena::RcMemory::new(1)); let msg = arena::Arc::new(zmq::Message::from(data), &msg_memory).unwrap(); let scratch = arena::Rc::new(RefCell::new(ParseScratch::new()), &scratch_memory).unwrap(); let resp = OwnedResponse::parse(msg, 5, scratch).unwrap(); let resp = resp.get(); assert_eq!(resp.from, b"server"); assert_eq!(resp.ids.len(), 1); assert_eq!(resp.ids[0].id, b"1"); assert_eq!(resp.ids[0].seq, Some(0)); } } pushpin-1.41.0/src/connmgr/zhttpsocket.rs000066400000000000000000003402361504671364300204660ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::connmgr::zhttppacket::{parse_ids, Id, ParseScratch}; use crate::core::arena; use crate::core::buffer::trim_for_display; use crate::core::channel::{ self, AsyncReceiver, AsyncSender, RecvFuture, WaitWritableFuture, REGISTRATIONS_PER_CHANNEL, }; use crate::core::event; use crate::core::executor::Executor; use crate::core::list; use crate::core::reactor::Reactor; use crate::core::select::{select_10, select_option, select_slice, Select10}; use crate::core::tnetstring; use crate::core::zmq::{ AsyncZmqSocket, MultipartHeader, SpecInfo, ZmqSendFuture, ZmqSendToFuture, ZmqSocket, REGISTRATIONS_PER_ZMQSOCKET, }; use arrayvec::{ArrayString, ArrayVec}; use log::{debug, error, log_enabled, trace, warn}; use slab::Slab; use std::cell::{Cell, RefCell}; use std::collections::HashMap; use std::convert::TryFrom; use std::fmt; use std::future::Future; use std::io; use std::marker; use std::pin::pin; use std::pin::Pin; use std::str; use std::str::FromStr; use std::sync::{mpsc, Arc, Mutex}; use std::task::{Context, Poll}; use std::thread; use std::time::Duration; pub const FROM_MAX: usize = 64; pub const REQ_ID_MAX: usize = 64; const HANDLES_MAX: usize = 1_024; const STREAM_OUT_STREAM_DELAY: Duration = Duration::from_millis(50); const LOG_METADATA_MAX: usize = 1_000; const LOG_CONTENT_MAX: usize = 1_000; const EXECUTOR_TASKS_MAX: usize = 1; struct Packet<'a> { map_frame: tnetstring::Frame<'a>, content_field: Option<&'a str>, } impl Packet<'_> { fn fmt_metadata(&self, f: &mut dyn io::Write) -> Result<(), io::Error> { let it = tnetstring::MapIterator::new(self.map_frame.data); write!(f, "{{ ")?; let mut first = true; for mi in it { let mi = match mi { Ok(mi) => mi, Err(_) => return Ok(()), }; if let Some(field) = self.content_field { if mi.key == field { continue; } } // can't fail let (frame, _) = tnetstring::parse_frame(mi.data).unwrap(); if !first { write!(f, ", ")?; } first = false; write!(f, "\"{}\": {}", mi.key, frame)?; } write!(f, " }}") } fn fmt_content(&self, f: &mut dyn io::Write) -> Result, io::Error> { let field = match self.content_field { Some(field) => field, None => return Ok(None), }; let it = tnetstring::MapIterator::new(self.map_frame.data); let mut ptype = &b""[..]; let mut condition = None; let mut content = None; for mi in it { let mi = match mi { Ok(mi) => mi, Err(_) => return Ok(None), }; match mi.key { "type" => { ptype = match tnetstring::parse_string(mi.data) { Ok(s) => s, Err(_) => return Ok(None), }; } "condition" => { condition = match tnetstring::parse_string(mi.data) { Ok(s) => Some(s), Err(_) => return Ok(None), }; } _ => {} } // can't fail let (frame, _) = tnetstring::parse_frame(mi.data).unwrap(); if mi.key == field { content = Some(frame); } } // only take content from data (ptype empty), close, or rejection packets if ptype.is_empty() || ptype == b"close" || (ptype == b"error" && condition == Some(b"rejected")) { if let Some(frame) = content { write!(f, "{}", frame)?; return Ok(Some(frame.data.len())); } else { return Ok(Some(0)); } } Ok(None) } } impl fmt::Display for Packet<'_> { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { let mut meta = Vec::new(); if self.fmt_metadata(&mut meta).is_err() { return Err(fmt::Error); } // formatted output is guaranteed to be utf8 let meta = String::from_utf8(meta).unwrap(); let meta = trim_for_display(&meta, LOG_METADATA_MAX); if self.content_field.is_some() { let mut content = Vec::new(); let clen = match self.fmt_content(&mut content) { Ok(clen) => clen, Err(_) => return Err(fmt::Error), }; if let Some(clen) = clen { // formatted output is guaranteed to be utf8 let content = String::from_utf8(content).unwrap(); let content = trim_for_display(&content, LOG_CONTENT_MAX); return write!(f, "{} {} {}", meta, clen, content); } } write!(f, "{}", meta) } } fn packet_to_string(data: &[u8]) -> String { if data.is_empty() { return String::from(""); } if data[0] == b'T' { let (frame, _) = match tnetstring::parse_frame(&data[1..]) { Ok(frame) => frame, Err(e) => return format!("", e), }; if frame.ftype != tnetstring::FrameType::Map { return String::from(""); } let p = Packet { map_frame: frame, content_field: Some("body"), }; p.to_string() } else { // maybe it's addr-prefixed let mut pos = None; for (index, b) in data.iter().enumerate() { if *b == b' ' { pos = Some(index); break; } } if pos.is_none() { return String::from(""); } let pos = pos.unwrap(); let addr = match str::from_utf8(&data[..pos]) { Ok(addr) => addr, Err(e) => return format!("", e), }; let payload = &data[(pos + 1)..]; if payload.is_empty() { return String::from(""); } if payload[0] != b'T' { return String::from(""); } let payload = &data[(pos + 2)..]; let (frame, _) = match tnetstring::parse_frame(payload) { Ok(frame) => frame, Err(e) => return format!("", e), }; if frame.ftype != tnetstring::FrameType::Map { return String::from(""); } let p = Packet { map_frame: frame, content_field: Some("body"), }; format!("{} {}", addr, p) } } pub type SessionKey = (ArrayVec, ArrayVec); struct SessionItem { key: SessionKey, handle_index: usize, } enum SessionAddError { Full, Exists, } struct SessionDataInner { items: Slab, items_by_key: HashMap, } #[derive(Clone)] struct SessionData { inner: Arc>, } impl SessionData { fn new(capacity: usize) -> Self { Self { inner: Arc::new(Mutex::new(SessionDataInner { items: Slab::with_capacity(capacity), items_by_key: HashMap::with_capacity(capacity), })), } } fn add(&self, key: SessionKey, handle_index: usize) -> Result { let inner = &mut *self.inner.lock().unwrap(); if inner.items.len() == inner.items.capacity() { return Err(SessionAddError::Full); } if inner.items_by_key.contains_key(&key) { return Err(SessionAddError::Exists); } let item_key = inner.items.insert(SessionItem { key: key.clone(), handle_index, }); inner.items_by_key.insert(key, item_key); Ok(Session { data: self.clone(), item_key, }) } // returns handle_index fn get(&self, key: &SessionKey) -> Option { let inner = &*self.inner.lock().unwrap(); if let Some(item_key) = inner.items_by_key.get(key) { return Some(inner.items[*item_key].handle_index); } None } fn remove(&self, item_key: usize) { let inner = &mut *self.inner.lock().unwrap(); let item = &inner.items[item_key]; inner.items_by_key.remove(&item.key); inner.items.remove(item_key); } } pub struct Session { data: SessionData, item_key: usize, } impl Drop for Session { fn drop(&mut self) { self.data.remove(self.item_key); } } struct SessionTable { data: SessionData, } impl SessionTable { fn new(capacity: usize) -> Self { Self { data: SessionData::new(capacity), } } fn add(&self, key: SessionKey, handle_index: usize) -> Result { self.data.add(key, handle_index) } fn get(&self, key: &SessionKey) -> Option { self.data.get(key) } } struct ClientReqSockets { sock: AsyncZmqSocket, } struct ClientStreamSockets { out: AsyncZmqSocket, out_stream: AsyncZmqSocket, in_: AsyncZmqSocket, } struct ReqPipeEnd { sender: channel::Sender>, receiver: channel::Receiver, } struct StreamPipeEnd { sender: channel::Sender<(arena::Arc, bool)>, receiver_any: channel::Receiver, receiver_addr: channel::Receiver<(ArrayVec, zmq::Message)>, } struct AsyncReqPipeEnd { sender: AsyncSender>, receiver: AsyncReceiver, } struct AsyncStreamPipeEnd { sender: AsyncSender<(arena::Arc, bool)>, receiver_any: AsyncReceiver, receiver_addr: AsyncReceiver<(ArrayVec, zmq::Message)>, } enum ControlRequest { Stop, SetClientReq(Vec), SetClientStream(Vec, Vec, Vec), AddClientReqHandle(ReqPipeEnd, ArrayString<8>), AddClientStreamHandle(StreamPipeEnd, ArrayString<8>), } struct ServerStreamSockets { in_: AsyncZmqSocket, in_stream: AsyncZmqSocket, out: AsyncZmqSocket, specs_applied: bool, } struct ServerReqPipeEnd { sender: channel::Sender<(MultipartHeader, arena::Arc)>, receiver: channel::Receiver<(MultipartHeader, zmq::Message)>, } struct ServerStreamPipeEnd { sender_any: channel::Sender<(arena::Arc, Session)>, sender_direct: channel::Sender>, receiver: channel::Receiver<(Option>, zmq::Message)>, } struct AsyncServerReqPipeEnd { sender: AsyncSender<(MultipartHeader, arena::Arc)>, receiver: AsyncReceiver<(MultipartHeader, zmq::Message)>, } struct AsyncServerStreamPipeEnd { sender_any: AsyncSender<(arena::Arc, Session)>, sender_direct: AsyncSender>, receiver: AsyncReceiver<(Option>, zmq::Message)>, } enum ServerControlRequest { Stop, SetServerReq(Vec), SetServerStream(Vec, Vec, Vec), AddServerReqHandle(ServerReqPipeEnd), AddServerStreamHandle(ServerStreamPipeEnd), } type ControlResponse = Result<(), String>; struct ReqPipe { pe: AsyncReqPipeEnd, filter: ArrayString<8>, valid: Cell, } struct StreamPipe { pe: AsyncStreamPipeEnd, filter: ArrayString<8>, valid: Cell, } struct ServerReqPipe { pe: AsyncServerReqPipeEnd, valid: Cell, } struct ServerStreamPipe { pe: AsyncServerStreamPipeEnd, valid: Cell, } struct RecvWrapperFuture<'a, T> { fut: RecvFuture<'a, T>, nkey: usize, } impl Future for RecvWrapperFuture<'_, T> { type Output = (usize, Result); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let r = &mut *self; match Pin::new(&mut r.fut).poll(cx) { Poll::Ready(result) => match result { Ok(value) => Poll::Ready((r.nkey, Ok(value))), Err(mpsc::RecvError) => Poll::Ready((r.nkey, Err(mpsc::RecvError))), }, Poll::Pending => Poll::Pending, } } } struct RecvScratch { tasks: arena::ReusableVec, slice_scratch: Vec, _marker: marker::PhantomData, } impl RecvScratch { fn new(capacity: usize) -> Self { Self { tasks: arena::ReusableVec::new::>(capacity), slice_scratch: Vec::with_capacity(capacity), _marker: marker::PhantomData, } } fn get<'a>( &mut self, ) -> ( arena::ReusableVecHandle<'_, RecvWrapperFuture<'a, T>>, &mut Vec, ) { (self.tasks.get_as_new(), &mut self.slice_scratch) } } struct CheckSendScratch { tasks: arena::ReusableVec, slice_scratch: Vec, _marker: marker::PhantomData, } impl CheckSendScratch { fn new(capacity: usize) -> Self { Self { tasks: arena::ReusableVec::new::>(capacity), slice_scratch: Vec::with_capacity(capacity), _marker: marker::PhantomData, } } fn get<'a>( &mut self, ) -> ( arena::ReusableVecHandle<'_, WaitWritableFuture<'a, T>>, &mut Vec, ) { (self.tasks.get_as_new(), &mut self.slice_scratch) } } struct ReqHandles { nodes: Slab>, list: list::List, recv_scratch: RefCell>, need_cleanup: Cell, } impl ReqHandles { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_scratch: RefCell::new(RecvScratch::new(capacity)), need_cleanup: Cell::new(false), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncReqPipeEnd, filter: ArrayString<8>) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(ReqPipe { pe, filter, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } #[allow(clippy::await_holding_refcell_ref)] async fn recv(&self) -> zmq::Message { let mut scratch = self.recv_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(msg))) => return msg, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } async fn send(&self, msg: &arena::Arc, ids: &[Id<'_>]) { let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; let mut do_send = false; for id in ids.iter() { if id.id.starts_with(p.filter.as_bytes()) { do_send = true; break; } } if p.valid.get() && do_send { // blocking send. handle is expected to read as fast as possible // without downstream backpressure match p.pe.sender.send(arena::Arc::clone(msg)).await { Ok(_) => {} Err(_) => { p.valid.set(false); self.need_cleanup.set(true); } } } next = n.next; } } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&ReqPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } struct StreamHandles { nodes: Slab>, list: list::List, recv_any_scratch: RefCell>, recv_addr_scratch: RefCell, zmq::Message)>>, need_cleanup: Cell, } impl StreamHandles { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_any_scratch: RefCell::new(RecvScratch::new(capacity)), recv_addr_scratch: RefCell::new(RecvScratch::new(capacity)), need_cleanup: Cell::new(false), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncStreamPipeEnd, filter: ArrayString<8>) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(StreamPipe { pe, filter, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } #[allow(clippy::await_holding_refcell_ref)] async fn recv_any(&self) -> zmq::Message { let mut scratch = self.recv_any_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver_any.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(msg))) => return msg, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } #[allow(clippy::await_holding_refcell_ref)] async fn recv_addr(&self) -> (ArrayVec, zmq::Message) { let mut scratch = self.recv_addr_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver_addr.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(ret))) => return ret, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } async fn send(&self, msg: &arena::Arc, ids: &[Id<'_>], from_router: bool) { let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; let mut do_send = false; for id in ids.iter() { if id.id.starts_with(p.filter.as_bytes()) { do_send = true; break; } } if p.valid.get() && do_send { // blocking send. handle is expected to read as fast as possible // without downstream backpressure match p .pe .sender .send((arena::Arc::clone(msg), from_router)) .await { Ok(_) => {} Err(_) => { p.valid.set(false); self.need_cleanup.set(true); } } } next = n.next; } } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&StreamPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } struct ReqHandlesSendError(MultipartHeader); struct ServerReqHandles { nodes: Slab>, list: list::List, recv_scratch: RefCell>, check_send_scratch: RefCell)>>, need_cleanup: Cell, send_index: Cell, } impl ServerReqHandles { fn new(capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_scratch: RefCell::new(RecvScratch::new(capacity)), check_send_scratch: RefCell::new(CheckSendScratch::new(capacity)), need_cleanup: Cell::new(false), send_index: Cell::new(0), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncServerReqPipeEnd) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(ServerReqPipe { pe, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } #[allow(clippy::await_holding_refcell_ref)] async fn recv(&self) -> (MultipartHeader, zmq::Message) { let mut scratch = self.recv_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(ret))) => return ret, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } // waits until at least one handle is likely writable #[allow(clippy::await_holding_refcell_ref)] async fn check_send(&self) { let mut any_valid = false; let mut any_writable = false; for (_, p) in self.list.iter(&self.nodes) { if p.valid.get() { any_valid = true; if p.pe.sender.is_writable() { any_writable = true; break; } } } if any_writable { return; } // if there are no valid pipes then hang forever. caller can // try again by dropping the future and making a new one if !any_valid { std::future::pending::<()>().await; } // there are valid pipes but none are writable. we'll wait let mut scratch = self.check_send_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); for (_, p) in self.list.iter(&self.nodes) { if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(p.pe.sender.wait_writable()); } } select_slice(&mut tasks, slice_scratch).await; } // non-blocking send. caller should use check_send() first fn send( &self, header: MultipartHeader, msg: &arena::Arc, ) -> Result<(), ReqHandlesSendError> { if self.nodes.is_empty() { return Err(ReqHandlesSendError(header)); } let mut skip = self.send_index.get(); self.send_index.set((skip + 1) % self.nodes.len()); // select the nth ready node, else the latest ready node let mut selected = None; for (nkey, p) in self.list.iter(&self.nodes) { if p.valid.get() && p.pe.sender.is_writable() { selected = Some(nkey); } if skip > 0 { skip -= 1; } else if selected.is_some() { break; } } let nkey = match selected { Some(nkey) => nkey, None => return Err(ReqHandlesSendError(header)), }; let n = &self.nodes[nkey]; let p = &n.value; if let Err(e) = p.pe.sender.try_send((header, arena::Arc::clone(msg))) { let header = match e { mpsc::TrySendError::Full((header, _)) => header, mpsc::TrySendError::Disconnected((header, _)) => { p.valid.set(false); self.need_cleanup.set(true); header } }; return Err(ReqHandlesSendError(header)); } Ok(()) } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&ServerReqPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } enum StreamHandlesSendError { BadFormat, NoneReady, SessionExists, SessionCapacityFull, } struct ServerStreamHandles { nodes: Slab>, list: list::List, recv_scratch: RefCell>, zmq::Message)>>, check_send_any_scratch: RefCell, Session)>>, send_direct_scratch: RefCell>, need_cleanup: Cell, send_index: Cell, sessions: SessionTable, } impl ServerStreamHandles { fn new(capacity: usize, sessions_capacity: usize) -> Self { Self { nodes: Slab::with_capacity(capacity), list: list::List::default(), recv_scratch: RefCell::new(RecvScratch::new(capacity)), check_send_any_scratch: RefCell::new(CheckSendScratch::new(capacity)), send_direct_scratch: RefCell::new(Vec::with_capacity(capacity)), need_cleanup: Cell::new(false), send_index: Cell::new(0), sessions: SessionTable::new(sessions_capacity), } } fn len(&self) -> usize { self.nodes.len() } fn add(&mut self, pe: AsyncServerStreamPipeEnd) { assert!(self.nodes.len() < self.nodes.capacity()); let key = self.nodes.insert(list::Node::new(ServerStreamPipe { pe, valid: Cell::new(true), })); self.list.push_back(&mut self.nodes, key); } #[allow(clippy::await_holding_refcell_ref)] async fn recv(&self) -> (Option>, zmq::Message) { let mut scratch = self.recv_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); let mut next = self.list.head; while let Some(nkey) = next { let n = &self.nodes[nkey]; let p = &n.value; if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(RecvWrapperFuture { fut: p.pe.receiver.recv(), nkey, }); } next = n.next; } loop { match select_slice(&mut tasks, slice_scratch).await { (_, (_, Ok(ret))) => return ret, (pos, (nkey, Err(mpsc::RecvError))) => { tasks.remove(pos); let p = &self.nodes[nkey].value; p.valid.set(false); self.need_cleanup.set(true); } } } } // waits until at least one handle is likely writable #[allow(clippy::await_holding_refcell_ref)] async fn check_send_any(&self) { let mut any_valid = false; let mut any_writable = false; for (_, p) in self.list.iter(&self.nodes) { if p.valid.get() { any_valid = true; if p.pe.sender_any.is_writable() { any_writable = true; break; } } } if any_writable { return; } // if there are no valid pipes then hang forever. caller can // try again by dropping the future and making a new one if !any_valid { std::future::pending::<()>().await; } // there are valid pipes but none are writable. we'll wait let mut scratch = self.check_send_any_scratch.borrow_mut(); let (mut tasks, slice_scratch) = scratch.get(); for (_, p) in self.list.iter(&self.nodes) { if p.valid.get() { assert!(tasks.len() < tasks.capacity()); tasks.push(p.pe.sender_any.wait_writable()); } } select_slice(&mut tasks, slice_scratch).await; } // non-blocking send. caller should use check_send_any() first fn send_any( &self, msg: &arena::Arc, from: &[u8], ids: &[Id], ) -> Result<(), StreamHandlesSendError> { if from.len() > FROM_MAX || ids.is_empty() || ids[0].id.len() > REQ_ID_MAX { return Err(StreamHandlesSendError::BadFormat); } if self.nodes.is_empty() { return Err(StreamHandlesSendError::NoneReady); } let mut skip = self.send_index.get(); self.send_index.set((skip + 1) % self.nodes.len()); // select the nth ready node, else the latest ready node let mut selected = None; for (nkey, p) in self.list.iter(&self.nodes) { if p.valid.get() && p.pe.sender_any.is_writable() { selected = Some(nkey); } if skip > 0 { skip -= 1; } else if selected.is_some() { break; } } let nkey = match selected { Some(nkey) => nkey, None => return Err(StreamHandlesSendError::NoneReady), }; let n = &self.nodes[nkey]; let p = &n.value; let from = ArrayVec::try_from(from).unwrap(); let id = ArrayVec::try_from(ids[0].id).unwrap(); let key = (from, id); let session = match self.sessions.add(key, nkey) { Ok(s) => s, Err(SessionAddError::Full) => return Err(StreamHandlesSendError::SessionCapacityFull), Err(SessionAddError::Exists) => return Err(StreamHandlesSendError::SessionExists), }; if let Err(e) = p.pe.sender_any.try_send((arena::Arc::clone(msg), session)) { match e { mpsc::TrySendError::Full(_) => {} mpsc::TrySendError::Disconnected(_) => { p.valid.set(false); self.need_cleanup.set(true); } } return Err(StreamHandlesSendError::NoneReady); } Ok(()) } #[allow(clippy::await_holding_refcell_ref)] async fn send_direct(&self, msg: &arena::Arc, from: &[u8], ids: &[Id<'_>]) { if self.nodes.is_empty() { return; } let from = match ArrayVec::try_from(from) { Ok(v) => v, Err(_) => return, }; let indexes = &mut *self.send_direct_scratch.borrow_mut(); indexes.clear(); for _ in 0..self.nodes.capacity() { indexes.push(false); } for id in ids { let id = match ArrayVec::try_from(id.id) { Ok(v) => v, Err(_) => return, }; let key = (from.clone(), id); if let Some(nkey) = self.sessions.get(&key) { indexes[nkey] = true; } } for (nkey, &do_send) in indexes.iter().enumerate() { let n = match self.nodes.get(nkey) { Some(n) => n, None => continue, }; let p = &n.value; if p.valid.get() && do_send { // blocking send. handle is expected to read as fast as possible // without downstream backpressure match p.pe.sender_direct.send(arena::Arc::clone(msg)).await { Ok(_) => {} Err(_) => { p.valid.set(false); self.need_cleanup.set(true); } } } } } fn need_cleanup(&self) -> bool { self.need_cleanup.get() } fn cleanup(&mut self, f: F) where F: Fn(&ServerStreamPipe), { let mut next = self.list.head; while let Some(nkey) = next { let n = &mut self.nodes[nkey]; let p = &mut n.value; next = n.next; if !p.valid.get() { f(p); self.list.remove(&mut self.nodes, nkey); self.nodes.remove(nkey); } } self.need_cleanup.set(false); } } pub struct ClientSocketManager { handle_bound: usize, thread: Option>, control_pipe: Mutex<( channel::Sender, channel::Receiver, )>, } impl ClientSocketManager { // retained_max is the maximum number of received messages that the user // will keep around at any moment. for example, if the user plans to // set up 4 handles on the manager and read 1 message at a time from // each of the handles (i.e. process and drop a message before reading // the next), then the value here should be 4, because there would be // no more than 4 dequeued messages alive at any one time. this number // is needed to help size the internal arena pub fn new( ctx: Arc, instance_id: &str, retained_max: usize, init_hwm: usize, other_hwm: usize, handle_bound: usize, ) -> Self { let (s1, r1) = channel::channel(1); let (s2, r2) = channel::channel(1); let instance_id = String::from(instance_id); let thread = thread::Builder::new() .name("zhttpsocket".to_string()) .spawn(move || { debug!("manager thread start"); // 2 control channels, 3 channels per handle, 4 zmq sockets let channels = 2 + (HANDLES_MAX * 3); let zmqsockets = 4; let registrations_max = (channels * REGISTRATIONS_PER_CHANNEL) + (zmqsockets * REGISTRATIONS_PER_ZMQSOCKET); let reactor = Reactor::new(registrations_max); let executor = Executor::new(EXECUTOR_TASKS_MAX); executor .spawn(Self::run( ctx, s1, r2, instance_id, retained_max, init_hwm, other_hwm, handle_bound, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("manager thread end"); }) .unwrap(); Self { handle_bound, thread: Some(thread), control_pipe: Mutex::new((s2, r1)), } } pub fn set_client_req_specs(&mut self, specs: &[SpecInfo]) -> Result<(), String> { self.control_req(ControlRequest::SetClientReq(specs.to_vec())) } pub fn set_client_stream_specs( &mut self, out_specs: &[SpecInfo], out_stream_specs: &[SpecInfo], in_specs: &[SpecInfo], ) -> Result<(), String> { self.control_req(ControlRequest::SetClientStream( out_specs.to_vec(), out_stream_specs.to_vec(), in_specs.to_vec(), )) } pub fn client_req_handle(&self, id_prefix: &[u8]) -> ClientReqHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let pe = ReqPipeEnd { sender: s1, receiver: r2, }; let prefix = ArrayString::from_str(str::from_utf8(id_prefix).unwrap()).unwrap(); self.control_send(ControlRequest::AddClientReqHandle(pe, prefix)); ClientReqHandle { sender: s2, receiver: r1, } } pub fn client_stream_handle(&self, id_prefix: &[u8]) -> ClientStreamHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let (s3, r3) = channel::channel(self.handle_bound); let pe = StreamPipeEnd { sender: s1, receiver_any: r2, receiver_addr: r3, }; let prefix = ArrayString::from_str(str::from_utf8(id_prefix).unwrap()).unwrap(); self.control_send(ControlRequest::AddClientStreamHandle(pe, prefix)); ClientStreamHandle { sender_any: s2, sender_addr: s3, receiver: r1, } } fn control_send(&self, req: ControlRequest) { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this will block if queue is full pipe.0.send(req).unwrap(); } fn control_req(&self, req: ControlRequest) -> Result<(), String> { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this is a blocking exchange pipe.0.send(req).unwrap(); pipe.1.recv().unwrap() } #[allow(clippy::too_many_arguments)] async fn run( ctx: Arc, control_sender: channel::Sender, control_receiver: channel::Receiver, instance_id: String, retained_max: usize, init_hwm: usize, other_hwm: usize, handle_bound: usize, ) { let control_sender = AsyncSender::new(control_sender); let control_receiver = AsyncReceiver::new(control_receiver); // the messages arena needs to fit the max number of potential incoming messages that // still need to be processed. this is the entire channel queue for every handle, plus // the most number of messages the user might retain, plus 1 extra for the next message // we are preparing to send to the handles let arena_size = (HANDLES_MAX * handle_bound) + retained_max + 1; let messages_memory = Arc::new(arena::SyncMemory::new(arena_size)); let client_req = ClientReqSockets { sock: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::DEALER)), }; let client_stream = ClientStreamSockets { out: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::PUSH)), out_stream: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER)), in_: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::SUB)), }; client_req .sock .inner() .inner() .set_sndhwm(init_hwm as i32) .unwrap(); client_req .sock .inner() .inner() .set_rcvhwm(other_hwm as i32) .unwrap(); client_stream .out .inner() .inner() .set_sndhwm(init_hwm as i32) .unwrap(); client_stream .out_stream .inner() .inner() .set_sndhwm(other_hwm as i32) .unwrap(); client_stream .out_stream .inner() .inner() .set_rcvhwm(other_hwm as i32) .unwrap(); client_stream .in_ .inner() .inner() .set_rcvhwm(other_hwm as i32) .unwrap(); client_stream .out_stream .inner() .inner() .set_router_mandatory(true) .unwrap(); // a ROUTER socket may still be writable after returning EAGAIN, which // could mean that a different peer than the one we tried to write to // is writable. there's no way to know when the desired peer will be // writable, so we'll keep trying again after a delay client_stream .out_stream .set_retry_timeout(Some(STREAM_OUT_STREAM_DELAY)); client_stream .out_stream .inner() .inner() .set_identity(instance_id.as_bytes()) .unwrap(); let sub = format!("{} ", instance_id); client_stream .in_ .inner() .inner() .set_subscribe(sub.as_bytes()) .unwrap(); let mut req_handles = ReqHandles::new(HANDLES_MAX); let mut stream_handles = StreamHandles::new(HANDLES_MAX); let mut req_send: Option = None; let mut stream_out_send: Option = None; let mut stream_out_stream_send: Option = None; loop { let req_handles_recv = if req_send.is_none() { Some(req_handles.recv()) } else { None }; let stream_handles_recv_any = if stream_out_send.is_none() { Some(stream_handles.recv_any()) } else { None }; let stream_handles_recv_addr = if stream_out_stream_send.is_none() { Some(stream_handles.recv_addr()) } else { None }; let result = select_10( control_receiver.recv(), select_option(pin!(req_handles_recv).as_pin_mut()), select_option(req_send.as_mut()), client_req.sock.recv_routed(), select_option(pin!(stream_handles_recv_any).as_pin_mut()), select_option(stream_out_send.as_mut()), select_option(pin!(stream_handles_recv_addr).as_pin_mut()), select_option(stream_out_stream_send.as_mut()), client_stream.out_stream.recv_routed(), client_stream.in_.recv(), ) .await; match result { // control_receiver.recv Select10::R1(result) => match result { Ok(req) => match req { ControlRequest::Stop => break, ControlRequest::SetClientReq(specs) => { debug!("applying req specs: {:?}", specs); let result = Self::apply_req_specs(&client_req, &specs); control_sender .send(result) .await .expect("failed to send control response"); } ControlRequest::SetClientStream(out_specs, out_stream_specs, in_specs) => { debug!( "applying stream specs: {:?} {:?} {:?}", out_specs, out_stream_specs, in_specs ); let result = Self::apply_stream_specs( &client_stream, &out_specs, &out_stream_specs, &in_specs, ); control_sender .send(result) .await .expect("failed to send control response"); } ControlRequest::AddClientReqHandle(pe, filter) => { debug!("adding req handle: filter=[{}]", filter); if req_handles.len() + stream_handles.len() < HANDLES_MAX { req_handles.add( AsyncReqPipeEnd { sender: AsyncSender::new(pe.sender), receiver: AsyncReceiver::new(pe.receiver), }, filter, ); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } ControlRequest::AddClientStreamHandle(pe, filter) => { debug!("adding stream handle: filter=[{}]", filter); if req_handles.len() + stream_handles.len() < HANDLES_MAX { stream_handles.add( AsyncStreamPipeEnd { sender: AsyncSender::new(pe.sender), receiver_any: AsyncReceiver::new(pe.receiver_any), receiver_addr: AsyncReceiver::new(pe.receiver_addr), }, filter, ); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } }, Err(e) => error!("control recv: {}", e), }, // req_handles_recv Select10::R2(msg) => { if log_enabled!(log::Level::Trace) { trace!("OUT req {}", packet_to_string(&msg)); } req_send = Some(client_req.sock.send_to(MultipartHeader::new(), msg)); } // req_send Select10::R3(result) => { if let Err(e) = result { error!("req zmq send: {}", e); } req_send = None; } // client_req.sock.recv_routed Select10::R4(result) => match result { Ok((_, msg)) => { if log_enabled!(log::Level::Trace) { trace!("IN req {}", packet_to_string(&msg)); } Self::handle_req_message(msg, &messages_memory, &req_handles).await; } Err(e) => error!("req zmq recv: {}", e), }, // stream_handles_recv_any Select10::R5(msg) => { if log_enabled!(log::Level::Trace) { trace!("OUT stream {}", packet_to_string(&msg)); } stream_out_send = Some(client_stream.out.send(msg)); } // stream_out_send Select10::R6(result) => { if let Err(e) = result { error!("stream zmq send: {}", e); } stream_out_send = None; } // stream_handles_recv_addr Select10::R7((addr, msg)) => { let h = vec![zmq::Message::from(addr.as_slice())]; if log_enabled!(log::Level::Trace) { trace!( "OUT stream to={} {}", String::from_utf8_lossy(addr.as_slice()), packet_to_string(&msg) ); } stream_out_stream_send = Some(client_stream.out_stream.send_to(h, msg)); } // stream_out_stream_send Select10::R8(result) => { match result { Ok(()) => {} Err(zmq::Error::EHOSTUNREACH) => { // this can happen if a known peer goes away debug!("stream zmq send to host unreachable"); } Err(e) => error!("stream zmq send to: {}", e), } stream_out_stream_send = None; } // client_stream.out_stream.recv_routed Select10::R9(result) => match result { Ok((_, msg)) => { if log_enabled!(log::Level::Trace) { trace!("IN stream (router) {}", packet_to_string(&msg)); } Self::handle_stream_message(msg, &messages_memory, None, &stream_handles) .await; } Err(e) => error!("stream (router) zmq recv: {}", e), }, // client_stream.in_.recv Select10::R10(result) => match result { Ok(msg) => { if log_enabled!(log::Level::Trace) { trace!("IN stream {}", packet_to_string(&msg)); } Self::handle_stream_message( msg, &messages_memory, Some(&instance_id), &stream_handles, ) .await; } Err(e) => error!("stream zmq recv: {}", e), }, } if req_handles.need_cleanup() { req_handles.cleanup(|p| { debug!("req handle disconnected: filter=[{}]", p.filter); }); } if stream_handles.need_cleanup() { stream_handles.cleanup(|p| { debug!("stream handle disconnected: filter=[{}]", p.filter); }); } } } fn apply_req_specs(client_req: &ClientReqSockets, specs: &[SpecInfo]) -> Result<(), String> { if let Err(e) = client_req.sock.inner().apply_specs(specs) { return Err(e.to_string()); } Ok(()) } fn apply_stream_specs( client_stream: &ClientStreamSockets, out_specs: &[SpecInfo], out_stream_specs: &[SpecInfo], in_specs: &[SpecInfo], ) -> Result<(), String> { if let Err(e) = client_stream.out.inner().apply_specs(out_specs) { return Err(e.to_string()); } if let Err(e) = client_stream .out_stream .inner() .apply_specs(out_stream_specs) { return Err(e.to_string()); } if let Err(e) = client_stream.in_.inner().apply_specs(in_specs) { return Err(e.to_string()); } Ok(()) } async fn handle_req_message( msg: zmq::Message, messages_memory: &Arc>, handles: &ReqHandles, ) { let msg = arena::Arc::new(msg, messages_memory).unwrap(); let mut scratch = ParseScratch::new(); let (_, ids) = match parse_ids(msg.get(), &mut scratch) { Ok(ret) => ret, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send(&msg, ids).await; } async fn handle_stream_message( msg: zmq::Message, messages_memory: &Arc>, expect_addr: Option<&str>, handles: &StreamHandles, ) { let msg = arena::Arc::new(msg, messages_memory).unwrap(); let buf = msg.get(); let buf = if let Some(expect_addr) = expect_addr { let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = match pos { Some(pos) => pos, None => { warn!("unable to determine packet address"); return; } }; let addr = &buf[..pos]; if addr != expect_addr.as_bytes() { warn!("packet not for us"); return; } &buf[pos + 1..] } else { buf }; let mut scratch = ParseScratch::new(); let (_, ids) = match parse_ids(buf, &mut scratch) { Ok(ret) => ret, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send(&msg, ids, expect_addr.is_none()).await; } } impl Drop for ClientSocketManager { fn drop(&mut self) { self.control_send(ControlRequest::Stop); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } enum ZmqFuture<'a> { Send(ZmqSendFuture<'a>), SendTo(ZmqSendToFuture<'a>), } impl Future for ZmqFuture<'_> { type Output = Result<(), zmq::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { match &mut *self { Self::Send(fut) => Pin::new(fut).poll(cx), Self::SendTo(fut) => Pin::new(fut).poll(cx), } } } pub struct ServerSocketManager { handle_bound: usize, thread: Option>, control_pipe: Mutex<( channel::Sender, channel::Receiver, )>, } impl ServerSocketManager { // retained_max is the maximum number of received messages that the user // will keep around at any moment. for example, if the user plans to // set up 4 handles on the manager and read 1 message at a time from // each of the handles (i.e. process and drop a message before reading // the next), then the value here should be 4, because there would be // no more than 4 dequeued messages alive at any one time. this number // is needed to help size the internal arena pub fn new( ctx: Arc, instance_id: &str, retained_max: usize, init_hwm: usize, other_hwm: usize, handle_bound: usize, stream_maxconn: usize, ) -> Self { let (s1, r1) = channel::channel(1); let (s2, r2) = channel::channel(1); let instance_id = String::from(instance_id); let thread = thread::Builder::new() .name("zhttpsocket".to_string()) .spawn(move || { debug!("server manager thread start"); // 2 control channels, 3 channels per handle, 4 zmq sockets let channels = 2 + (HANDLES_MAX * 3); let zmqsockets = 4; let registrations_max = (channels * REGISTRATIONS_PER_CHANNEL) + (zmqsockets * REGISTRATIONS_PER_ZMQSOCKET); let reactor = Reactor::new(registrations_max); let executor = Executor::new(EXECUTOR_TASKS_MAX); executor .spawn(Self::run( ctx, s1, r2, instance_id, retained_max, init_hwm, other_hwm, handle_bound, stream_maxconn, )) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); debug!("server manager thread end"); }) .unwrap(); Self { handle_bound, thread: Some(thread), control_pipe: Mutex::new((s2, r1)), } } pub fn set_server_req_specs(&mut self, specs: &[SpecInfo]) -> Result<(), String> { self.control_req(ServerControlRequest::SetServerReq(specs.to_vec())) } pub fn set_server_stream_specs( &self, in_specs: &[SpecInfo], in_stream_specs: &[SpecInfo], out_specs: &[SpecInfo], ) -> Result<(), String> { self.control_req(ServerControlRequest::SetServerStream( in_specs.to_vec(), in_stream_specs.to_vec(), out_specs.to_vec(), )) } pub fn server_req_handle(&self) -> ServerReqHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let pe = ServerReqPipeEnd { sender: s1, receiver: r2, }; self.control_send(ServerControlRequest::AddServerReqHandle(pe)); ServerReqHandle { sender: s2, receiver: r1, } } pub fn server_stream_handle(&self) -> ServerStreamHandle { let (s1, r1) = channel::channel(self.handle_bound); let (s2, r2) = channel::channel(self.handle_bound); let (s3, r3) = channel::channel(self.handle_bound); let pe = ServerStreamPipeEnd { sender_any: s1, sender_direct: s2, receiver: r3, }; self.control_send(ServerControlRequest::AddServerStreamHandle(pe)); ServerStreamHandle { sender: s3, receiver_any: r1, receiver_direct: r2, } } fn control_send(&self, req: ServerControlRequest) { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this will block if queue is full pipe.0.send(req).unwrap(); } fn control_req(&self, req: ServerControlRequest) -> Result<(), String> { let pipe = self.control_pipe.lock().unwrap(); // NOTE: this is a blocking exchange pipe.0.send(req).unwrap(); pipe.1.recv().unwrap() } #[allow(clippy::too_many_arguments)] async fn run( ctx: Arc, control_sender: channel::Sender, control_receiver: channel::Receiver, instance_id: String, retained_max: usize, init_hwm: usize, other_hwm: usize, handle_bound: usize, stream_maxconn: usize, ) { let control_sender = AsyncSender::new(control_sender); let control_receiver = AsyncReceiver::new(control_receiver); // the messages arena needs to fit the max number of potential incoming messages that // still need to be processed. this is the entire channel queue for every handle, plus // the most number of messages the user might retain, plus 1 extra for the next message // we are preparing to send to the handles, x2 since there are two sending channels // per stream handle let arena_size = ((HANDLES_MAX * handle_bound) + retained_max + 1) * 2; let messages_memory = Arc::new(arena::SyncMemory::new(arena_size)); // sessions are created at the time of attempting to send to a handle, so we need enough // sessions to max out the workers, and max out all the handle channels, and have one // left to use when attempting to send let sessions_max = stream_maxconn + (HANDLES_MAX * handle_bound) + 1; let req_sock = AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER)); let mut stream_socks = ServerStreamSockets { in_: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::PULL)), in_stream: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::ROUTER)), out: AsyncZmqSocket::new(ZmqSocket::new(&ctx, zmq::PUB)), specs_applied: false, }; req_sock .inner() .inner() .set_sndhwm(init_hwm as i32) .unwrap(); req_sock .inner() .inner() .set_rcvhwm(other_hwm as i32) .unwrap(); stream_socks .in_ .inner() .inner() .set_rcvhwm(init_hwm as i32) .unwrap(); stream_socks .in_stream .inner() .inner() .set_sndhwm(other_hwm as i32) .unwrap(); stream_socks .in_stream .inner() .inner() .set_rcvhwm(other_hwm as i32) .unwrap(); stream_socks .out .inner() .inner() .set_sndhwm(other_hwm as i32) .unwrap(); stream_socks .in_stream .inner() .inner() .set_router_mandatory(true) .unwrap(); // a ROUTER socket may still be writable after returning EAGAIN, which // could mean that a different peer than the one we tried to write to // is writable. there's no way to know when the desired peer will be // writable, so we'll keep trying again after a delay stream_socks .in_stream .set_retry_timeout(Some(STREAM_OUT_STREAM_DELAY)); stream_socks .in_stream .inner() .inner() .set_identity(instance_id.as_bytes()) .unwrap(); let mut req_handles = ServerReqHandles::new(HANDLES_MAX); let mut stream_handles = ServerStreamHandles::new(HANDLES_MAX, sessions_max); let mut req_send: Option = None; let mut stream_out_send: Option = None; let mut req_in_msg = None; let mut stream_in_msg = None; loop { let req_recv_routed = if req_in_msg.is_none() { Some(req_sock.recv_routed()) } else { None }; let req_handles_recv = if req_send.is_none() { Some(req_handles.recv()) } else { None }; let req_handles_check_send = if req_in_msg.is_some() { Some(req_handles.check_send()) } else { None }; let stream_in_recv = if stream_in_msg.is_none() { Some(stream_socks.in_.recv()) } else { None }; let stream_handles_recv = if stream_out_send.is_none() { Some(stream_handles.recv()) } else { None }; let stream_handles_check_send_any = if stream_in_msg.is_some() { Some(stream_handles.check_send_any()) } else { None }; let result = select_10( control_receiver.recv(), select_option(pin!(req_recv_routed).as_pin_mut()), select_option(pin!(req_handles_recv).as_pin_mut()), select_option(req_send.as_mut()), select_option(pin!(req_handles_check_send).as_pin_mut()), select_option(pin!(stream_in_recv).as_pin_mut()), stream_socks.in_stream.recv_routed(), select_option(pin!(stream_handles_recv).as_pin_mut()), select_option(stream_out_send.as_mut()), select_option(pin!(stream_handles_check_send_any).as_pin_mut()), ) .await; match result { // control_receiver.recv Select10::R1(result) => match result { Ok(req) => match req { ServerControlRequest::Stop => break, ServerControlRequest::SetServerReq(specs) => { debug!("applying server req specs: {:?}", specs); let result = Self::apply_req_specs(&req_sock, &specs); control_sender .send(result) .await .expect("failed to send control response"); } ServerControlRequest::SetServerStream( in_specs, in_stream_specs, out_specs, ) => { debug!( "applying server stream specs: {:?} {:?} {:?}", in_specs, in_stream_specs, out_specs ); stream_socks.specs_applied = true; let result = Self::apply_stream_specs( &stream_socks, &in_specs, &in_stream_specs, &out_specs, ); control_sender .send(result) .await .expect("failed to send control response"); } ServerControlRequest::AddServerReqHandle(pe) => { debug!("adding server req handle"); if req_handles.len() + stream_handles.len() < HANDLES_MAX { req_handles.add(AsyncServerReqPipeEnd { sender: AsyncSender::new(pe.sender), receiver: AsyncReceiver::new(pe.receiver), }); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } ServerControlRequest::AddServerStreamHandle(pe) => { debug!("adding server stream handle"); if !stream_socks.specs_applied { if req_handles.len() + stream_handles.len() < HANDLES_MAX { stream_handles.add(AsyncServerStreamPipeEnd { sender_any: AsyncSender::new(pe.sender_any), sender_direct: AsyncSender::new(pe.sender_direct), receiver: AsyncReceiver::new(pe.receiver), }); } else { error!("cannot add more than {} handles", HANDLES_MAX); } } else { error!("cannot add handle after specs have been applied"); } } }, Err(e) => error!("control recv: {}", e), }, // req_recv_routed Select10::R2(result) => match result { Ok((header, msg)) => { if log_enabled!(log::Level::Trace) { trace!("IN server req {}", packet_to_string(&msg)); } let msg = arena::Arc::new(msg, &messages_memory).unwrap(); req_in_msg = Some((header, msg)); } Err(e) => error!("server req zmq recv: {}", e), }, // req_handles_recv Select10::R3((header, msg)) => { if log_enabled!(log::Level::Trace) { trace!("OUT server req {}", packet_to_string(&msg)); } req_send = Some(req_sock.send_to(header, msg)); } // req_send Select10::R4(result) => { if let Err(e) = result { error!("server req zmq send: {}", e); } req_send = None; } // req_handles_check_send Select10::R5(()) => Self::handle_req_message(&mut req_in_msg, &req_handles), // stream_in_recv Select10::R6(result) => match result { Ok(msg) => { if log_enabled!(log::Level::Trace) { trace!("IN server stream {}", packet_to_string(&msg)); } let msg = arena::Arc::new(msg, &messages_memory).unwrap(); stream_in_msg = Some(msg); } Err(e) => error!("server stream zmq recv: {}", e), }, // stream_socks.in_stream.recv_routed Select10::R7(result) => match result { Ok((_, msg)) => { if log_enabled!(log::Level::Trace) { trace!("IN server stream next {}", packet_to_string(&msg)); } Self::handle_stream_message_direct(msg, &messages_memory, &stream_handles) .await; } Err(e) => error!("server stream next zmq recv: {}", e), }, // stream_handles_recv Select10::R8((addr, msg)) => { if let Some(addr) = &addr { let h = vec![zmq::Message::from(addr.as_ref())]; if log_enabled!(log::Level::Trace) { trace!( "OUT server stream to={} {}", String::from_utf8_lossy(addr), packet_to_string(&msg) ); } stream_out_send = Some(ZmqFuture::SendTo(stream_socks.in_stream.send_to(h, msg))); } else { if log_enabled!(log::Level::Trace) { trace!("OUT server stream {}", packet_to_string(&msg)); } stream_out_send = Some(ZmqFuture::Send(stream_socks.out.send(msg))); } } // stream_out_send Select10::R9(result) => { match result { Ok(()) => {} Err(zmq::Error::EHOSTUNREACH) => { // this can happen if a known peer goes away debug!("server stream zmq send to host unreachable"); } Err(e) => error!("server stream zmq send: {}", e), } stream_out_send = None; } // stream_handles_check_send_any Select10::R10(()) => { Self::handle_stream_message_any(&mut stream_in_msg, &stream_handles); } } if req_handles.need_cleanup() { req_handles.cleanup(|_| debug!("server req handle disconnected")); } if stream_handles.need_cleanup() { stream_handles.cleanup(|_| debug!("server stream handle disconnected")); } } } fn apply_req_specs(sock: &AsyncZmqSocket, specs: &[SpecInfo]) -> Result<(), String> { if let Err(e) = sock.inner().apply_specs(specs) { return Err(e.to_string()); } Ok(()) } fn apply_stream_specs( socks: &ServerStreamSockets, in_specs: &[SpecInfo], in_stream_specs: &[SpecInfo], out_specs: &[SpecInfo], ) -> Result<(), String> { if let Err(e) = socks.in_.inner().apply_specs(in_specs) { return Err(e.to_string()); } if let Err(e) = socks.in_stream.inner().apply_specs(in_stream_specs) { return Err(e.to_string()); } if let Err(e) = socks.out.inner().apply_specs(out_specs) { return Err(e.to_string()); } Ok(()) } fn handle_req_message( next_msg: &mut Option<(MultipartHeader, arena::Arc)>, handles: &ServerReqHandles, ) { let (header, msg) = next_msg.take().unwrap(); if let Err(ReqHandlesSendError(header)) = handles.send(header, &msg) { *next_msg = Some((header, msg)); } } fn handle_stream_message_any( next_msg: &mut Option>, handles: &ServerStreamHandles, ) { let msg = next_msg.take().unwrap(); let ret = { let mut scratch = ParseScratch::new(); let (from, ids) = match parse_ids(msg.get(), &mut scratch) { Ok(ret) => ret, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send_any(&msg, from, ids) }; match ret { Ok(()) => {} Err(StreamHandlesSendError::BadFormat) => warn!("stream send_any: bad format"), Err(StreamHandlesSendError::NoneReady) => *next_msg = Some(msg), Err(StreamHandlesSendError::SessionExists) => { warn!("stream send_any: session id in use") } Err(StreamHandlesSendError::SessionCapacityFull) => { error!("stream send_any: session capacity full") } } } async fn handle_stream_message_direct( msg: zmq::Message, messages_memory: &Arc>, handles: &ServerStreamHandles, ) { let msg = arena::Arc::new(msg, messages_memory).unwrap(); let mut scratch = ParseScratch::new(); let (from, ids) = match parse_ids(msg.get(), &mut scratch) { Ok(ret) => ret, Err(e) => { warn!("unable to determine packet id(s): {}", e); return; } }; handles.send_direct(&msg, from, ids).await; } } impl Drop for ServerSocketManager { fn drop(&mut self) { self.control_send(ServerControlRequest::Stop); let thread = self.thread.take().unwrap(); thread.join().unwrap(); } } #[derive(Debug)] pub enum SendError { Full(zmq::Message), Io(io::Error), } pub struct ClientReqHandle { sender: channel::Sender, receiver: channel::Receiver>, } impl ClientReqHandle { pub fn get_read_registration(&self) -> &event::Registration { self.receiver.get_read_registration() } pub fn get_write_registration(&self) -> &event::Registration { self.sender.get_write_registration() } pub fn recv(&self) -> Result, io::Error> { match self.receiver.try_recv() { Ok(msg) => Ok(msg), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send(&self, msg: zmq::Message) -> Result<(), SendError> { match self.sender.try_send(msg) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncClientReqHandle { sender: AsyncSender, receiver: AsyncReceiver>, } impl AsyncClientReqHandle { pub fn new(h: ClientReqHandle) -> Self { Self { sender: AsyncSender::new(h.sender), receiver: AsyncReceiver::new(h.receiver), } } pub async fn recv(&self) -> Result, io::Error> { match self.receiver.recv().await { Ok(msg) => Ok(msg), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send(&self, msg: zmq::Message) -> Result<(), io::Error> { match self.sender.send(msg).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } pub struct ClientStreamHandle { sender_any: channel::Sender, sender_addr: channel::Sender<(ArrayVec, zmq::Message)>, receiver: channel::Receiver<(arena::Arc, bool)>, } impl ClientStreamHandle { pub fn get_read_registration(&self) -> &event::Registration { self.receiver.get_read_registration() } pub fn get_write_any_registration(&self) -> &event::Registration { self.sender_any.get_write_registration() } pub fn get_write_addr_registration(&self) -> &event::Registration { self.sender_addr.get_write_registration() } pub fn recv(&self) -> Result<(arena::Arc, bool), io::Error> { match self.receiver.try_recv() { Ok(ret) => Ok(ret), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send_to_any(&self, msg: zmq::Message) -> Result<(), SendError> { match self.sender_any.try_send(msg) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full(msg)) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } pub fn send_to_addr(&self, addr: &[u8], msg: zmq::Message) -> Result<(), SendError> { let addr = match ArrayVec::try_from(addr) { Ok(a) => a, Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))), }; match self.sender_addr.try_send((addr, msg)) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncClientStreamHandle { sender_any: AsyncSender, sender_addr: AsyncSender<(ArrayVec, zmq::Message)>, receiver: AsyncReceiver<(arena::Arc, bool)>, } impl AsyncClientStreamHandle { pub fn new(h: ClientStreamHandle) -> Self { Self { sender_any: AsyncSender::new(h.sender_any), sender_addr: AsyncSender::new(h.sender_addr), receiver: AsyncReceiver::new(h.receiver), } } pub async fn recv(&self) -> Result<(arena::Arc, bool), io::Error> { self.receiver .recv() .await .map_err(|_| io::Error::from(io::ErrorKind::BrokenPipe)) } pub async fn send_to_any(&self, msg: zmq::Message) -> Result<(), io::Error> { match self.sender_any.send(msg).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send_to_addr( &self, addr: ArrayVec, msg: zmq::Message, ) -> Result<(), io::Error> { match self.sender_addr.send((addr, msg)).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } pub struct ServerReqHandle { sender: channel::Sender<(MultipartHeader, zmq::Message)>, receiver: channel::Receiver<(MultipartHeader, arena::Arc)>, } impl ServerReqHandle { pub fn get_read_registration(&self) -> &event::Registration { self.receiver.get_read_registration() } pub fn get_write_registration(&self) -> &event::Registration { self.sender.get_write_registration() } pub fn recv(&self) -> Result<(MultipartHeader, arena::Arc), io::Error> { match self.receiver.try_recv() { Ok(ret) => Ok(ret), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send(&self, header: MultipartHeader, msg: zmq::Message) -> Result<(), SendError> { match self.sender.try_send((header, msg)) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncServerReqHandle { sender: AsyncSender<(MultipartHeader, zmq::Message)>, receiver: AsyncReceiver<(MultipartHeader, arena::Arc)>, } impl AsyncServerReqHandle { pub fn new(h: ServerReqHandle) -> Self { Self { sender: AsyncSender::new(h.sender), receiver: AsyncReceiver::new(h.receiver), } } pub async fn recv(&self) -> Result<(MultipartHeader, arena::Arc), io::Error> { match self.receiver.recv().await { Ok(msg) => Ok(msg), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send(&self, header: MultipartHeader, msg: zmq::Message) -> Result<(), io::Error> { match self.sender.send((header, msg)).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } pub struct ServerStreamHandle { sender: channel::Sender<(Option>, zmq::Message)>, receiver_any: channel::Receiver<(arena::Arc, Session)>, receiver_direct: channel::Receiver>, } impl ServerStreamHandle { pub fn get_read_any_registration(&self) -> &event::Registration { self.receiver_any.get_read_registration() } pub fn get_read_direct_registration(&self) -> &event::Registration { self.receiver_direct.get_read_registration() } pub fn get_write_registration(&self) -> &event::Registration { self.sender.get_write_registration() } pub fn recv_from_any(&self) -> Result<(arena::Arc, Session), io::Error> { match self.receiver_any.try_recv() { Ok(ret) => Ok(ret), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn recv_directed(&self) -> Result, io::Error> { match self.receiver_direct.try_recv() { Ok(msg) => Ok(msg), Err(mpsc::TryRecvError::Empty) => Err(io::Error::from(io::ErrorKind::WouldBlock)), Err(mpsc::TryRecvError::Disconnected) => { Err(io::Error::from(io::ErrorKind::BrokenPipe)) } } } pub fn send(&self, addr: Option<&[u8]>, msg: zmq::Message) -> Result<(), SendError> { let addr = match addr { Some(a) => match ArrayVec::try_from(a) { Ok(a) => Some(a), Err(_) => return Err(SendError::Io(io::Error::from(io::ErrorKind::InvalidInput))), }, None => None, }; match self.sender.try_send((addr, msg)) { Ok(_) => Ok(()), Err(mpsc::TrySendError::Full((_, msg))) => Err(SendError::Full(msg)), Err(mpsc::TrySendError::Disconnected(_)) => { Err(SendError::Io(io::Error::from(io::ErrorKind::BrokenPipe))) } } } } pub struct AsyncServerStreamHandle { sender: AsyncSender<(Option>, zmq::Message)>, receiver_any: AsyncReceiver<(arena::Arc, Session)>, receiver_direct: AsyncReceiver>, } impl AsyncServerStreamHandle { pub fn new(h: ServerStreamHandle) -> Self { Self { sender: AsyncSender::new(h.sender), receiver_any: AsyncReceiver::new(h.receiver_any), receiver_direct: AsyncReceiver::new(h.receiver_direct), } } pub async fn recv_from_any(&self) -> Result<(arena::Arc, Session), io::Error> { match self.receiver_any.recv().await { Ok(ret) => Ok(ret), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn recv_directed(&self) -> Result, io::Error> { match self.receiver_direct.recv().await { Ok(msg) => Ok(msg), Err(mpsc::RecvError) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } pub async fn send( &self, addr: Option>, msg: zmq::Message, ) -> Result<(), io::Error> { match self.sender.send((addr, msg)).await { Ok(_) => Ok(()), Err(mpsc::SendError(_)) => Err(io::Error::from(io::ErrorKind::BrokenPipe)), } } } #[cfg(test)] mod tests { use super::*; use crate::connmgr::zhttppacket::{ PacketParse, Request, RequestData, RequestPacket, Response, ResponsePacket, }; use crate::core::event; use test_log::test; fn wait_readable(poller: &mut event::Poller, token: mio::Token) { loop { poller.poll(None).unwrap(); for event in poller.iter_events() { if event.token() == token && event.is_readable() { return; } } } } fn wait_writable(poller: &mut event::Poller, token: mio::Token) { loop { poller.poll(None).unwrap(); for event in poller.iter_events() { if event.token() == token && event.is_writable() { return; } } } } #[test] fn test_client_send_flow() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = ClientSocketManager::new(Arc::clone(&zmq_context), "test", 1, 1, 1, 1); zsockman .set_client_stream_specs( &vec![SpecInfo { spec: String::from("inproc://flow-test-out"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://flow-test-out-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://flow-test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); // connect an out-stream receiver. the other sockets we'll leave alone let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock .set_identity("test-handler".as_bytes()) .unwrap(); in_stream_sock.set_rcvhwm(1).unwrap(); in_stream_sock .connect("inproc://flow-test-out-stream") .unwrap(); let h = zsockman.client_stream_handle(b"a-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h.get_write_addr_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); // write four times, which will all succeed eventually. after this // we'll have filled the handle, the manager's temporary variable, // and the HWMs of both the sending and receiving zmq sockets for i in 1..=4 { loop { match h.send_to_addr( "test-handler".as_bytes(), zmq::Message::from(format!("{}", i).into_bytes()), ) { Ok(()) => break, Err(SendError::Full(_)) => wait_writable(&mut poller, mio::Token(1)), Err(SendError::Io(e)) => panic!("{:?}", e), } } } // once we were able to write a fourth time, this means the manager // has started processing the third message. let's wait a short bit // for the manager to attempt to send the third message to the zmq // socket and fail with EAGAIN thread::sleep(Duration::from_millis(10)); // fifth write will fail. there's no room let e = h .send_to_addr( "test-handler".as_bytes(), zmq::Message::from("5".as_bytes()), ) .unwrap_err(); let msg = match e { SendError::Full(msg) => msg, _ => panic!("unexpected error"), }; assert_eq!(str::from_utf8(&msg).unwrap(), "5"); // blocking read from the zmq socket so another message can flow let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"1"); // fifth write will now succeed, eventually loop { match h.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("5".as_bytes()), ) { Ok(()) => break, Err(SendError::Full(_)) => wait_writable(&mut poller, mio::Token(1)), Err(SendError::Io(e)) => panic!("{:?}", e), } } // read the rest of the messages for i in 2..=5 { let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], format!("{}", i).as_bytes()); } } #[test] fn test_client_req() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = ClientSocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100, 100); zsockman .set_client_req_specs(&vec![SpecInfo { spec: String::from("inproc://test-req"), bind: true, ipc_file_mode: 0, }]) .unwrap(); let h1 = zsockman.client_req_handle(b"a-"); let h2 = zsockman.client_req_handle(b"b-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); let rep_sock = zmq_context.socket(zmq::REP).unwrap(); rep_sock.connect("inproc://test-req").unwrap(); h1.send(zmq::Message::from("hello a".as_bytes())).unwrap(); let parts = rep_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello a"); rep_sock .send("T26:2:id,3:a-1,4:body,5:world,}".as_bytes(), 0) .unwrap(); let msg; loop { match h1.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let mut scratch = ParseScratch::new(); let resp = Response::parse(&msg, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); h2.send(zmq::Message::from("hello b".as_bytes())).unwrap(); let parts = rep_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello b"); rep_sock .send("T26:2:id,3:b-1,4:body,5:world,}".as_bytes(), 0) .unwrap(); let msg; loop { match h2.recv() { Ok(m) => { msg = m; break; } Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), }; } let msg = msg.get(); let mut scratch = ParseScratch::new(); let resp = Response::parse(&msg, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); drop(h1); drop(h2); drop(zsockman); } #[test] fn test_client_stream() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = ClientSocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100, 100); zsockman .set_client_stream_specs( &vec![SpecInfo { spec: String::from("inproc://test-out"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-out-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-in"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let h1 = zsockman.client_stream_handle(b"a-"); let h2 = zsockman.client_stream_handle(b"b-"); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); let in_sock = zmq_context.socket(zmq::PULL).unwrap(); in_sock.connect("inproc://test-out").unwrap(); let in_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); in_stream_sock .set_identity("test-handler".as_bytes()) .unwrap(); in_stream_sock.connect("inproc://test-out-stream").unwrap(); let out_sock = zmq_context.socket(zmq::XPUB).unwrap(); out_sock.connect("inproc://test-in").unwrap(); // ensure zsockman is subscribed let msg = out_sock.recv_msg(0).unwrap(); assert_eq!(&msg[..], b"\x01test "); h1.send_to_any(zmq::Message::from("hello a".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello a"); out_sock .send( "test T49:4:from,12:test-handler,2:id,3:a-1,4:body,5:world,}".as_bytes(), 0, ) .unwrap(); let (msg, from_router) = loop { match h1.recv() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), }; }; assert!(!from_router); let msg = msg.get(); let buf = &msg; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = pos.unwrap(); let buf = &buf[pos + 1..]; let mut scratch = ParseScratch::new(); let resp = Response::parse(buf, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); // send via router in_stream_sock .send_multipart( [ "test".as_bytes(), &[], "T52:4:from,12:test-handler,2:id,3:a-2,4:body,8:world a2,}".as_bytes(), ], 0, ) .unwrap(); let (msg, from_router) = loop { match h1.recv() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), }; }; assert!(from_router); let msg = msg.get(); let buf = &msg; let mut scratch = ParseScratch::new(); let resp = Response::parse(buf, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world a2"); h2.send_to_any(zmq::Message::from("hello b".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"hello b"); out_sock .send( "test T49:4:from,12:test-handler,2:id,3:b-1,4:body,5:world,}".as_bytes(), 0, ) .unwrap(); let (msg, from_router) = loop { match h2.recv() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), }; }; assert!(!from_router); let msg = msg.get(); let buf = &msg; let mut pos = None; for (i, b) in buf.iter().enumerate() { if *b == b' ' { pos = Some(i); break; } } let pos = pos.unwrap(); let buf = &buf[pos + 1..]; let mut scratch = ParseScratch::new(); let resp = Response::parse(buf, &mut scratch).unwrap(); let rdata = match resp.ptype { ResponsePacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"world"); h1.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("hello a".as_bytes()), ) .unwrap(); let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"hello a"); h2.send_to_addr( "test-handler".as_bytes(), zmq::Message::from("hello b".as_bytes()), ) .unwrap(); let parts = in_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"hello b"); drop(h1); drop(h2); drop(zsockman); } #[test] fn test_server_req() { let zmq_context = Arc::new(zmq::Context::new()); let mut zsockman = ServerSocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100, 100, 0); let h1 = zsockman.server_req_handle(); let h2 = zsockman.server_req_handle(); zsockman .set_server_req_specs(&vec![SpecInfo { spec: String::from("inproc://test-server-req"), bind: true, ipc_file_mode: 0, }]) .unwrap(); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); let req_sock = zmq_context.socket(zmq::REQ).unwrap(); req_sock.connect("inproc://test-server-req").unwrap(); req_sock.send("hello a".as_bytes(), 0).unwrap(); let (header, msg) = loop { match h1.recv() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let msg: &[u8] = msg.as_ref(); assert_eq!(msg, b"hello a"); h1.send(header, zmq::Message::from("world a".as_bytes())) .unwrap(); let parts = req_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"world a"); req_sock.send("hello b".as_bytes(), 0).unwrap(); let (header, msg) = loop { match h2.recv() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let msg: &[u8] = msg.as_ref(); assert_eq!(msg, b"hello b"); h2.send(header, zmq::Message::from("world b".as_bytes())) .unwrap(); let parts = req_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"world b"); drop(h1); drop(h2); drop(zsockman); } #[test] fn test_server_stream() { let zmq_context = Arc::new(zmq::Context::new()); let zsockman = ServerSocketManager::new(Arc::clone(&zmq_context), "test", 1, 100, 100, 100, 2); let h1 = zsockman.server_stream_handle(); let h2 = zsockman.server_stream_handle(); zsockman .set_server_stream_specs( &vec![SpecInfo { spec: String::from("inproc://test-server-in"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-server-in-stream"), bind: true, ipc_file_mode: 0, }], &vec![SpecInfo { spec: String::from("inproc://test-server-out"), bind: true, ipc_file_mode: 0, }], ) .unwrap(); let mut poller = event::Poller::new(1024).unwrap(); poller .register_custom( h1.get_read_any_registration(), mio::Token(1), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h1.get_read_direct_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_any_registration(), mio::Token(3), mio::Interest::READABLE, ) .unwrap(); poller .register_custom( h2.get_read_direct_registration(), mio::Token(4), mio::Interest::READABLE, ) .unwrap(); let out_sock = zmq_context.socket(zmq::PUSH).unwrap(); out_sock.connect("inproc://test-server-in").unwrap(); let out_stream_sock = zmq_context.socket(zmq::ROUTER).unwrap(); out_stream_sock.set_identity(b"test-handler").unwrap(); out_stream_sock .connect("inproc://test-server-in-stream") .unwrap(); let in_sock = zmq_context.socket(zmq::SUB).unwrap(); in_sock.connect("inproc://test-server-out").unwrap(); in_sock.set_subscribe(b"test-handler ").unwrap(); // ensure we are subscribed thread::sleep(Duration::from_millis(100)); let req = { let mut rdata = RequestData::new(); rdata.body = b"hello"; let mut dest = [0; 1024]; let size = Request::new_data( b"test-handler", &[Id { id: b"a-1", seq: None, }], rdata, ) .serialize(&mut dest) .unwrap(); dest[..size].to_vec() }; out_sock.send(req, 0).unwrap(); let (msg, sess_a) = loop { match h1.recv_from_any() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(1)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let mut scratch = ParseScratch::new(); let req = Request::parse(msg, &mut scratch).unwrap(); let rdata = match req.ptype { RequestPacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"hello"); h1.send(None, zmq::Message::from("test-handler world a".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"test-handler world a"); // send via router h1.send( Some(b"test-handler"), zmq::Message::from("world a2".as_bytes()), ) .unwrap(); let parts = out_stream_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 3); assert!(parts[1].is_empty()); assert_eq!(parts[2], b"world a2"); let req = { let mut rdata = RequestData::new(); rdata.body = b"hello"; let mut dest = [0; 1024]; let size = Request::new_data( b"test-handler", &[Id { id: b"b-1", seq: None, }], rdata, ) .serialize(&mut dest) .unwrap(); dest[..size].to_vec() }; out_sock.send(req, 0).unwrap(); let (msg, sess_b) = loop { match h2.recv_from_any() { Ok(ret) => break ret, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(3)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let mut scratch = ParseScratch::new(); let req = Request::parse(msg, &mut scratch).unwrap(); let rdata = match req.ptype { RequestPacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"hello"); h2.send(None, zmq::Message::from("test-handler world b".as_bytes())) .unwrap(); let parts = in_sock.recv_multipart(0).unwrap(); assert_eq!(parts.len(), 1); assert_eq!(parts[0], b"test-handler world b"); let req = { let mut rdata = RequestData::new(); rdata.body = b"hello a"; let mut dest = [0; 1024]; let size = Request::new_data( b"test-handler", &[Id { id: b"a-1", seq: None, }], rdata, ) .serialize(&mut dest) .unwrap(); dest[..size].to_vec() }; out_stream_sock .send_multipart(["test".as_bytes(), &[], &req], 0) .unwrap(); let msg = loop { match h1.recv_directed() { Ok(m) => break m, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(2)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let mut scratch = ParseScratch::new(); let req = Request::parse(msg, &mut scratch).unwrap(); let rdata = match req.ptype { RequestPacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"hello a"); let req = { let mut rdata = RequestData::new(); rdata.body = b"hello b"; let mut dest = [0; 1024]; let size = Request::new_data( b"test-handler", &[Id { id: b"b-1", seq: None, }], rdata, ) .serialize(&mut dest) .unwrap(); dest[..size].to_vec() }; out_stream_sock .send_multipart(["test".as_bytes(), &[], &req], 0) .unwrap(); let msg = loop { match h2.recv_directed() { Ok(m) => break m, Err(e) if e.kind() == io::ErrorKind::WouldBlock => { wait_readable(&mut poller, mio::Token(4)); continue; } Err(e) => panic!("recv: {}", e), } }; let msg = msg.get(); let mut scratch = ParseScratch::new(); let req = Request::parse(msg, &mut scratch).unwrap(); let rdata = match req.ptype { RequestPacket::Data(data) => data, _ => panic!("expected data packet"), }; assert_eq!(rdata.body, b"hello b"); drop(sess_a); drop(sess_b); drop(h1); drop(h2); drop(zsockman); } } pushpin-1.41.0/src/core/000077500000000000000000000000001504671364300150135ustar00rootroot00000000000000pushpin-1.41.0/src/core/arena.rs000066400000000000000000000411761504671364300164600ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use slab::Slab; use std::cell::{RefCell, RefMut}; use std::mem; use std::ops::{Deref, DerefMut}; use std::sync::{Mutex, MutexGuard}; pub struct EntryGuard<'a, T> { entries: RefMut<'a, Slab>, entry: &'a mut T, key: usize, } impl EntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for EntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for EntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a sharable slab for use within a single thread. // operations are protected by a RefCell. when an element is retrieved for // reading or modification, it is wrapped in a EntryGuard which keeps the // entire slab borrowed until the caller is done working with the element pub struct Memory { entries: RefCell>, } impl Memory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: RefCell::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.borrow(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.borrow_mut(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.borrow_mut(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in EntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the EntryGuard // and we can extend the lifetime of the reference beyond the // RefMut let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(EntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.borrow(); entries.get(0).unwrap() as *const T } } pub struct SyncEntryGuard<'a, T> { entries: MutexGuard<'a, Slab>, entry: &'a mut T, key: usize, } impl SyncEntryGuard<'_, T> { fn remove(mut self) { self.entries.remove(self.key); } } impl Deref for SyncEntryGuard<'_, T> { type Target = T; fn deref(&self) -> &Self::Target { self.entry } } impl DerefMut for SyncEntryGuard<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.entry } } // this is essentially a thread-safe slab. operations are protected by a // mutex. when an element is retrieved for reading or modification, it is // wrapped in a EntryGuard which keeps the entire slab locked until the // caller is done working with the element pub struct SyncMemory { entries: Mutex>, } impl SyncMemory { pub fn new(capacity: usize) -> Self { // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new(s), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.len() } fn insert(&self, e: T) -> Result { let mut entries = self.entries.lock().unwrap(); // out of capacity. by preventing inserts beyond the capacity, we // ensure the underlying memory won't get moved due to a realloc if entries.len() == entries.capacity() { return Err(()); } Ok(entries.insert(e)) } fn get<'a>(&'a self, key: usize) -> Option> { let mut entries = self.entries.lock().unwrap(); let entry = entries.get_mut(key)?; // slab element addresses are guaranteed to be stable once created, // and the only place we remove the element is in SyncEntryGuard's // remove method which consumes itself, therefore it is safe to // assume the element will live at least as long as the SyncEntryGuard // and we can extend the lifetime of the reference beyond the // MutexGuard let entry = unsafe { mem::transmute::<&mut T, &'a mut T>(entry) }; Some(SyncEntryGuard { entries, entry, key, }) } // for tests, as a way to confirm the memory isn't moving. be careful // with this. the very first element inserted will be at index 0, but // if the slab has been used and cleared, then the next element // inserted may not be at index 0 and calling this method afterward // will panic #[cfg(test)] fn entry0_ptr(&self) -> *const T { let entries = self.entries.lock().unwrap(); entries.get(0).unwrap() as *const T } } pub struct ReusableValue { reusable: std::sync::Arc>, value: *mut T, key: usize, } impl ReusableValue { // vec element addresses are guaranteed to be stable once created, // and elements are only removed when the Reusable is dropped, and // the Arc'd Reusable is guaranteed to live as long as // ReusableValue, therefore it is safe to assume the element will // live at least as long as the ReusableValue fn get(&self) -> &T { unsafe { &*self.value } } fn get_mut(&mut self) -> &mut T { unsafe { &mut *self.value } } } impl Drop for ReusableValue { fn drop(&mut self) { let mut entries = self.reusable.entries.lock().unwrap(); entries.0.remove(self.key); } } impl Deref for ReusableValue { type Target = T; fn deref(&self) -> &Self::Target { self.get() } } impl DerefMut for ReusableValue { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } // like Memory, but for preinitializing each value and reusing pub struct Reusable { entries: Mutex<(Slab<()>, Vec)>, } impl Reusable { pub fn new(capacity: usize, init_fn: F) -> Self where F: Fn() -> T, { let mut values = Vec::with_capacity(capacity); for _ in 0..capacity { values.push(init_fn()); } // allocate the slab with fixed capacity let s = Slab::with_capacity(capacity); Self { entries: Mutex::new((s, values)), } } #[cfg(test)] pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.0.len() } #[allow(clippy::result_unit_err)] pub fn reserve(self: &std::sync::Arc) -> Result, ()> { let mut entries = self.entries.lock().unwrap(); // out of capacity. the number of buffers is fixed if entries.0.len() == entries.0.capacity() { return Err(()); } let key = entries.0.insert(()); let value = &mut entries.1[key] as *mut T; Ok(ReusableValue { reusable: self.clone(), value, key, }) } } pub struct RcEntry { value: T, refs: usize, } pub type RcMemory = Memory>; pub struct Rc { memory: std::rc::Rc>, key: usize, } impl Rc { #[allow(clippy::result_unit_err)] pub fn new(v: T, memory: &std::rc::Rc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: std::rc::Rc::clone(memory), key, }) } #[allow(clippy::should_implement_trait)] pub fn clone(rc: &Rc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Rc won't be dropped until this Rc drops, // therefore it is safe to assume the entry managed by this Rc will // live at least as long as this Rc, and we can extend the lifetime // of the reference beyond the EntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Rc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } pub type ArcMemory = SyncMemory>; pub struct Arc { memory: std::sync::Arc>, key: usize, } impl Arc { #[allow(clippy::result_unit_err)] pub fn new(v: T, memory: &std::sync::Arc>) -> Result { let key = memory.insert(RcEntry { value: v, refs: 1 })?; Ok(Self { memory: memory.clone(), key, }) } #[allow(clippy::should_implement_trait)] pub fn clone(rc: &Arc) -> Self { let mut e = rc.memory.get(rc.key).unwrap(); e.refs += 1; Self { memory: rc.memory.clone(), key: rc.key, } } pub fn get<'a>(&'a self) -> &'a T { let e = self.memory.get(self.key).unwrap(); // get a reference to the inner value let value = &e.value; // entry addresses are guaranteed to be stable once created, and the // entry managed by this Arc won't be dropped until this Arc drops, // therefore it is safe to assume the entry managed by this Arc will // live at least as long as this Arc, and we can extend the lifetime // of the reference beyond the SyncEntryGuard unsafe { mem::transmute::<&T, &'a T>(value) } } } impl Drop for Arc { fn drop(&mut self) { let mut e = self.memory.get(self.key).unwrap(); if e.refs == 1 { e.remove(); return; } e.refs -= 1; } } // adapted from https://github.com/rust-lang/rfcs/pull/2802 pub fn recycle_vec(mut v: Vec) -> Vec { assert_eq!(core::mem::size_of::(), core::mem::size_of::()); assert_eq!(core::mem::align_of::(), core::mem::align_of::()); v.clear(); let ptr = v.as_mut_ptr(); let capacity = v.capacity(); mem::forget(v); let ptr = ptr as *mut U; unsafe { Vec::from_raw_parts(ptr, 0, capacity) } } // ReusableVec inspired by recycle_vec pub struct ReusableVecHandle<'a, T> { vec: &'a mut Vec, } impl ReusableVecHandle<'_, T> { pub fn get_ref(&self) -> &Vec { self.vec } pub fn get_mut(&mut self) -> &mut Vec { self.vec } } impl Drop for ReusableVecHandle<'_, T> { fn drop(&mut self) { self.vec.clear(); } } impl Deref for ReusableVecHandle<'_, T> { type Target = Vec; fn deref(&self) -> &Self::Target { self.get_ref() } } impl DerefMut for ReusableVecHandle<'_, T> { fn deref_mut(&mut self) -> &mut Self::Target { self.get_mut() } } pub struct ReusableVec { vec: Vec, size: usize, align: usize, } impl ReusableVec { pub fn new(capacity: usize) -> Self { let size = mem::size_of::(); let align = mem::align_of::(); let vec: Vec = Vec::with_capacity(capacity); // safety: we must cast to Vec before using, where U has the same // size and alignment as T let vec: Vec = unsafe { mem::transmute(vec) }; Self { vec, size, align } } pub fn get_as_new(&mut self) -> ReusableVecHandle<'_, U> { let size = mem::size_of::(); let align = mem::align_of::(); // if these don't match, panic. it's up the user to ensure the type // is acceptable assert_eq!(self.size, size); assert_eq!(self.align, align); let vec: &mut Vec = &mut self.vec; // safety: U has the expected size and alignment let vec: &mut Vec = unsafe { mem::transmute(vec) }; // the vec starts empty, and is always cleared when the handle drops. // get_as_new() borrows self mutably, so it's not possible to create // a handle when one already exists assert!(vec.is_empty()); ReusableVecHandle { vec } } } #[cfg(test)] mod tests { use super::*; #[test] fn test_reusable() { let reusable = std::sync::Arc::new(Reusable::new(2, || vec![0; 128])); assert_eq!(reusable.len(), 0); let mut buf1 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 1); let mut buf2 = reusable.reserve().unwrap(); assert_eq!(reusable.len(), 2); // no room assert!(reusable.reserve().is_err()); buf1[..5].copy_from_slice(b"hello"); buf2[..5].copy_from_slice(b"world"); assert_eq!(&buf1[..5], b"hello"); assert_eq!(&buf2[..5], b"world"); mem::drop(buf1); assert_eq!(reusable.len(), 1); mem::drop(buf2); assert_eq!(reusable.len(), 0); } #[test] fn test_rc() { let memory = std::rc::Rc::new(RcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Rc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Rc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Rc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Rc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_arc() { let memory = std::sync::Arc::new(ArcMemory::new(2)); assert_eq!(memory.len(), 0); let e0a = Arc::new(123 as i32, &memory).unwrap(); assert_eq!(memory.len(), 1); let p = memory.entry0_ptr(); let e0b = Arc::clone(&e0a); assert_eq!(memory.len(), 1); assert_eq!(memory.entry0_ptr(), p); let e1a = Arc::new(456 as i32, &memory).unwrap(); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); // no room assert!(Arc::new(789 as i32, &memory).is_err()); assert_eq!(*e0a.get(), 123); assert_eq!(*e0b.get(), 123); assert_eq!(*e1a.get(), 456); mem::drop(e0b); assert_eq!(memory.len(), 2); assert_eq!(memory.entry0_ptr(), p); mem::drop(e0a); assert_eq!(memory.len(), 1); mem::drop(e1a); assert_eq!(memory.len(), 0); } #[test] fn test_reusable_vec() { let mut vec_mem = ReusableVec::new::(100); let mut vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); vec.push(1); assert_eq!(vec.len(), 1); mem::drop(vec); let vec = vec_mem.get_as_new::(); assert_eq!(vec.capacity(), 100); assert_eq!(vec.len(), 0); } } pushpin-1.41.0/src/core/buffer.rs000066400000000000000000000767251504671364300166530ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::cell::RefCell; use std::cmp; use std::io; use std::io::Write; use std::mem::{self, MaybeUninit}; use std::rc::Rc; use std::slice; #[cfg(test)] use std::io::Read; pub const VECTORED_MAX: usize = 8; pub fn trim_for_display(s: &str, max: usize) -> String { // NOTE: O(n) let char_len = s.chars().count(); if char_len > max && max >= 7 { let dist = max / 2; let mut left_end = 0; let mut right_start = 0; // NOTE: O(n) for (i, (pos, _)) in s.char_indices().enumerate() { // dist guaranteed to be < char_len if i == dist { left_end = pos; } // (char_len - dist + 3) guaranteed to be < char_len if i == char_len - dist + 3 { right_start = pos; } } let left = &s[..left_end]; let right = &s[right_start..]; format!("{}...{}", left, right) } else { s.to_owned() } } fn init_array<'a, T, const N: usize>(arr: &'a mut MaybeUninit<[T; N]>, src: &mut [T]) -> &'a mut [T] where T: Default, { // SAFETY: T and MaybeUninit have the same layout let arr: &mut [MaybeUninit; N] = unsafe { mem::transmute(arr) }; let len = cmp::min(arr.len(), src.len()); for (d, s) in arr.iter_mut().zip(src) { d.write(mem::take(s)); } // SAFETY: the slice will contain only initialized elements unsafe { slice::from_raw_parts_mut(arr[0].as_mut_ptr(), len) } } pub trait Buffer { fn len(&self) -> usize; fn remaining_capacity(&self) -> usize; fn read_buf(&self) -> &[u8]; fn read_buf_mut(&mut self) -> &mut [u8]; fn read_commit(&mut self, amount: usize); fn write_buf(&mut self) -> &mut [u8]; fn write_commit(&mut self, amount: usize); fn is_empty(&self) -> bool { self.len() == 0 } fn read_bufs<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { if !bufs.is_empty() { bufs[0] = self.read_buf(); &mut bufs[..1] } else { &mut [] } } fn read_bufs_mut<'data, 'bufs, const N: usize>( &'data mut self, bufs: &'bufs mut MaybeUninit<[&'data mut [u8]; N]>, ) -> &'bufs mut [&'data mut [u8]] { init_array(bufs, &mut [self.read_buf_mut()]) } } // for reading only impl Buffer for io::Cursor<&mut [u8]> { fn len(&self) -> usize { Buffer::read_buf(self).len() } fn remaining_capacity(&self) -> usize { 0 } fn read_buf(&self) -> &[u8] { let pos = self.position() as usize; &self.get_ref()[pos..] } fn read_buf_mut(&mut self) -> &mut [u8] { let pos = self.position() as usize; &mut self.get_mut()[pos..] } fn read_commit(&mut self, amount: usize) { let pos = self.position(); self.set_position(pos + (amount as u64)); } fn write_buf(&mut self) -> &mut [u8] { &mut [] } fn write_commit(&mut self, amount: usize) { assert_eq!(amount, 0); } } pub fn write_vectored_offset( writer: &mut W, bufs: &[&[u8]], offset: usize, ) -> Result { if bufs.is_empty() { return Ok(0); } let mut offset = offset; let mut start = 0; while offset >= bufs[start].len() { // on the last buf? if start + 1 >= bufs.len() { // exceeding the last buf is an error if offset > bufs[start].len() { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } return Ok(0); } offset -= bufs[start].len(); start += 1; } let mut arr = [io::IoSlice::new(&b""[..]); VECTORED_MAX]; let mut arr_len = 0; for (index, &buf) in bufs.iter().enumerate().skip(start) { let buf = if index == start { &buf[offset..] } else { buf }; arr[arr_len] = io::IoSlice::new(buf); arr_len += 1; } writer.write_vectored(&arr[..arr_len]) } struct LimitBufsRestore { index: usize, ptr: T, len: usize, } pub struct LimitBufsGuard<'a, 'b> { bufs: &'b mut [&'a [u8]], start: usize, end: usize, restore: Option>, } impl<'a: 'b, 'b> LimitBufsGuard<'a, 'b> { pub fn as_slice(&self) -> &[&'a [u8]] { &self.bufs[self.start..self.end] } } impl<'a: 'b, 'b> Drop for LimitBufsGuard<'a, 'b> { fn drop(&mut self) { if let Some(restore) = self.restore.take() { // SAFETY: ptr and len were collected earlier from the original // memory referred to by the slice at this index and they are // still valid. the only issue with reconstructing the slice is // that we currently have a different slice using the same memory // at this index. however, this is safe because we also replace // the slice at this index and the two slices don't coexist unsafe { self.bufs[restore.index] = slice::from_raw_parts(restore.ptr, restore.len); } } } } pub struct LimitBufsMutGuard<'a, 'b> { bufs: &'b mut [&'a mut [u8]], start: usize, end: usize, restore: Option>, } impl<'a: 'b, 'b> LimitBufsMutGuard<'a, 'b> { pub fn as_slice(&mut self) -> &mut [&'a mut [u8]] { &mut self.bufs[self.start..self.end] } } impl<'a: 'b, 'b> Drop for LimitBufsMutGuard<'a, 'b> { fn drop(&mut self) { if let Some(restore) = self.restore.take() { // SAFETY: ptr and len were collected earlier from the original // memory referred to by the slice at this index and they are // still valid. the only issue with reconstructing the slice is // that we currently have a different slice using the same memory // at this index. however, this is safe because we also replace // the slice at this index and the two slices don't coexist unsafe { self.bufs[restore.index] = slice::from_raw_parts_mut(restore.ptr, restore.len); } } } } pub trait LimitBufs<'a, 'b> { fn limit(&'b mut self, size: usize) -> LimitBufsGuard<'a, 'b>; } impl<'a: 'b, 'b> LimitBufs<'a, 'b> for [&'a [u8]] { fn limit(&'b mut self, size: usize) -> LimitBufsGuard<'a, 'b> { let mut end = self.len(); let mut restore = None; let mut want = size; for (index, item) in self.iter_mut().enumerate() { let buf: &[u8] = item; let buf_len = buf.len(); if buf_len >= want { let len = buf.len(); let ptr = buf.as_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &slice::from_raw_parts(ptr, len)[..want]; } end = index + 1; break; } want -= buf_len; } LimitBufsGuard { bufs: self, start: 0, end, restore, } } } pub trait LimitBufsMut<'a: 'b, 'b> { fn skip(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b>; fn limit(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b>; } impl<'a: 'b, 'b> LimitBufsMut<'a, 'b> for [&'a mut [u8]] { fn skip(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b> { let mut start = 0; let end = self.len(); let mut restore = None; let mut skip = size; for (index, item) in self.iter_mut().enumerate() { let buf: &mut [u8] = item; let buf_len = buf.len(); if buf_len >= skip { let len = buf.len(); let ptr = buf.as_mut_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &mut slice::from_raw_parts_mut(ptr, len)[skip..]; } start = index; break; } skip -= buf_len; } LimitBufsMutGuard { bufs: self, start, end, restore, } } fn limit(&'b mut self, size: usize) -> LimitBufsMutGuard<'a, 'b> { let mut end = self.len(); let mut restore = None; let mut want = size; for (index, item) in self.iter_mut().enumerate() { let buf: &mut [u8] = item; let buf_len = buf.len(); if buf_len >= want { let len = buf.len(); let ptr = buf.as_mut_ptr(); restore = Some(LimitBufsRestore { index, ptr, len }); // SAFETY: ptr and len were obtained above and are still // valid. we just need to be careful about using them again // later on from the restore field unsafe { *item = &mut slice::from_raw_parts_mut(ptr, len)[..want]; } end = index + 1; break; } want -= buf_len; } LimitBufsMutGuard { bufs: self, start: 0, end, restore, } } } pub struct ContiguousBuffer { buf: Vec, start: usize, end: usize, } #[allow(clippy::len_without_is_empty)] impl ContiguousBuffer { pub fn new(size: usize) -> Self { let buf = vec![0; size]; Self { buf, start: 0, end: 0, } } pub fn clear(&mut self) { self.start = 0; self.end = 0; } } impl Buffer for ContiguousBuffer { fn len(&self) -> usize { self.end - self.start } fn remaining_capacity(&self) -> usize { self.buf.len() - self.end } fn read_buf(&self) -> &[u8] { &self.buf[self.start..self.end] } fn read_buf_mut(&mut self) -> &mut [u8] { &mut self.buf[self.start..self.end] } fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); self.start += amount; } fn write_buf(&mut self) -> &mut [u8] { let len = self.buf.len(); &mut self.buf[self.end..len] } fn write_commit(&mut self, amount: usize) { assert!(self.end + amount <= self.buf.len()); self.end += amount; } } #[cfg(test)] impl Read for ContiguousBuffer { fn read(&mut self, buf: &mut [u8]) -> Result { // fully qualified to work around future method warning // https://github.com/rust-lang/rust/issues/48919 let src = Buffer::read_buf(self); let size = cmp::min(src.len(), buf.len()); buf[..size].copy_from_slice(&src[..size]); self.read_commit(size); Ok(size) } } impl Write for ContiguousBuffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.remaining_capacity() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len()); dest[..size].copy_from_slice(&buf[..size]); self.write_commit(size); Ok(size) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct TmpBuffer(RefCell>); #[allow(clippy::len_without_is_empty)] impl TmpBuffer { pub fn new(size: usize) -> Self { Self(RefCell::new(vec![0; size])) } pub fn len(&self) -> usize { self.0.borrow().len() } } // holds a Vec but only exposes the portion of it considered to be // readable ("filled"). any remaining bytes may be zeroed or uninitialized // and are not considered to be readable pub struct FilledBuf { data: Vec, filled: usize, } impl FilledBuf { // panics if filled is larger than data.len() pub fn new(data: Vec, filled: usize) -> Self { assert!(filled <= data.len()); Self { data, filled } } pub fn filled(&self) -> &[u8] { &self.data[..self.filled] } pub fn filled_len(&self) -> usize { self.filled } pub fn into_inner(self) -> Vec { self.data } } pub struct RingBuffer { buf: T, start: usize, end: usize, tmp: Rc, } #[allow(clippy::len_without_is_empty)] impl + AsMut<[u8]>> RingBuffer { pub fn capacity(&self) -> usize { self.buf.as_ref().len() } pub fn clear(&mut self) { self.start = 0; self.end = 0; } // return true if the readable bytes have not wrapped pub fn is_readable_contiguous(&self) -> bool { self.end <= self.buf.as_ref().len() } pub fn align(&mut self) -> usize { assert!(self.buf.as_ref().len() <= self.tmp.len()); if self.start == 0 { return 0; } let buf = self.buf.as_mut(); let size = self.end - self.start; if self.end <= buf.len() { // if the buffer hasn't wrapped, simply copy down buf.copy_within(self.start.., 0); } else if size <= self.start { // if the buffer has wrapped, but the wrapped part can be copied // without overlapping, then copy the wrapped part followed by // initial part let left_size = self.end - buf.len(); let right_size = buf.len() - self.start; buf.copy_within(..left_size, right_size); buf.copy_within(self.start..(self.start + right_size), 0); } else { // if the buffer has wrapped and the wrapped part can't be copied // without overlapping, then use a temporary buffer to // facilitate. smaller part is copied to the temp buffer, then // the larger and small parts (in that order) are copied into // their intended locations. in the worst case, up to 50% of // the buffer may be copied twice let left_size = self.end - buf.len(); let right_size = buf.len() - self.start; let (lsize, lsrc, ldest, hsize, hsrc, hdest); if left_size < right_size { lsize = left_size; hsize = right_size; lsrc = 0; ldest = hsize; hsrc = self.start; hdest = 0; } else { lsize = right_size; hsize = left_size; lsrc = self.start; ldest = 0; hsrc = 0; hdest = lsize; } let mut tmp = self.tmp.0.borrow_mut(); tmp[..lsize].copy_from_slice(&buf[lsrc..(lsrc + lsize)]); buf.copy_within(hsrc..(hsrc + hsize), hdest); buf[ldest..(ldest + lsize)].copy_from_slice(&tmp[..lsize]); } self.start = 0; self.end = size; size } pub fn get_tmp(&self) -> &Rc { &self.tmp } } #[cfg(test)] impl + AsMut<[u8]>> Read for RingBuffer { fn read(&mut self, buf: &mut [u8]) -> Result { let mut pos = 0; while pos < buf.len() && self.len() > 0 { // fully qualified to work around future method warning // https://github.com/rust-lang/rust/issues/48919 let src = Buffer::read_buf(self); let size = cmp::min(src.len(), buf.len() - pos); buf[pos..(pos + size)].copy_from_slice(&src[..size]); self.read_commit(size); pos += size; } Ok(pos) } } impl + AsMut<[u8]>> Write for RingBuffer { fn write(&mut self, buf: &[u8]) -> Result { if !buf.is_empty() && self.remaining_capacity() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let mut pos = 0; while pos < buf.len() && self.remaining_capacity() > 0 { let dest = self.write_buf(); let size = cmp::min(dest.len(), buf.len() - pos); dest[..size].copy_from_slice(&buf[pos..(pos + size)]); self.write_commit(size); pos += size; } Ok(pos) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } impl + AsMut<[u8]>> Buffer for RingBuffer { fn len(&self) -> usize { self.end - self.start } fn remaining_capacity(&self) -> usize { self.buf.as_ref().len() - (self.end - self.start) } fn read_buf(&self) -> &[u8] { let buf = self.buf.as_ref(); let end = cmp::min(self.end, buf.len()); &buf[self.start..end] } fn read_buf_mut(&mut self) -> &mut [u8] { let buf = self.buf.as_mut(); let end = cmp::min(self.end, buf.len()); &mut buf[self.start..end] } fn read_commit(&mut self, amount: usize) { assert!(self.start + amount <= self.end); let buf = self.buf.as_ref(); self.start += amount; if self.start == self.end { self.start = 0; self.end = 0; } else if self.start >= buf.len() { self.start -= buf.len(); self.end -= buf.len(); } } fn write_buf(&mut self) -> &mut [u8] { let buf = self.buf.as_mut(); let (start, end) = if self.end < buf.len() { (self.end, buf.len()) } else { (self.end - buf.len(), self.start) }; &mut buf[start..end] } fn write_commit(&mut self, amount: usize) { assert!((self.end - self.start) + amount <= self.buf.as_ref().len()); self.end += amount; } fn read_bufs<'data, 'bufs>( &'data self, bufs: &'bufs mut [&'data [u8]], ) -> &'bufs mut [&'data [u8]] { assert!(!bufs.is_empty()); let buf = self.buf.as_ref(); let buf_len = buf.len(); if self.end > buf_len && bufs.len() >= 2 { let (part1, part2) = buf.split_at(self.start); bufs[0] = part2; bufs[1] = &part1[..(self.end - buf_len)]; &mut bufs[..2] } else { bufs[0] = &buf[self.start..self.end]; &mut bufs[..1] } } fn read_bufs_mut<'data, 'bufs, const N: usize>( &'data mut self, bufs: &'bufs mut MaybeUninit<[&'data mut [u8]; N]>, ) -> &'bufs mut [&'data mut [u8]] { let buf = self.buf.as_mut(); let buf_len = buf.len(); if self.end > buf_len { let (part1, part2) = buf.split_at_mut(self.start); init_array(bufs, &mut [part2, &mut part1[..(self.end - buf_len)]]) } else { init_array(bufs, &mut [&mut buf[self.start..self.end]]) } } } impl RingBuffer> { pub fn new(size: usize, tmp: &Rc) -> Self { assert!(size <= tmp.len()); let buf = vec![0; size]; Self { buf, start: 0, end: 0, tmp: Rc::clone(tmp), } } // extract inner buffer, aligning it first if necessary, and replace it // with an empty buffer. this should be cheap if the inner buffer is // already aligned. afterwards, the ringbuffer will have a capacity of // zero and will be essentially unusable until set_inner is called with a // non-empty buffer pub fn take_inner(&mut self) -> FilledBuf { self.align(); let data = mem::take(&mut self.buf); let filled = self.end; self.end = 0; FilledBuf::new(data, filled) } // replace the inner buffer. this should be cheap if the original inner // buffer is empty, which is the case if take_inner was called earlier. // panics if the new buffer is larger than the tmp buffer pub fn set_inner(&mut self, buf: FilledBuf) { let filled = buf.filled_len(); let data = buf.into_inner(); assert!(data.len() <= self.tmp.len()); self.buf = data; self.start = 0; self.end = filled; } pub fn swap_inner(&mut self, other: &mut Self) { let buf = self.take_inner(); self.set_inner(other.take_inner()); other.set_inner(buf); } pub fn resize(&mut self, size: usize) { if size == self.buf.len() { return; } self.align(); self.buf.resize(size, 0); self.buf.shrink_to_fit(); self.end = cmp::min(self.end, size); } } impl<'a> RingBuffer<&'a mut [u8]> { pub fn new(buf: &'a mut [u8], tmp: &Rc) -> Self { assert!(buf.len() <= tmp.len()); Self { buf, start: 0, end: 0, tmp: Rc::clone(tmp), } } } pub type VecRingBuffer = RingBuffer>; pub type SliceRingBuffer<'a> = RingBuffer<&'a mut [u8]>; #[cfg(test)] mod tests { use super::*; use std::io::{Read, Write}; #[test] fn test_write_vectored_offset() { struct MyWriter { bufs: Vec, } impl MyWriter { fn new() -> Self { Self { bufs: Vec::new() } } } impl Write for MyWriter { fn write(&mut self, buf: &[u8]) -> Result { self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); Ok(buf.len()) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { total += buf.len(); self.bufs.push(String::from_utf8(buf.to_vec()).unwrap()); } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } // empty let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[], 0); assert_eq!(r.unwrap(), 0); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 6); assert!(r.is_err()); // offset too large let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 12); assert!(r.is_err()); // nothing to write let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 5); assert_eq!(r.unwrap(), 0); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 0); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "apple"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple"], 3); assert_eq!(r.unwrap(), 2); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "le"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 3); assert_eq!(r.unwrap(), 8); assert_eq!(w.bufs.len(), 2); assert_eq!(w.bufs[0], "le"); assert_eq!(w.bufs[1], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 5); assert_eq!(r.unwrap(), 6); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "banana"); let mut w = MyWriter::new(); let r = write_vectored_offset(&mut w, &[b"apple", b"banana"], 6); assert_eq!(r.unwrap(), 5); assert_eq!(w.bufs.len(), 1); assert_eq!(w.bufs[0], "anana"); } #[test] fn test_buffer() { let mut b = ContiguousBuffer::new(8); assert_eq!(b.len(), 0); assert_eq!(b.remaining_capacity(), 8); let size = b.write(b"hello").unwrap(); assert_eq!(size, 5); assert_eq!(b.len(), 5); assert_eq!(b.remaining_capacity(), 3); let size = b.write(b"world").unwrap(); assert_eq!(size, 3); assert_eq!(b.len(), 8); assert_eq!(b.remaining_capacity(), 0); let mut tmp = [0; 16]; let size = b.read(&mut tmp).unwrap(); assert_eq!(&tmp[..size], b"hellowor"); b.clear(); assert_eq!(b.len(), 0); assert_eq!(b.remaining_capacity(), 8); } #[test] fn test_ringbuffer() { let mut buf = [0u8; 8]; let tmp = Rc::new(TmpBuffer::new(8)); let mut r = VecRingBuffer::new(8, &tmp); assert_eq!(r.len(), 0); assert_eq!(r.remaining_capacity(), 8); r.write(b"12345").unwrap(); assert_eq!(r.len(), 5); assert_eq!(r.remaining_capacity(), 3); r.write(b"678").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.read_bufs(&mut bufs_arr); assert_eq!(r.len(), 8); assert_eq!(r.remaining_capacity(), 0); assert_eq!(r.read_buf(), b"12345678"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"12345678"); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.len(), 3); assert_eq!(r.remaining_capacity(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"9abcd").unwrap(); assert_eq!(r.len(), 8); assert_eq!(r.remaining_capacity(), 0); r.read(&mut buf[5..]).unwrap(); assert_eq!(r.len(), 5); assert_eq!(r.remaining_capacity(), 3); r.read(&mut buf[..5]).unwrap(); assert_eq!(r.len(), 0); assert_eq!(r.remaining_capacity(), 8); assert_eq!(&buf, b"9abcd678"); r.write(b"12345").unwrap(); r.read(&mut buf[..2]).unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.read_bufs(&mut bufs_arr); assert_eq!(r.len(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"345"); assert_eq!(r.remaining_capacity(), 5); assert_eq!(r.write_buf().len(), 3); r.align(); assert_eq!(r.len(), 3); assert_eq!(r.read_buf(), b"345"); assert_eq!(r.remaining_capacity(), 5); assert_eq!(r.write_buf().len(), 5); r.write(b"6789a").unwrap(); r.read(&mut buf[..2]).unwrap(); r.write(b"bc").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.read_bufs(&mut bufs_arr); assert_eq!(r.len(), 8); assert_eq!(r.read_buf(), b"56789a"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"56789a"); assert_eq!(bufs[1], b"bc"); assert_eq!(r.remaining_capacity(), 0); r.align(); assert_eq!(r.len(), 8); assert_eq!(r.read_buf(), b"56789abc"); assert_eq!(r.remaining_capacity(), 0); r.read(&mut buf[..6]).unwrap(); r.write(b"def123").unwrap(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.read_bufs(&mut bufs_arr); assert_eq!(r.len(), 8); assert_eq!(r.read_buf(), b"bc"); assert_eq!(bufs.len(), 2); assert_eq!(bufs[0], b"bc"); assert_eq!(bufs[1], b"def123"); assert_eq!(r.remaining_capacity(), 0); r.align(); let mut bufs_arr = [&b""[..]; VECTORED_MAX]; let bufs = r.read_bufs(&mut bufs_arr); assert_eq!(r.len(), 8); assert_eq!(r.read_buf(), b"bcdef123"); assert_eq!(bufs.len(), 1); assert_eq!(bufs[0], b"bcdef123"); assert_eq!(r.remaining_capacity(), 0); r.clear(); r.write(b"12345678").unwrap(); r.read(&mut buf[..6]).unwrap(); r.write(b"9abc").unwrap(); assert_eq!(r.len(), 6); assert_eq!(r.read_buf().len(), 2); r.align(); assert_eq!(r.len(), 6); assert_eq!(r.read_buf().len(), 6); } #[test] fn test_slice_ringbuffer() { let mut buf = [0; 8]; let mut backing_buf = [0; 8]; let tmp = Rc::new(TmpBuffer::new(8)); let mut r = SliceRingBuffer::new(&mut backing_buf, &tmp); r.write(b"12345678").unwrap(); let size = r.read(&mut buf[..4]).unwrap(); assert_eq!(&buf[..size], b"1234"); r.write(b"90ab").unwrap(); let size = r.read(&mut buf).unwrap(); assert_eq!(&buf[..size], b"567890ab"); } #[test] fn test_limitbufs() { let mut buf1 = [b'1', b'2', b'3', b'4']; let mut buf2 = [b'5', b'6', b'7', b'8']; let mut buf3 = [b'9', b'0', b'a', b'b']; let mut bufs = [buf1.as_slice(), buf2.as_slice(), buf3.as_slice()]; { let limited = bufs.limit(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"1234"); assert_eq!(&limited[1], b"567"); } assert_eq!(bufs.len(), 3); assert_eq!(&bufs[0], b"1234"); assert_eq!(&bufs[1], b"5678"); assert_eq!(&bufs[2], b"90ab"); let mut bufs = [ buf1.as_mut_slice(), buf2.as_mut_slice(), buf3.as_mut_slice(), ]; { let mut limited = bufs.limit(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"1234"); assert_eq!(&limited[1], b"567"); } { let mut limited = bufs.skip(7); let limited = limited.as_slice(); assert_eq!(limited.len(), 2); assert_eq!(&limited[0], b"8"); assert_eq!(&limited[1], b"90ab"); } assert_eq!(bufs.len(), 3); assert_eq!(&bufs[0], b"1234"); assert_eq!(&bufs[1], b"5678"); assert_eq!(&bufs[2], b"90ab"); } #[test] fn test_resize() { let tmp = Rc::new(TmpBuffer::new(16)); let mut r = VecRingBuffer::new(8, &tmp); assert_eq!(r.capacity(), 8); let size = r.write(b"12345678").unwrap(); assert_eq!(size, 8); let mut buf = [0; 4]; let size = r.read(&mut buf).unwrap(); assert_eq!(size, 4); assert_eq!(&buf[..size], b"1234"); let size = r.write(b"90ab").unwrap(); assert_eq!(size, 4); assert!(r.write(b"cdef").is_err()); r.resize(12); assert_eq!(r.capacity(), 12); let size = r.write(b"cdef").unwrap(); assert_eq!(size, 4); let mut buf = [0; 12]; let size = r.read(&mut buf).unwrap(); assert_eq!(size, 12); assert_eq!(&buf[..size], b"567890abcdef"); let size = r.write(b"1234567890").unwrap(); assert_eq!(size, 10); r.resize(8); assert_eq!(r.capacity(), 8); let mut buf = [0; 12]; let size = r.read(&mut buf).unwrap(); assert_eq!(size, 8); assert_eq!(&buf[..size], b"12345678"); } } pushpin-1.41.0/src/core/bufferlist.cpp000066400000000000000000000063031504671364300176660ustar00rootroot00000000000000/* * Copyright (C) 2013 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "bufferlist.h" #include BufferList::BufferList() : size_(0), offset_(0) { } void BufferList::findPos(int pos, int *bufferIndex, int *offset) const { assert(pos < size_); int at = 0; int curOffset = offset_; while(true) { const QByteArray &buf = bufs_[at]; if(curOffset + pos < buf.size()) break; ++at; pos -= (buf.size() - curOffset); curOffset = 0; } *bufferIndex = at; *offset = curOffset + pos; } QByteArray BufferList::mid(int pos, int size) const { assert(pos >= 0); if(size_ == 0 || size == 0 || pos >= size_) return QByteArray(); int toRead; if(size > 0) toRead = qMin(size, size_ - pos); else toRead = size_ - pos; assert(!bufs_.isEmpty()); int at; int offset; findPos(pos, &at, &offset); // if we're reading the exact size of the current buffer, cheaply // return it if(offset == 0 && bufs_[at].size() == toRead) return bufs_[at]; QByteArray out; out.resize(toRead); char *outp = out.data(); while(toRead > 0) { const QByteArray &buf = bufs_[at]; int bsize = qMin(buf.size() - offset, toRead); memcpy(outp, buf.data() + offset, bsize); if(offset + bsize >= buf.size()) { ++at; offset = 0; } toRead -= bsize; outp += bsize; } return out; } void BufferList::clear() { bufs_.clear(); size_ = 0; offset_ = 0; } void BufferList::append(const QByteArray &buf) { if(buf.size() < 1) return; bufs_ += buf; size_ += buf.size(); } QByteArray BufferList::take(int size) { if(size_ == 0 || size == 0) return QByteArray(); int toRead; if(size > 0) toRead = qMin(size, size_); else toRead = size_; assert(!bufs_.isEmpty()); // if we're reading the exact size of the first buffer, cheaply // return it if(offset_ == 0 && bufs_.first().size() == toRead) { size_ -= toRead; return bufs_.takeFirst(); } QByteArray out; out.resize(toRead); char *outp = out.data(); while(toRead > 0) { const QByteArray &buf = bufs_.first(); int bsize = qMin(buf.size() - offset_, toRead); memcpy(outp, buf.data() + offset_, bsize); if(offset_ + bsize >= buf.size()) { bufs_.removeFirst(); offset_ = 0; } else offset_ += bsize; toRead -= bsize; size_ -= bsize; outp += bsize; } return out; } QByteArray BufferList::toByteArray() { if(size_ == 0) return QByteArray(); QByteArray out; while(!bufs_.isEmpty()) { if(offset_ > 0) { out += bufs_.first().mid(offset_); offset_ = 0; bufs_.removeFirst(); } else out += bufs_.takeFirst(); } // keep the rewritten buffer as the only buffer bufs_ += out; return out; } pushpin-1.41.0/src/core/bufferlist.h000066400000000000000000000024201504671364300173270ustar00rootroot00000000000000/* * Copyright (C) 2013 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef BUFFERLIST_H #define BUFFERLIST_H #include #include class BufferList { public: BufferList(); int size() const { return size_; } bool isEmpty() const { return size_ == 0; } QByteArray mid(int pos, int size = -1) const; void clear(); void append(const QByteArray &buf); QByteArray take(int size = -1); QByteArray toByteArray(); // non-const because we rewrite the list BufferList & operator+=(const QByteArray &buf) { append(buf); return *this; } private: QList bufs_; int size_; int offset_; void findPos(int pos, int *bufferIndex, int *offset) const; }; #endif pushpin-1.41.0/src/core/callback.h000066400000000000000000000055141504671364300167250ustar00rootroot00000000000000/* * Copyright (C) 2023 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef CALLBACK_H #define CALLBACK_H #include #include template class Callback { public: typedef void (*CallbackFunc)(void *data, T); Callback() : activeCalls_(0), destroyed_(0) { } ~Callback() { if(destroyed_) *destroyed_ = true; } void add(CallbackFunc cb, void *data) { targets_ += Target(cb, data); } void remove(void *data) { // mark for removal, but don't actually remove for(int n = 0; n < targets_.count(); ++n) { Target &t = targets_[n]; if(t.second == data) { t.second = 0; } } // only actually remove if not in the middle of a call if(activeCalls_ == 0) { removeMarked(); } } void call(T value) { activeCalls_ += 1; for(int n = 0; n < targets_.count(); ++n) { const Target &t = targets_[n]; // skip if marked for removal if(!t.second) { continue; } CallbackFunc f = t.first; void *data = t.second; bool *prevDestroyed = destroyed_; bool destroyed = false; destroyed_ = &destroyed; f(data, value); if(destroyed) { if(prevDestroyed) { *prevDestroyed = true; } return; } destroyed_ = prevDestroyed; } assert(activeCalls_ >= 1); activeCalls_ -= 1; if(activeCalls_ == 0) { removeMarked(); } } private: typedef QPair Target; QList targets_; bool activeCalls_; bool *destroyed_; void removeMarked() { assert(activeCalls_ == 0); for(int n = 0; n < targets_.count(); ++n) { if(!targets_[n].second) { targets_.removeAt(n); --n; // adjust position } } } }; #endif pushpin-1.41.0/src/core/channel.rs000066400000000000000000001446001504671364300167760ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::arena; use crate::core::event; use crate::core::list; use crate::core::reactor::CustomEvented; use crate::core::task::get_reactor; use slab::Slab; use std::cell::RefCell; use std::collections::VecDeque; use std::future::Future; use std::mem; use std::pin::Pin; use std::rc::Rc; use std::sync::atomic::{AtomicBool, Ordering}; use std::sync::mpsc; use std::sync::Arc; use std::task::{Context, Poll}; pub const REGISTRATIONS_PER_CHANNEL: usize = 1; pub struct Sender { sender: Option>, read_set_readiness: event::SetReadiness, write_registration: event::Registration, cts: Option>, } impl Sender { // NOTE: only makes sense for rendezvous channels pub fn can_send(&self) -> bool { match &self.cts { Some(cts) => cts.load(Ordering::Relaxed), None => true, } } pub fn get_write_registration(&self) -> &event::Registration { &self.write_registration } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { if let Some(cts) = &self.cts { if cts .compare_exchange(true, false, Ordering::Relaxed, Ordering::Relaxed) .is_err() { return Err(mpsc::TrySendError::Full(t)); } // cts will only be true if a read was performed while the queue // was empty, and this function is the only place where the queue // is written to. this means the try_send call below will only // fail if the receiver disconnected } match self.sender.as_ref().unwrap().try_send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } pub fn send(&self, t: T) -> Result<(), mpsc::SendError> { if self.cts.is_some() { panic!("blocking send with rendezvous channel not supported") } match self.sender.as_ref().unwrap().send(t) { Ok(_) => { self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); Ok(()) } Err(e) => Err(e), } } } impl Drop for Sender { fn drop(&mut self) { mem::drop(self.sender.take().unwrap()); self.read_set_readiness .set_readiness(mio::Interest::READABLE) .unwrap(); } } pub struct Receiver { receiver: mpsc::Receiver, read_registration: event::Registration, write_set_readiness: event::SetReadiness, cts: Option>, } impl Receiver { pub fn get_read_registration(&self) -> &event::Registration { &self.read_registration } pub fn try_recv(&self) -> Result { match self.receiver.try_recv() { Ok(t) => { if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } Err(mpsc::TryRecvError::Empty) if self.cts.is_some() => { let cts = self.cts.as_ref().unwrap(); if cts .compare_exchange(false, true, Ordering::Relaxed, Ordering::Relaxed) .is_ok() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Err(mpsc::TryRecvError::Empty) } Err(e) => Err(e), } } pub fn recv(&self) -> Result { let t = self.receiver.recv()?; if self.cts.is_none() { self.write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } Ok(t) } } pub fn channel(bound: usize) -> (Sender, Receiver) { let (read_reg, read_sr) = event::Registration::new(); let (write_reg, write_sr) = event::Registration::new(); // rendezvous channel if bound == 0 { let (s, r) = mpsc::sync_channel::(1); let cts = Arc::new(AtomicBool::new(false)); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: Some(Arc::clone(&cts)), }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: Some(Arc::clone(&cts)), }; (sender, receiver) } else { let (s, r) = mpsc::sync_channel::(bound); let sender = Sender { sender: Some(s), read_set_readiness: read_sr, write_registration: write_reg, cts: None, }; let receiver = Receiver { receiver: r, read_registration: read_reg, write_set_readiness: write_sr, cts: None, }; // channel is immediately writable receiver .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); (sender, receiver) } } struct LocalSenderData { notified: bool, write_set_readiness: event::LocalSetReadiness, } struct LocalSenders { nodes: Slab>, waiting: list::List, } struct LocalChannel { queue: RefCell>, senders: RefCell, read_set_readiness: RefCell>, } impl LocalChannel { fn senders_is_empty(&self) -> bool { self.senders.borrow().nodes.is_empty() } fn add_sender(&self, write_sr: event::LocalSetReadiness) -> Result { let mut senders = self.senders.borrow_mut(); if senders.nodes.len() == senders.nodes.capacity() { return Err(()); } let key = senders.nodes.insert(list::Node::new(LocalSenderData { notified: false, write_set_readiness: write_sr, })); Ok(key) } fn remove_sender(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); senders.waiting.remove(&mut senders.nodes, key); senders.nodes.remove(key); if senders.nodes.is_empty() { if let Some(read_sr) = &*self.read_set_readiness.borrow() { // notify for disconnect read_sr.set_readiness(mio::Interest::READABLE).unwrap(); } } } fn set_sender_waiting(&self, key: usize) { let senders = &mut *self.senders.borrow_mut(); // add if not already present if senders.nodes[key].prev.is_none() && senders.waiting.head != Some(key) { senders.waiting.push_back(&mut senders.nodes, key); } } fn notify_one_sender(&self) { let senders = &mut *self.senders.borrow_mut(); // notify next waiting sender, if any if let Some(key) = senders.waiting.pop_front(&mut senders.nodes) { let sender = &mut senders.nodes[key].value; sender.notified = true; sender .write_set_readiness .set_readiness(mio::Interest::WRITABLE) .unwrap(); } } fn sender_is_notified(&self, key: usize) -> bool { self.senders.borrow().nodes[key].value.notified } fn clear_sender_notified(&self, key: usize) { self.senders.borrow_mut().nodes[key].value.notified = false; } } pub struct LocalSender { channel: Rc>, key: usize, write_registration: event::LocalRegistration, } impl LocalSender { pub fn get_write_registration(&self) -> &event::LocalRegistration { &self.write_registration } // if this returns true, then the next call to try_send() by any sender // is guaranteed to not return TrySendError::Full. // if this returns false, the sender is added to the wait list pub fn check_send(&self) -> bool { let queue = self.channel.queue.borrow(); let can_send = queue.len() < queue.capacity(); if !can_send { self.channel.set_sender_waiting(self.key); } can_send } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { // we are acting, so clear the notified flag self.channel.clear_sender_notified(self.key); let read_sr = &*self.channel.read_set_readiness.borrow(); let read_sr = match read_sr { Some(sr) => sr, None => { // receiver is disconnected return Err(mpsc::TrySendError::Disconnected(t)); } }; let mut queue = self.channel.queue.borrow_mut(); if queue.len() < queue.capacity() { queue.push_back(t); read_sr.set_readiness(mio::Interest::READABLE).unwrap(); Ok(()) } else { self.channel.set_sender_waiting(self.key); Err(mpsc::TrySendError::Full(t)) } } pub fn cancel(&self) { // if we were notified but never acted on it, notify the next waiting sender, if any if self.channel.sender_is_notified(self.key) { self.channel.clear_sender_notified(self.key); self.channel.notify_one_sender(); } } // NOTE: if the receiver is dropped while there are multiple senders, // only one of the senders will be notified of the disconnect #[allow(clippy::result_unit_err)] pub fn try_clone( &self, memory: &Rc>, ) -> Result { let (write_reg, write_sr) = event::LocalRegistration::new(memory); let key = self.channel.add_sender(write_sr)?; Ok(Self { channel: self.channel.clone(), key, write_registration: write_reg, }) } // returns error if a receiver already exists #[allow(clippy::result_unit_err)] pub fn make_receiver( &self, memory: &Rc>, ) -> Result, ()> { if self.channel.read_set_readiness.borrow().is_some() { return Err(()); } let (read_reg, read_sr) = event::LocalRegistration::new(memory); *self.channel.read_set_readiness.borrow_mut() = Some(read_sr); Ok(LocalReceiver { channel: self.channel.clone(), read_registration: read_reg, }) } } impl Drop for LocalSender { fn drop(&mut self) { self.cancel(); self.channel.remove_sender(self.key); } } pub struct LocalReceiver { channel: Rc>, read_registration: event::LocalRegistration, } impl LocalReceiver { pub fn get_read_registration(&self) -> &event::LocalRegistration { &self.read_registration } pub fn try_recv(&self) -> Result { let mut queue = self.channel.queue.borrow_mut(); if queue.is_empty() { if self.channel.senders_is_empty() { return Err(mpsc::TryRecvError::Disconnected); } return Err(mpsc::TryRecvError::Empty); } let value = queue.pop_front().unwrap(); self.channel.notify_one_sender(); Ok(value) } pub fn clear(&self) { // loop over try_recv() in order to notify senders while self.try_recv().is_ok() {} } } impl Drop for LocalReceiver { fn drop(&mut self) { *self.channel.read_set_readiness.borrow_mut() = None; self.channel.notify_one_sender(); } } pub fn local_channel( bound: usize, max_senders: usize, memory: &Rc>, ) -> (LocalSender, LocalReceiver) { let (read_reg, read_sr) = event::LocalRegistration::new(memory); let (write_reg, write_sr) = event::LocalRegistration::new(memory); // no support for rendezvous channels assert!(bound > 0); // need to support at least one sender assert!(max_senders > 0); let channel = Rc::new(LocalChannel { queue: RefCell::new(VecDeque::with_capacity(bound)), senders: RefCell::new(LocalSenders { nodes: Slab::with_capacity(max_senders), waiting: list::List::default(), }), read_set_readiness: RefCell::new(Some(read_sr)), }); let key = channel.add_sender(write_sr).unwrap(); let sender = LocalSender { channel: channel.clone(), key, write_registration: write_reg, }; let receiver = LocalReceiver { channel, read_registration: read_reg, }; (sender, receiver) } pub struct AsyncSender { evented: CustomEvented, inner: Sender, } impl AsyncSender { pub fn new(s: Sender) -> Self { let evented = CustomEvented::new( s.get_write_registration(), mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); // assume we can write, unless can_send() returns false. note that // if can_send() returns true, it doesn't mean we can actually write evented.registration().set_ready(s.can_send()); Self { evented, inner: s } } pub fn is_writable(&self) -> bool { self.evented.registration().is_ready() } pub fn wait_writable(&self) -> WaitWritableFuture<'_, T> { WaitWritableFuture { s: self } } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { match self.inner.try_send(t) { Ok(_) => { // if can_send() returns false, then we know we can't write if !self.inner.can_send() { self.evented.registration().set_ready(false); } Ok(()) } Err(mpsc::TrySendError::Full(t)) => { self.evented.registration().set_ready(false); Err(mpsc::TrySendError::Full(t)) } Err(mpsc::TrySendError::Disconnected(t)) => Err(mpsc::TrySendError::Disconnected(t)), } } pub fn send(&self, t: T) -> SendFuture<'_, T> { SendFuture { s: self, t: Some(t), } } } pub struct AsyncReceiver { evented: CustomEvented, inner: Receiver, } impl AsyncReceiver { pub fn new(r: Receiver) -> Self { let evented = CustomEvented::new( r.get_read_registration(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: r } } pub fn recv(&self) -> RecvFuture<'_, T> { RecvFuture { r: self } } } pub struct AsyncLocalSender { evented: CustomEvented, inner: LocalSender, } impl AsyncLocalSender { pub fn new(s: LocalSender) -> Self { let evented = CustomEvented::new_local( s.get_write_registration(), mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: s } } pub fn into_inner(self) -> LocalSender { // normally, the poll registration would be deregistered when the // sender drops, but here we are keeping the sender alive, so we need // to explicitly deregister self.evented .registration() .deregister_custom_local(self.inner.get_write_registration()) .unwrap(); self.inner } pub fn send(&self, t: T) -> LocalSendFuture<'_, T> { LocalSendFuture { s: self, t: Some(t), } } // after polling/awaiting the returned future, you must call try_send() // or cancel(), or drop the sender, in order to ensure proper // coordination when there are multiple senders. prefer using // wait_sendable() which guards against misuse. // it's okay to run multiple instances of this future within the same // task. see the comment on the CheckSendFuture struct. pub fn check_send(&self) -> CheckSendFuture<'_, T> { CheckSendFuture { s: self } } pub fn try_send(&self, t: T) -> Result<(), mpsc::TrySendError> { self.inner.try_send(t) } pub fn cancel(&self) { self.inner.cancel(); } // waits for sendability in order to perform a non-blocking send // afterward. basically a less error-prone version check_send() + // try_send(). the returned future calls cancel() if dropped before // completion. the output of the future is a SendOnce, which offers a // method for sending one value. SendOnce calls cancel() if dropped // without sending anything. // it's okay to run multiple instances of this future within the same // task. see the comment on the WaitSendableFuture struct. pub fn wait_sendable(&self) -> WaitSendableFuture<'_, T> { WaitSendableFuture { evented: &self.evented, sender: Some(&self.inner), } } } pub struct AsyncLocalReceiver { evented: CustomEvented, inner: LocalReceiver, } impl AsyncLocalReceiver { pub fn new(r: LocalReceiver) -> Self { let evented = CustomEvented::new_local( r.get_read_registration(), mio::Interest::READABLE, &get_reactor(), ) .unwrap(); evented.registration().set_ready(true); Self { evented, inner: r } } pub fn into_inner(self) -> LocalReceiver { // normally, the poll registration would be deregistered when the // receiver drops, but here we are keeping the receiver alive, so we // need to explicitly deregister self.evented .registration() .deregister_custom_local(self.inner.get_read_registration()) .unwrap(); self.inner } pub fn recv(&self) -> LocalRecvFuture<'_, T> { LocalRecvFuture { r: self } } } // allows one send attempt by calling try_send() which consumes self. // if struct is dropped instead, the send interest is canceled. pub struct SendOnce<'a, T> { sender: Option<&'a LocalSender>, } impl SendOnce<'_, T> { pub fn try_send(mut self, t: T) -> Result<(), mpsc::TrySendError> { self.sender.take().unwrap().try_send(t) } } impl Drop for SendOnce<'_, T> { fn drop(&mut self) { if let Some(sender) = &self.sender { sender.cancel(); } } } pub struct WaitWritableFuture<'a, T> { s: &'a AsyncSender, } impl Future for WaitWritableFuture<'_, T> { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); // if can_send() returns false, then we know we can't write. this // check prevents spurious wakups of a rendezvous channel from // indicating writability when the channel is not actually writable if !f.s.inner.can_send() { f.s.evented.registration().set_ready(false); } if !f.s.evented.registration().is_ready() { return Poll::Pending; } Poll::Ready(()) } } impl Drop for WaitWritableFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct SendFuture<'a, T> { s: &'a AsyncSender, t: Option, } impl Future for SendFuture<'_, T> where T: Unpin, { type Output = Result<(), mpsc::SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } let t = f.t.take().unwrap(); // try_send will update the registration readiness, so we don't need // to do that here match f.s.try_send(t) { Ok(()) => Poll::Ready(Ok(())), Err(mpsc::TrySendError::Full(t)) => { f.t = Some(t); Poll::Pending } Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), } } } impl Drop for SendFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct RecvFuture<'a, T> { r: &'a AsyncReceiver, } impl Future for RecvFuture<'_, T> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.r.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f.r.evented.registration().is_ready() { return Poll::Pending; } if !f.r.evented.registration().pull_from_budget() { return Poll::Pending; } match f.r.inner.try_recv() { Ok(v) => Poll::Ready(Ok(v)), Err(mpsc::TryRecvError::Empty) => { f.r.evented.registration().set_ready(false); Poll::Pending } Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), } } } impl Drop for RecvFuture<'_, T> { fn drop(&mut self) { self.r.evented.registration().clear_waker(); } } pub struct LocalSendFuture<'a, T> { s: &'a AsyncLocalSender, t: Option, } impl Future for LocalSendFuture<'_, T> where T: Unpin, { type Output = Result<(), mpsc::SendError>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } if !f.s.evented.registration().pull_from_budget() { return Poll::Pending; } let t = f.t.take().unwrap(); match f.s.inner.try_send(t) { Ok(()) => Poll::Ready(Ok(())), Err(mpsc::TrySendError::Full(t)) => { f.s.evented.registration().set_ready(false); f.t = Some(t); Poll::Pending } Err(mpsc::TrySendError::Disconnected(t)) => Poll::Ready(Err(mpsc::SendError(t))), } } } impl Drop for LocalSendFuture<'_, T> { fn drop(&mut self) { self.s.inner.cancel(); self.s.evented.registration().clear_waker(); } } // it's okay to maintain multiple instances of this future at the same time // within the same task. calling poll() won't negatively affect other // instances. the drop() method clears the waker on the shared registration, // which may look problematic. however, whenever any instance is (re-)polled, // the waker will be reinstated. // // notably, these scenarios work: // // * creating two instances and awaiting them sequentially // * creating two instances and selecting on them in a loop. both will // eventually complete // * creating one instance, polling it to pending, then creating a second // instance and polling it to completion, then polling on the first // instance again pub struct CheckSendFuture<'a, T> { s: &'a AsyncLocalSender, } impl Future for CheckSendFuture<'_, T> where T: Unpin, { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.s.inner.check_send() { f.s.evented.registration().set_ready(false); return Poll::Pending; } Poll::Ready(()) } } impl Drop for CheckSendFuture<'_, T> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } // it's okay to maintain multiple instances of this future at the same time // within the same task, because the logic is same as CheckSendFuture pub struct WaitSendableFuture<'a, T> { evented: &'a CustomEvented, sender: Option<&'a LocalSender>, } impl<'a, T> Future for WaitSendableFuture<'a, T> where T: Unpin, { type Output = SendOnce<'a, T>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.sender.unwrap().check_send() { f.evented.registration().set_ready(false); return Poll::Pending; } let sender = f.sender.take(); assert!(sender.is_some()); Poll::Ready(SendOnce { sender }) } } impl Drop for WaitSendableFuture<'_, T> { fn drop(&mut self) { if let Some(sender) = &self.sender { sender.cancel(); } self.evented.registration().clear_waker(); } } pub struct LocalRecvFuture<'a, T> { r: &'a AsyncLocalReceiver, } impl Future for LocalRecvFuture<'_, T> { type Output = Result; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &*self; f.r.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f.r.evented.registration().is_ready() { return Poll::Pending; } if !f.r.evented.registration().pull_from_budget() { return Poll::Pending; } match f.r.inner.try_recv() { Ok(v) => Poll::Ready(Ok(v)), Err(mpsc::TryRecvError::Empty) => { f.r.evented.registration().set_ready(false); Poll::Pending } Err(mpsc::TryRecvError::Disconnected) => Poll::Ready(Err(mpsc::RecvError)), } } } impl Drop for LocalRecvFuture<'_, T> { fn drop(&mut self) { self.r.evented.registration().clear_waker(); } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use crate::core::reactor::Reactor; use crate::core::task::poll_async; use std::cell::Cell; use std::time; #[test] fn send_recv_bound0() { let (sender, receiver) = channel(0); assert_eq!(sender.can_send(), false); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); assert_eq!(sender.can_send(), true); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); assert_eq!(sender.can_send(), false); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn send_recv_bound1() { let (sender, receiver) = channel(1); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); let result = sender.try_send(42); assert_eq!(result.is_ok(), true); let result = sender.try_send(42); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TrySendError::Full(42)); let result = receiver.try_recv(); assert_eq!(result.is_ok(), true); let v = result.unwrap(); assert_eq!(v, 42); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); mem::drop(sender); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Disconnected); } #[test] fn notify_bound0() { let (sender, receiver) = channel(0); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); assert_eq!(sender.can_send(), false); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); let result = receiver.try_recv(); assert_eq!(result.is_err(), true); assert_eq!(result.unwrap_err(), mpsc::TryRecvError::Empty); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); assert_eq!(sender.can_send(), true); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn notify_bound1() { let (sender, receiver) = channel(1); let mut poller = event::Poller::new(2).unwrap(); poller .register_custom( sender.get_write_registration(), mio::Token(1), mio::Interest::WRITABLE, ) .unwrap(); poller .register_custom( receiver.get_read_registration(), mio::Token(2), mio::Interest::READABLE, ) .unwrap(); poller.poll(Some(time::Duration::from_millis(0))).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(1)); assert_eq!(event.is_writable(), true); assert_eq!(it.next(), None); sender.try_send(42).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let v = receiver.try_recv().unwrap(); assert_eq!(v, 42); mem::drop(sender); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), mio::Token(2)); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); let e = receiver.try_recv().unwrap_err(); assert_eq!(e, mpsc::TryRecvError::Disconnected); } #[test] fn local_send_recv() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(receiver.try_recv(), Ok(1)); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); assert_eq!(sender1.try_send(2), Ok(())); let channel = sender2.channel.clone(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender2.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); assert_eq!(sender2.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); mem::drop(sender1); mem::drop(sender2); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Disconnected)); } #[test] fn local_send_disc() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); mem::drop(receiver); assert_eq!(sender.try_send(1), Err(mpsc::TrySendError::Disconnected(1))); } #[test] fn local_cancel() { let poller = event::Poller::new(6).unwrap(); let (sender1, receiver) = local_channel(1, 2, poller.local_registration_memory()); let sender2 = sender1 .try_clone(poller.local_registration_memory()) .unwrap(); let channel = sender2.channel.clone(); assert_eq!(sender1.try_send(1), Ok(())); assert_eq!(sender2.try_send(2), Err(mpsc::TrySendError::Full(2))); assert_eq!(sender1.try_send(3), Err(mpsc::TrySendError::Full(3))); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, true ); sender2.cancel(); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, true ); assert_eq!( channel.senders.borrow().nodes[sender2.key].value.notified, false ); assert_eq!(sender1.try_send(3), Ok(())); assert_eq!( channel.senders.borrow().nodes[sender1.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(3)); } #[test] fn local_check_send() { let poller = event::Poller::new(4).unwrap(); let (sender, receiver) = local_channel(1, 1, poller.local_registration_memory()); assert_eq!(receiver.try_recv(), Err(mpsc::TryRecvError::Empty)); let channel = sender.channel.clone(); assert_eq!(sender.check_send(), true); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.try_send(1), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(sender.check_send(), false); assert_eq!(channel.senders.borrow().waiting.is_empty(), false); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(1)); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, true ); assert_eq!(sender.try_send(2), Ok(())); assert_eq!(channel.senders.borrow().waiting.is_empty(), true); assert_eq!( channel.senders.borrow().nodes[sender.key].value.notified, false ); assert_eq!(receiver.try_recv(), Ok(2)); } #[test] fn async_send_bound0() { let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = channel::(0); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { s.send(1).await.unwrap(); assert_eq!(s.is_writable(), false); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_send_bound1() { let reactor = Reactor::new(2); let executor = Executor::new(1); let (s, r) = channel::(1); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { s.send(1).await.unwrap(); assert_eq!(s.is_writable(), true); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), false); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_recv() { let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = channel::(0); let s = AsyncSender::new(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { s.send(1).await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_writable() { let reactor = Reactor::new(1); let executor = Executor::new(1); let (s, r) = channel::(0); let s = AsyncSender::new(s); executor .spawn(async move { assert_eq!(s.is_writable(), false); s.wait_writable().await; }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); // attempting to receive on a rendezvous channel will make the // sender writable assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_local_channel() { let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); let s = AsyncLocalSender::new(s); let r = AsyncLocalReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); executor .spawn(async move { s.send(1).await.unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_check_send_sequential() { // create two instances and await them sequentially let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); let state = Rc::new(Cell::new(0)); { let state = state.clone(); executor .spawn(async move { let s = AsyncLocalSender::new(s); // fill the queue s.send(1).await.unwrap(); state.set(1); // create two instances and await them sequentially let fut1 = s.check_send(); let fut2 = s.check_send(); fut1.await; s.send(2).await.unwrap(); state.set(2); fut2.await; state.set(3); }) .unwrap(); } reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 1); assert_eq!(r.try_recv(), Ok(1)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 2); assert_eq!(r.try_recv(), Ok(2)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 3); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_check_send_alternating() { // create one instance, poll it to pending, then create a second // instance and poll it to completion, then poll the first again let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); let state = Rc::new(Cell::new(0)); { let state = state.clone(); executor .spawn(async move { let s = AsyncLocalSender::new(s); // fill the queue s.send(1).await.unwrap(); // create one instance let mut fut1 = s.check_send(); // poll it to pending assert_eq!(poll_async(&mut fut1).await, Poll::Pending); state.set(1); // create a second instance and poll it to completion s.check_send().await; s.send(2).await.unwrap(); state.set(2); // poll the first again fut1.await; state.set(3); }) .unwrap(); } reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 1); assert_eq!(r.try_recv(), Ok(1)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 2); assert_eq!(r.try_recv(), Ok(2)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 3); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_wait_sendable_sequential() { // create two instances and await them sequentially let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); let state = Rc::new(Cell::new(0)); { let state = state.clone(); executor .spawn(async move { let s = AsyncLocalSender::new(s); // fill the queue s.send(1).await.unwrap(); state.set(1); // create two instances and await them sequentially let fut1 = s.wait_sendable(); let fut2 = s.wait_sendable(); fut1.await.try_send(2).unwrap(); state.set(2); fut2.await; state.set(3); }) .unwrap(); } reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 1); assert_eq!(r.try_recv(), Ok(1)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 2); assert_eq!(r.try_recv(), Ok(2)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 3); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_wait_sendable_alternating() { // create one instance, poll it to pending, then create a second // instance and poll it to completion, then poll the first again let reactor = Reactor::new(2); let executor = Executor::new(2); let (s, r) = local_channel::(1, 1, &reactor.local_registration_memory()); let state = Rc::new(Cell::new(0)); { let state = state.clone(); executor .spawn(async move { let s = AsyncLocalSender::new(s); // fill the queue s.send(1).await.unwrap(); // create one instance let mut fut1 = s.wait_sendable(); // poll it to pending assert!(poll_async(&mut fut1).await.is_pending()); state.set(1); // create a second instance and poll it to completion s.wait_sendable().await.try_send(2).unwrap(); state.set(2); // poll the first again fut1.await; state.set(3); }) .unwrap(); } reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 1); assert_eq!(r.try_recv(), Ok(1)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 2); assert_eq!(r.try_recv(), Ok(2)); assert_eq!(r.try_recv(), Err(mpsc::TryRecvError::Empty)); reactor.poll_nonblocking(reactor.now()).unwrap(); executor.run_until_stalled(); assert_eq!(state.get(), 3); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn budget_unlimited() { let reactor = Reactor::new(1); let executor = Executor::new(1); let (s, r) = channel::(3); s.send(1).unwrap(); s.send(2).unwrap(); s.send(3).unwrap(); mem::drop(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Ok(2)); assert_eq!(r.recv().await, Ok(3)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); let mut park_count = 0; executor .run(|timeout| { park_count += 1; reactor.poll(timeout) }) .unwrap(); assert_eq!(park_count, 0); } #[test] fn budget_1() { let reactor = Reactor::new(1); let executor = Executor::new(1); { let reactor = reactor.clone(); executor.set_pre_poll(move || { reactor.set_budget(Some(1)); }); } let (s, r) = channel::(3); s.send(1).unwrap(); s.send(2).unwrap(); s.send(3).unwrap(); mem::drop(s); let r = AsyncReceiver::new(r); executor .spawn(async move { assert_eq!(r.recv().await, Ok(1)); assert_eq!(r.recv().await, Ok(2)); assert_eq!(r.recv().await, Ok(3)); assert_eq!(r.recv().await, Err(mpsc::RecvError)); }) .unwrap(); let mut park_count = 0; executor .run(|timeout| { park_count += 1; reactor.poll(timeout) }) .unwrap(); assert_eq!(park_count, 3); } } pushpin-1.41.0/src/core/config.cpp000066400000000000000000000021431504671364300167640ustar00rootroot00000000000000/* * Copyright (C) 2023-2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "config.h" #include "rust/bindings.h" namespace Config { static thread_local Config *g_config = 0; Config & get() { if(!g_config) { Config *c = new Config; ffi::BuildConfig *bc = ffi::build_config_new(); c->version = QString(bc->version); c->configDir = QString(bc->config_dir); c->libDir = QString(bc->lib_dir); ffi::build_config_destroy(bc); g_config = c; } return *g_config; } } pushpin-1.41.0/src/core/config.h000066400000000000000000000016231504671364300164330ustar00rootroot00000000000000/* * Copyright (C) 2023-2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef CONFIG_H #define CONFIG_H #include namespace Config { class Config { public: QString version; QString configDir; QString libDir; }; // value is thread local Config & get(); } #endif pushpin-1.41.0/src/core/config.rs000066400000000000000000000467641504671364300166470ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use config::{Config, ConfigError}; use serde::Deserialize; use std::env; use std::error::Error; use std::path::{Path, PathBuf}; #[cfg(not(test))] use config::File; #[derive(Debug, Deserialize, Default)] pub struct Global { pub include: String, pub rundir: String, pub libdir: String, pub ipc_prefix: String, pub port_offset: i32, pub stats_connection_ttl: i32, pub stats_connection_send: bool, } impl From for config::ValueKind { fn from(global: Global) -> Self { let mut properties = std::collections::HashMap::new(); properties.insert("include".to_string(), config::Value::from(global.include)); properties.insert("rundir".to_string(), config::Value::from(global.rundir)); properties.insert("libdir".to_string(), config::Value::from(global.libdir)); properties.insert( "ipc_prefix".to_string(), config::Value::from(global.ipc_prefix), ); properties.insert( "port_offset".to_string(), config::Value::from(global.port_offset), ); properties.insert( "stats_connection_ttl".to_string(), config::Value::from(global.stats_connection_ttl), ); properties.insert( "stats_connection_send".to_string(), config::Value::from(global.stats_connection_send), ); Self::Table(properties) } } #[derive(serde::Deserialize, Eq, PartialEq, Debug, Default)] pub struct Runner { //runner rundir is deprecated pub rundir: String, pub services: String, pub http_port: String, pub https_ports: String, pub local_ports: String, pub logdir: String, pub log_level: String, pub client_buffer_size: i32, pub client_maxconn: i32, pub allow_compression: bool, } impl From for config::ValueKind { fn from(runner: Runner) -> Self { let mut properties = std::collections::HashMap::new(); properties.insert("rundir".to_string(), config::Value::from(runner.rundir)); properties.insert("services".to_string(), config::Value::from(runner.services)); properties.insert( "http_port".to_string(), config::Value::from(runner.http_port), ); properties.insert( "https_ports".to_string(), config::Value::from(runner.https_ports), ); properties.insert( "local_ports".to_string(), config::Value::from(runner.local_ports), ); properties.insert("logdir".to_string(), config::Value::from(runner.logdir)); properties.insert( "log_level".to_string(), config::Value::from(runner.log_level), ); properties.insert("client_buffer_size".to_string(), config::Value::from(8192)); properties.insert("client_maxconn".to_string(), config::Value::from(50000)); properties.insert( "allow_compression".to_string(), config::Value::from(runner.allow_compression), ); Self::Table(properties) } } #[derive(Debug, Deserialize, Default)] pub struct Proxy { pub routesfile: String, pub debug: bool, pub auto_cross_origin: bool, pub accept_x_forwarded_protocol: bool, pub set_x_forwarded_protocol: String, pub x_forwarded_for: String, pub x_forwarded_for_trusted: String, pub orig_headers_need_mark: String, pub accept_pushpin_route: bool, pub cdn_loop: String, pub log_from: bool, pub log_user_agent: bool, pub sig_iss: String, pub sig_key: String, pub upstream_key: String, pub sockjs_url: String, pub updates_check: String, pub organization_name: String, } impl From for config::ValueKind { fn from(proxy: Proxy) -> Self { let mut properties = std::collections::HashMap::new(); properties.insert( "routesfile".to_string(), config::Value::from(proxy.routesfile), ); properties.insert("debug".to_string(), config::Value::from(proxy.debug)); properties.insert( "auto_cross_origin".to_string(), config::Value::from(proxy.auto_cross_origin), ); properties.insert( "accept_x_forwarded_protocol".to_string(), config::Value::from(proxy.accept_x_forwarded_protocol), ); properties.insert( "set_x_forwarded_protocol".to_string(), config::Value::from(proxy.set_x_forwarded_protocol), ); properties.insert( "x_forwarded_for".to_string(), config::Value::from(proxy.x_forwarded_for), ); properties.insert( "x_forwarded_for_trusted".to_string(), config::Value::from(proxy.x_forwarded_for_trusted), ); properties.insert( "orig_headers_need_mark".to_string(), config::Value::from(proxy.orig_headers_need_mark), ); properties.insert( "accept_pushpin_route".to_string(), config::Value::from(proxy.accept_pushpin_route), ); properties.insert("cdn_loop".to_string(), config::Value::from(proxy.cdn_loop)); properties.insert("log_from".to_string(), config::Value::from(proxy.log_from)); properties.insert( "log_user_agent".to_string(), config::Value::from(proxy.log_user_agent), ); properties.insert("sig_iss".to_string(), config::Value::from(proxy.sig_iss)); properties.insert("sig_key".to_string(), config::Value::from(proxy.sig_key)); properties.insert( "upstream_key".to_string(), config::Value::from(proxy.upstream_key), ); properties.insert( "sockjs_url".to_string(), config::Value::from(proxy.sockjs_url), ); properties.insert( "updates_check".to_string(), config::Value::from(proxy.updates_check), ); properties.insert( "organization_name".to_string(), config::Value::from(proxy.organization_name), ); Self::Table(properties) } } #[derive(Debug, Deserialize, Default)] pub struct Handler { pub ipc_file_mode: u16, pub push_in_spec: String, pub push_in_sub_specs: String, pub push_in_sub_connect: bool, pub push_in_http_addr: String, pub push_in_http_port: u16, pub push_in_http_max_headers_size: u32, pub push_in_http_max_body_size: u32, pub stats_spec: String, pub command_spec: String, pub message_rate: u32, pub message_hwm: u32, pub message_block_size: u32, pub message_wait: u32, pub id_cache_ttl: u32, pub connection_subscription_max: u32, pub subscription_linger: u32, pub stats_subscription_ttl: u32, pub stats_report_interval: u32, pub stats_format: String, pub prometheus_port: String, pub prometheus_prefix: String, } impl From for config::ValueKind { fn from(handler: Handler) -> Self { let mut properties = std::collections::HashMap::new(); properties.insert( "ipc_file_mode".to_string(), config::Value::from(handler.ipc_file_mode), ); properties.insert( "push_in_spec".to_string(), config::Value::from(handler.push_in_spec), ); properties.insert( "push_in_sub_specs".to_string(), config::Value::from(handler.push_in_sub_specs), ); properties.insert( "push_in_sub_connect".to_string(), config::Value::from(handler.push_in_sub_connect), ); properties.insert( "push_in_http_addr".to_string(), config::Value::from(handler.push_in_http_addr), ); properties.insert( "push_in_http_port".to_string(), config::Value::from(handler.push_in_http_port), ); properties.insert( "push_in_http_max_headers_size".to_string(), config::Value::from(handler.push_in_http_max_headers_size), ); properties.insert( "push_in_http_max_body_size".to_string(), config::Value::from(handler.push_in_http_max_body_size), ); properties.insert( "stats_spec".to_string(), config::Value::from(handler.stats_spec), ); properties.insert( "command_spec".to_string(), config::Value::from(handler.command_spec), ); properties.insert( "message_rate".to_string(), config::Value::from(handler.message_rate), ); properties.insert( "message_hwm".to_string(), config::Value::from(handler.message_hwm), ); properties.insert( "message_block_size".to_string(), config::Value::from(handler.message_block_size), ); properties.insert( "message_wait".to_string(), config::Value::from(handler.message_wait), ); properties.insert( "id_cache_ttl".to_string(), config::Value::from(handler.id_cache_ttl), ); properties.insert( "connection_subscription_max".to_string(), config::Value::from(handler.connection_subscription_max), ); properties.insert( "subscription_linger".to_string(), config::Value::from(handler.subscription_linger), ); properties.insert( "stats_subscription_ttl".to_string(), config::Value::from(handler.stats_subscription_ttl), ); properties.insert( "stats_report_interval".to_string(), config::Value::from(handler.stats_report_interval), ); properties.insert( "stats_format".to_string(), config::Value::from(handler.stats_format), ); properties.insert( "prometheus_port".to_string(), config::Value::from(handler.prometheus_port), ); properties.insert( "prometheus_prefix".to_string(), config::Value::from(handler.prometheus_prefix), ); Self::Table(properties) } } #[derive(Debug, Deserialize, Default)] pub struct CustomConfig { pub global: Global, pub runner: Runner, pub proxy: Proxy, pub handler: Handler, } impl CustomConfig { #[cfg(not(test))] pub fn new(config_file: &str) -> Result { let config = Config::builder() .add_source(File::with_name(config_file).format(config::FileFormat::Ini)) .set_default("global", Global::default())? .set_default("runner", Runner::default())? .set_default("proxy", Proxy::default())? .set_default("handler", Handler::default())? .build()?; config.try_deserialize() } #[cfg(test)] pub fn new(_config_file: &str) -> Result { let config = Config::builder() .set_default( "global", Global { include: String::from("{libdir}/internal.conf"), rundir: String::from("run"), ipc_prefix: String::from("pushpin-"), port_offset: 0, stats_connection_ttl: 120, stats_connection_send: true, libdir: String::new(), }, )? .set_default( "runner", Runner { rundir: String::new(), services: String::from("connmgr,proxy,handler"), http_port: String::from("7999"), https_ports: String::from("443"), local_ports: String::from("{rundir}/{ipc_prefix}server"), logdir: String::from("log"), log_level: String::from("2"), client_buffer_size: 8192, client_maxconn: 50000, allow_compression: false, }, )? .set_default( "proxy", Proxy { routesfile: String::from("routes"), debug: false, auto_cross_origin: false, accept_x_forwarded_protocol: false, set_x_forwarded_protocol: String::from("proto-only"), x_forwarded_for: String::new(), x_forwarded_for_trusted: String::new(), orig_headers_need_mark: String::new(), accept_pushpin_route: false, cdn_loop: String::new(), log_from: false, log_user_agent: false, sig_iss: String::from("pushpin"), sig_key: String::from("changeme"), upstream_key: String::new(), sockjs_url: String::from("http://cdn.jsdelivr.net/sockjs/0.3.4/sockjs.min.js"), updates_check: String::from("report"), organization_name: String::new(), }, )? .set_default( "handler", Handler { ipc_file_mode: 777, push_in_spec: String::from("tcp://127.0.0.1:5560"), push_in_sub_specs: String::from("tcp://127.0.0.1:5562"), push_in_sub_connect: false, push_in_http_addr: String::from("127.0.0.1"), push_in_http_port: 5561, push_in_http_max_headers_size: 10000, push_in_http_max_body_size: 1000000, stats_spec: String::from("ipc://{rundir}/{ipc_prefix}stats"), command_spec: String::from("tcp://127.0.0.1:5563"), message_rate: 2500, message_hwm: 25000, message_block_size: 0, message_wait: 5000, id_cache_ttl: 60, connection_subscription_max: 20, subscription_linger: 60, stats_subscription_ttl: 60, stats_report_interval: 10, stats_format: String::from("tnetstring"), prometheus_port: String::new(), prometheus_prefix: String::new(), }, )? .build()?; config.try_deserialize() } } pub fn get_config_file( work_dir: &Path, arg_config: Option, ) -> Result> { let mut config_files: Vec = vec![]; match arg_config { Some(x) => config_files.push(x), None => { // ./config config_files.push(work_dir.join("config").join("pushpin.conf")); // same dir as executable (NOTE: deprecated) config_files.push(work_dir.join("pushpin.conf")); // ./examples/config config_files.push( work_dir .join("examples") .join("config") .join("pushpin.conf"), ); // default config_files.push(PathBuf::from(format!( "{}/pushpin.conf", env!("CONFIG_DIR") ))); } } let mut config_file = ""; for cf in config_files.iter() { if cf.is_file() { config_file = cf.to_str().unwrap_or(""); break; } } if config_file.is_empty() { return Err(format!( "no configuration file found. Tried: {}", config_files .iter() .map(|path_buf| path_buf.display().to_string()) .collect::>() .join(" ") ) .into()); } match Path::new(config_file).try_exists() { Ok(true) => {} Ok(false) => { return Err(format!("failed to open {}", config_file).into()); } Err(e) => { return Err(format!("failed to open {}, with error: {:?}", config_file, e).into()); } } Ok(config_file.into()) } mod ffi { use crate::core::version; use std::env; use std::ffi::CString; #[repr(C)] pub struct BuildConfig { version: *mut libc::c_char, config_dir: *mut libc::c_char, lib_dir: *mut libc::c_char, } #[no_mangle] pub extern "C" fn build_config_new() -> *mut BuildConfig { let lib_dir = env!("LIB_DIR"); let config_dir = env!("CONFIG_DIR"); let c = BuildConfig { version: CString::new(version()).unwrap().into_raw(), config_dir: CString::new(config_dir).unwrap().into_raw(), lib_dir: CString::new(lib_dir).unwrap().into_raw(), }; Box::into_raw(Box::new(c)) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn build_config_destroy(c: *mut BuildConfig) { let c = match c.as_mut() { Some(c) => Box::from_raw(c), None => return, }; drop(CString::from_raw(c.version)); drop(CString::from_raw(c.config_dir)); drop(CString::from_raw(c.lib_dir)); } } #[cfg(test)] mod tests { use super::*; use crate::core::{ensure_example_config, test_dir}; use std::error::Error; use std::path::PathBuf; struct TestArgs { name: &'static str, work_dir: PathBuf, input: Option, output: Result>, } #[test] fn it_works() { let test_dir = test_dir(); ensure_example_config(&test_dir); let test_args: Vec = vec![TestArgs { name: "no input", work_dir: test_dir.clone(), input: None, output: Ok(test_dir .join("examples") .join("config") .join("pushpin.conf")), }]; for test_arg in test_args.iter() { assert_eq!( get_config_file(&test_arg.work_dir, test_arg.input.clone()).unwrap(), test_arg.output.as_deref().unwrap(), "{}", test_arg.name ); } } #[test] fn it_fails() { let test_args: Vec = vec![TestArgs { name: "invalid config file", work_dir: test_dir(), input: Some(PathBuf::from("no/such/file")), output: Err("no configuration file found. Tried: no/such/file".into()), }]; for test_arg in test_args.iter() { match get_config_file(&test_arg.work_dir, test_arg.input.clone()) { Ok(x) => panic!( "Test case {} should fail, but its passing with this output {:?}", test_arg.name, x ), Err(e) => { assert_eq!( e.to_string(), test_arg.output.as_deref().unwrap_err().to_string() ); } } } } } pushpin-1.41.0/src/core/core.pri000066400000000000000000000046731504671364300164710ustar00rootroot00000000000000HEADERS += \ $$PWD/qzmqcontext.h \ $$PWD/qzmqsocket.h \ $$PWD/qzmqvalve.h \ $$PWD/qzmqreqmessage.h \ $$PWD/qzmqreprouter.h SOURCES += \ $$PWD/qzmqcontext.cpp \ $$PWD/qzmqsocket.cpp \ $$PWD/qzmqvalve.cpp \ $$PWD/qzmqreprouter.cpp HEADERS += $$PWD/processquit.h SOURCES += $$PWD/processquit.cpp HEADERS += \ $$PWD/test.h \ $$PWD/tnetstring.h \ $$PWD/httpheaders.h \ $$PWD/zhttprequestpacket.h \ $$PWD/zhttpresponsepacket.h \ $$PWD/log.h \ $$PWD/bufferlist.h \ $$PWD/layertracker.h SOURCES += \ $$PWD/test.cpp \ $$PWD/tnetstring.cpp \ $$PWD/httpheaders.cpp \ $$PWD/zhttprequestpacket.cpp \ $$PWD/zhttpresponsepacket.cpp \ $$PWD/log.cpp \ $$PWD/bufferlist.cpp \ $$PWD/layertracker.cpp HEADERS += \ $$PWD/packet/httprequestdata.h \ $$PWD/packet/httpresponsedata.h \ $$PWD/packet/retryrequestpacket.h \ $$PWD/packet/wscontrolpacket.h \ $$PWD/packet/statspacket.h \ $$PWD/packet/zrpcrequestpacket.h \ $$PWD/packet/zrpcresponsepacket.h SOURCES += \ $$PWD/packet/retryrequestpacket.cpp \ $$PWD/packet/wscontrolpacket.cpp \ $$PWD/packet/statspacket.cpp \ $$PWD/packet/zrpcrequestpacket.cpp \ $$PWD/packet/zrpcresponsepacket.cpp HEADERS += \ $$PWD/callback.h \ $$PWD/config.h \ $$PWD/timerwheel.h \ $$PWD/jwt.h \ $$PWD/timer.h \ $$PWD/defercall.h \ $$PWD/socketnotifier.h \ $$PWD/event.h \ $$PWD/eventloop.h \ $$PWD/readwrite.h \ $$PWD/tcplistener.h \ $$PWD/tcpstream.h \ $$PWD/unixlistener.h \ $$PWD/unixstream.h \ $$PWD/filewatcher.h \ $$PWD/logutil.h \ $$PWD/uuidutil.h \ $$PWD/zutil.h \ $$PWD/httprequest.h \ $$PWD/websocket.h \ $$PWD/zhttpmanager.h \ $$PWD/zhttprequest.h \ $$PWD/zwebsocket.h \ $$PWD/zrpcmanager.h \ $$PWD/zrpcrequest.h \ $$PWD/statusreasons.h \ $$PWD/inspectdata.h \ $$PWD/cors.h \ $$PWD/simplehttpserver.h \ $$PWD/stats.h \ $$PWD/statsmanager.h \ $$PWD/settings.h SOURCES += \ $$PWD/config.cpp \ $$PWD/timerwheel.cpp \ $$PWD/jwt.cpp \ $$PWD/timer.cpp \ $$PWD/defercall.cpp \ $$PWD/socketnotifier.cpp \ $$PWD/event.cpp \ $$PWD/eventloop.cpp \ $$PWD/tcplistener.cpp \ $$PWD/tcpstream.cpp \ $$PWD/unixlistener.cpp \ $$PWD/unixstream.cpp \ $$PWD/filewatcher.cpp \ $$PWD/logutil.cpp \ $$PWD/uuidutil.cpp \ $$PWD/zutil.cpp \ $$PWD/zhttpmanager.cpp \ $$PWD/zhttprequest.cpp \ $$PWD/zwebsocket.cpp \ $$PWD/zrpcmanager.cpp \ $$PWD/zrpcrequest.cpp \ $$PWD/statusreasons.cpp \ $$PWD/cors.cpp \ $$PWD/simplehttpserver.cpp \ $$PWD/stats.cpp \ $$PWD/statsmanager.cpp \ $$PWD/settings.cpp pushpin-1.41.0/src/core/cors.cpp000066400000000000000000000064231504671364300164720ustar00rootroot00000000000000/* * Copyright (C) 2015 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "cors.h" #include "httpheaders.h" namespace Cors { static bool isSimpleHeader(const QByteArray &in) { return (qstricmp(in.data(), "Cache-Control") == 0 || qstricmp(in.data(), "Content-Language") == 0 || qstricmp(in.data(), "Content-Length") == 0 || qstricmp(in.data(), "Content-Type") == 0 || qstricmp(in.data(), "Expires") == 0 || qstricmp(in.data(), "Last-Modified") == 0 || qstricmp(in.data(), "Pragma") == 0); } static bool headerNamesContains(const QList &names, const QByteArray &name) { foreach(const QByteArray &i, names) { if(qstricmp(name.data(), i.data()) == 0) return true; } return false; } static bool headerNameStartsWith(const QByteArray &name, const char *value) { return (qstrnicmp(name.data(), value, qstrlen(value)) == 0); } void applyCorsHeaders(const HttpHeaders &requestHeaders, HttpHeaders *responseHeaders) { if(!responseHeaders->contains("Access-Control-Allow-Methods")) { QByteArray method = requestHeaders.get("Access-Control-Request-Method"); if(!method.isEmpty()) *responseHeaders += HttpHeader("Access-Control-Allow-Methods", method); else *responseHeaders += HttpHeader("Access-Control-Allow-Methods", "OPTIONS, HEAD, GET, POST, PUT, DELETE"); } if(!responseHeaders->contains("Access-Control-Allow-Headers")) { QList allowHeaders; foreach(const QByteArray &h, requestHeaders.getAll("Access-Control-Request-Headers", true)) { if(!h.isEmpty()) allowHeaders += h; } if(!allowHeaders.isEmpty()) *responseHeaders += HttpHeader("Access-Control-Allow-Headers", HttpHeaders::join(allowHeaders)); } if(!responseHeaders->contains("Access-Control-Expose-Headers")) { QList exposeHeaders; foreach(const HttpHeader &h, *responseHeaders) { if(!isSimpleHeader(h.first) && !headerNameStartsWith(h.first, "Access-Control-") && !headerNameStartsWith(h.first, "Grip-") && !headerNamesContains(exposeHeaders, h.first)) exposeHeaders += h.first; } if(!exposeHeaders.isEmpty()) *responseHeaders += HttpHeader("Access-Control-Expose-Headers", HttpHeaders::join(exposeHeaders)); } if(!responseHeaders->contains("Access-Control-Allow-Credentials")) *responseHeaders += HttpHeader("Access-Control-Allow-Credentials", "true"); if(!responseHeaders->contains("Access-Control-Allow-Origin")) { QByteArray origin = requestHeaders.get("Origin"); if(origin.isEmpty()) origin = "*"; *responseHeaders += HttpHeader("Access-Control-Allow-Origin", origin); } if(!responseHeaders->contains("Access-Control-Max-Age")) *responseHeaders += HttpHeader("Access-Control-Max-Age", "3600"); } } pushpin-1.41.0/src/core/cors.h000066400000000000000000000015451504671364300161370ustar00rootroot00000000000000/* * Copyright (C) 2015 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef CORS_H #define CORS_H class HttpHeaders; namespace Cors { void applyCorsHeaders(const HttpHeaders &requestHeaders, HttpHeaders *responseHeaders); } #endif pushpin-1.41.0/src/core/defer.rs000066400000000000000000000015541504671364300164530ustar00rootroot00000000000000/* * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ pub struct Defer { f: Option, } impl Defer { pub fn new(f: T) -> Self { Self { f: Some(f) } } } impl Drop for Defer { fn drop(&mut self) { let f = self.f.take().unwrap(); f(); } } pushpin-1.41.0/src/core/defercall.cpp000066400000000000000000000151341504671364300174440ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "defercall.h" #include #include #include #include "timer.h" #include "event.h" #include "eventloop.h" namespace { class ThreadWake : public QObject { Q_OBJECT public: ThreadWake() : customRegId_(-1), wakeQueued_(false) { EventLoop *loop = EventLoop::instance(); if(loop) { auto [regId, sr] = loop->registerCustom(ThreadWake::cb_ready, this); assert(regId >= 0); customRegId_ = regId; sr_ = std::move(sr); } } ~ThreadWake() { if(customRegId_ >= 0) EventLoop::instance()->deregister(customRegId_); } // requests the awake signal to be emitted from the object's event loop // at the next opportunity. this is safe to call from another thread. if // this is called multiple times before the signal is emitted, the signal // will only be emitted once. void wake() { std::lock_guard guard(mutex_); if(wakeQueued_) return; wakeQueued_ = true; if(sr_) sr_->setReadiness(Event::Readable); else QMetaObject::invokeMethod(this, "slotReady", Qt::QueuedConnection); } boost::signals2::signal awake; private slots: void slotReady() { ready(); } private: int customRegId_; std::unique_ptr sr_; std::mutex mutex_; bool wakeQueued_; static void cb_ready(void *ctx, uint8_t readiness) { Q_UNUSED(readiness); ThreadWake *self = (ThreadWake *)ctx; self->ready(); } void ready() { { std::lock_guard guard(mutex_); wakeQueued_ = false; } awake(); } }; } class DeferCall::Manager { public: Manager() : thread_(std::this_thread::get_id()) { timer_.setSingleShot(true); timer_.timeout.connect(boost::bind(&Manager::timer_timeout, this)); threadWake_.awake.connect(boost::bind(&Manager::threadWake_awake, this)); } void add(const std::weak_ptr &c) { std::lock_guard guard(callsMutex_); calls_.push_back(c); if(std::this_thread::get_id() == thread_) { if(!timer_.isActive()) timer_.start(0); } else { threadWake_.wake(); } } void flush() { while(!isCallsEmpty()) process(); } private: std::thread::id thread_; Timer timer_; ThreadWake threadWake_; std::mutex callsMutex_; std::list> calls_; // thread-safe bool isCallsEmpty() { std::lock_guard guard(callsMutex_); return calls_.empty(); } // thread-safe void process() { std::list> ready; // lock to take list { std::lock_guard guard(callsMutex_); // process all calls queued so far, but not any that may get queued // during processing ready.swap(calls_); } // process list while not locked for(auto c : ready) { if(auto p = c.lock()) { auto source = p->source.lock(); // if call is valid then its source will be too assert(source); source->erase(p->sourceElement); p->handler(); } } } void timer_timeout() { process(); // no need to re-arm the timer. if new calls were queued during // processing, add() will have taken care of that } void threadWake_awake() { process(); } }; std::list>::size_type DeferCall::CallsList::size() const { std::lock_guard guard(mutex); return l.size(); } std::list>::iterator DeferCall::CallsList::append(const std::shared_ptr &c) { std::lock_guard guard(mutex); l.push_back(c); // get an iterator to the element that was pushed auto it = l.end(); --it; return it; } void DeferCall::CallsList::erase(std::list>::iterator position) { std::lock_guard guard(mutex); l.erase(position); } DeferCall::DeferCall() : thread_(std::this_thread::get_id()), deferredCalls_(std::make_shared()) { if(!localManager) { localManager = std::make_shared(); std::lock_guard guard(managerByThreadMutex); managerByThread[thread_] = localManager; EventLoop *loop = EventLoop::instance(); if(loop) { // we use the manager pointer to uniquely identify the handler // registration even though the handler function doesn't do // anything with it loop->addCleanupHandler(eventloop_cleanup_handler, localManager.get()); } } } DeferCall::~DeferCall() = default; void DeferCall::defer(std::function handler) { std::shared_ptr c = std::make_shared(); c->handler = handler; c->source = deferredCalls_; c->sourceElement = deferredCalls_->append(c); Manager *manager = localManager.get(); if(std::this_thread::get_id() != thread_) { std::lock_guard guard(managerByThreadMutex); auto it = managerByThread.find(thread_); assert(it != managerByThread.end()); manager = it->second.get(); } // manager keeps a weak pointer, so we can invalidate pending calls by // simply deleting them manager->add(c); } DeferCall *DeferCall::global() { if(!localInstance) localInstance = std::make_unique(); return localInstance.get(); } void DeferCall::cleanup() { if(localManager) localManager->flush(); localInstance.reset(); if(localManager) { EventLoop *loop = EventLoop::instance(); if(loop) loop->removeCleanupHandler(eventloop_cleanup_handler, localManager.get()); std::lock_guard guard(managerByThreadMutex); managerByThread.erase(std::this_thread::get_id()); localManager.reset(); } } void DeferCall::eventloop_cleanup_handler(void *) { cleanup(); } thread_local std::shared_ptr DeferCall::localManager = std::shared_ptr(); thread_local std::unique_ptr DeferCall::localInstance = std::unique_ptr(); std::unordered_map> DeferCall::managerByThread = std::unordered_map>(); std::mutex DeferCall::managerByThreadMutex = std::mutex(); #include "defercall.moc" pushpin-1.41.0/src/core/defercall.h000066400000000000000000000047741504671364300171210ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef DEFERCALL_H #define DEFERCALL_H #include #include #include #include #include #include // queues calls to be run after returning to the event loop class DeferCall { public: DeferCall(); ~DeferCall(); // queue handler to be called after returning to the event loop. if // handler contains references, they must outlive DeferCall. the // recommended usage is for each object needing to perform deferred calls // to keep a DeferCall as a member variable, and only refer to the // object's own data in the handler. that way, any references are // guaranteed to live long enough. void defer(std::function handler); int pendingCount() const { return deferredCalls_->size(); } static DeferCall *global(); static void cleanup(); template static void deleteLater(T *p) { global()->defer([=] { delete p; }); } private: class Call; class CallsList { public: // all methods thread-safe std::list>::size_type size() const; std::list>::iterator append(const std::shared_ptr &c); void erase(std::list>::iterator position); private: mutable std::mutex mutex; std::list> l; }; class Call { public: std::function handler; std::weak_ptr source; std::list>::iterator sourceElement; }; class Manager; friend class Manager; std::thread::id thread_; std::shared_ptr deferredCalls_; static thread_local std::shared_ptr localManager; static thread_local std::unique_ptr localInstance; static std::unordered_map> managerByThread; static std::mutex managerByThreadMutex; static void eventloop_cleanup_handler(void *ctx); }; #endif pushpin-1.41.0/src/core/defercalltest.cpp000066400000000000000000000056441504671364300203510ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include #include "test.h" #include "defercall.h" #include "eventloop.h" // loop_advance should process enough events to cause the calls to run, // without sleeping, in order to prove the calls are run immediately static std::tuple runDeferCall(std::function loop_advance) { DeferCall deferCall; int count = 0; deferCall.defer([&] { ++count; deferCall.defer([&] { ++count; }); }); loop_advance(); return {deferCall.pendingCount(), count}; } // spawns a thread, triggers the deferCall from it, then waits for thread to // finish static void callNonLocal(DeferCall *deferCall, std::function handler) { std::thread thread([=] { deferCall->defer(handler); }); thread.join(); } // loop_advance should process enough events to cause the calls to run, // without sleeping, in order to prove the calls are run immediately static std::tuple runNonLocal(std::function loop_advance) { DeferCall deferCall; int count = 0; callNonLocal(&deferCall, [&] { ++count; }); loop_advance(); return {deferCall.pendingCount(), count}; } static void deferCall() { EventLoop loop(2); auto [pendingCount, count] = runDeferCall([&] { // run the first call and queue the second loop.step(); // run the second loop.step(); }); TEST_ASSERT_EQ(pendingCount, 0); TEST_ASSERT_EQ(count, 2); } static void nonLocal() { EventLoop loop(2); auto [pendingCount, count] = runNonLocal([&] { // run the first call loop.step(); }); TEST_ASSERT_EQ(pendingCount, 0); TEST_ASSERT_EQ(count, 1); } static void retract() { EventLoop loop(2); bool called = false; { DeferCall deferCall; deferCall.defer([&] { called = true; }); } DeferCall::cleanup(); TEST_ASSERT(!called); } static void managerCleanup() { EventLoop loop(2); int count = 0; DeferCall::global()->defer([&] { ++count; DeferCall::global()->defer([&] { ++count; }); }); // cleanup should process deferred calls queued so far as well as // those queued during processing DeferCall::cleanup(); TEST_ASSERT_EQ(count, 2); } extern "C" int defercall_test(ffi::TestException *out_ex) { TEST_CATCH(deferCall()); TEST_CATCH(nonLocal()); TEST_CATCH(retract()); TEST_CATCH(managerCleanup()); return 0; } pushpin-1.41.0/src/core/event.cpp000066400000000000000000000016111504671364300166370ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "event.h" namespace Event { SetReadiness::SetReadiness(ffi::SetReadiness *inner) : inner_(inner) { } SetReadiness::~SetReadiness() { ffi::set_readiness_destroy(inner_); } int SetReadiness::setReadiness(uint8_t readiness) { return ffi::set_readiness_set_readiness(inner_, readiness); } } pushpin-1.41.0/src/core/event.h000066400000000000000000000021741504671364300163110ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef EVENT_H #define EVENT_H #include "rust/bindings.h" class EventLoop; namespace Event { enum Interest { Readable = ffi::READABLE, Writable = ffi::WRITABLE, }; class SetReadiness { public: ~SetReadiness(); // disable copying SetReadiness(const SetReadiness &) = delete; SetReadiness & operator=(const SetReadiness &) = delete; // pass a non-zero set of Interest flags int setReadiness(uint8_t readiness); private: friend class ::EventLoop; SetReadiness(ffi::SetReadiness *inner); ffi::SetReadiness *inner_; }; } #endif pushpin-1.41.0/src/core/event.rs000066400000000000000000000555751504671364300165230ustar00rootroot00000000000000/* * Copyright (C) 2021-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::arena; use crate::core::list; use mio::event::Source; use mio::{Events, Interest, Poll, Token, Waker}; use slab::Slab; use std::cell::{Cell, RefCell}; use std::io; use std::rc::Rc; use std::sync::{Arc, Mutex}; use std::time::Duration; const EVENTS_MAX: usize = 1024; const LOCAL_BUDGET: u32 = 10; pub fn can_move_mio_sockets_between_threads() -> bool { // on unix platforms, mio always uses epoll or kqueue, which support // this. mio makes no guarantee about supporting this on non-unix // platforms cfg!(unix) } pub type Readiness = Option; pub trait ReadinessExt { fn contains_any(&self, readiness: Interest) -> bool; fn merge(&mut self, readiness: Interest); } impl ReadinessExt for Readiness { fn contains_any(&self, readiness: Interest) -> bool { match *self { Some(cur) => { (readiness.is_readable() && cur.is_readable()) || (readiness.is_writable() && cur.is_writable()) } None => false, } } fn merge(&mut self, readiness: Interest) { match *self { Some(cur) => *self = Some(cur.add(readiness)), None => *self = Some(readiness), } } } struct SourceItem { subtoken: Token, interests: Interest, readiness: Readiness, } struct RegisteredSources { nodes: Slab>, ready: list::List, } struct LocalSources { registered_sources: RefCell, } impl LocalSources { fn new(max_sources: usize) -> Self { Self { registered_sources: RefCell::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.borrow_mut(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !((item.interests.is_readable() && readiness.is_readable()) || (item.interests.is_writable() && readiness.is_writable())) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.borrow(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.borrow_mut(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct SyncSources { registered_sources: Mutex, waker: Waker, } impl SyncSources { fn new(max_sources: usize, waker: Waker) -> Self { Self { registered_sources: Mutex::new(RegisteredSources { nodes: Slab::with_capacity(max_sources), ready: list::List::default(), }), waker, } } fn register(&self, subtoken: Token, interests: Interest) -> Result { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.len() == sources.nodes.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } Ok(sources.nodes.insert(list::Node::new(SourceItem { subtoken, interests, readiness: None, }))) } fn deregister(&self, key: usize) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if sources.nodes.contains(key) { sources.ready.remove(&mut sources.nodes, key); sources.nodes.remove(key); } Ok(()) } fn set_readiness(&self, key: usize, readiness: Interest) -> Result<(), io::Error> { let sources = &mut *self.registered_sources.lock().unwrap(); if !sources.nodes.contains(key) { return Err(io::Error::from(io::ErrorKind::NotFound)); } let item = &mut sources.nodes[key].value; if !((item.interests.is_readable() && readiness.is_readable()) || (item.interests.is_writable() && readiness.is_writable())) { // not of interest return Ok(()); } let orig = item.readiness; item.readiness.merge(readiness); if item.readiness != orig { let need_wake = sources.ready.is_empty(); sources.ready.remove(&mut sources.nodes, key); sources.ready.push_back(&mut sources.nodes, key); if need_wake { self.waker.wake()?; } } Ok(()) } fn has_events(&self) -> bool { let sources = &*self.registered_sources.lock().unwrap(); !sources.ready.is_empty() } fn next_event(&self) -> Option<(Token, Interest)> { let sources = &mut *self.registered_sources.lock().unwrap(); match sources.ready.pop_front(&mut sources.nodes) { Some(key) => { let item = &mut sources.nodes[key].value; let readiness = item.readiness.take().unwrap(); Some((item.subtoken, readiness)) } None => None, } } } struct CustomSources { local: Rc, sync: Arc, next_local_only: Cell, } impl CustomSources { fn new(poll: &Poll, token: Token, max_sources: usize) -> Result { let waker = Waker::new(poll.registry(), token)?; Ok(Self { local: Rc::new(LocalSources::new(max_sources)), sync: Arc::new(SyncSources::new(max_sources, waker)), next_local_only: Cell::new(false), }) } fn set_next_local_only(&self, enabled: bool) { self.next_local_only.set(enabled); } fn register_local( &self, registration: &LocalRegistration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if reg.data.is_none() { let key = self.local.register(subtoken, interests)?; reg.data = Some((key, self.local.clone())); if let Some(readiness) = reg.readiness { self.local.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister_local(&self, registration: &LocalRegistration) -> Result<(), io::Error> { let mut reg = registration.entry.get().data.borrow_mut(); if let Some((key, _)) = reg.data { self.local.deregister(key)?; reg.data = None; } Ok(()) } fn register( &self, registration: &Registration, subtoken: Token, interests: Interest, ) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if reg.data.is_none() { let key = self.sync.register(subtoken, interests)?; reg.data = Some((key, self.sync.clone())); if let Some(readiness) = reg.readiness { self.sync.set_readiness(key, readiness).unwrap(); reg.readiness = None; } } Ok(()) } fn deregister(&self, registration: &Registration) -> Result<(), io::Error> { let mut reg = registration.inner.lock().unwrap(); if let Some((key, _)) = reg.data { self.sync.deregister(key)?; reg.data = None; } Ok(()) } fn has_local_events(&self) -> bool { self.local.has_events() } fn has_events(&self) -> bool { if self.local.has_events() { return true; } if self.next_local_only.get() { return false; } self.sync.has_events() } fn next_event(&self) -> Option<(Token, Interest)> { if let Some(e) = self.local.next_event() { return Some(e); } if self.next_local_only.get() { return None; } if let Some(e) = self.sync.next_event() { return Some(e); } None } } struct RegistrationInner { data: Option<(usize, Arc)>, readiness: Readiness, } pub struct Registration { inner: Arc>, } impl Registration { pub fn new() -> (Self, SetReadiness) { let reg = Arc::new(Mutex::new(RegistrationInner { data: None, readiness: None, })); let registration = Self { inner: reg.clone() }; let set_readiness = SetReadiness { inner: reg }; (registration, set_readiness) } } impl Drop for Registration { fn drop(&mut self) { let mut reg = self.inner.lock().unwrap(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct SetReadiness { inner: Arc>, } impl SetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.inner.lock().unwrap(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } struct LocalRegistrationData { data: Option<(usize, Rc)>, readiness: Readiness, } pub struct LocalRegistrationEntry { data: RefCell, } pub struct LocalRegistration { entry: arena::Rc, } impl LocalRegistration { pub fn new(memory: &Rc>) -> (Self, LocalSetReadiness) { let reg = arena::Rc::new( LocalRegistrationEntry { data: RefCell::new(LocalRegistrationData { data: None, readiness: None, }), }, memory, ) .unwrap(); let registration = Self { entry: arena::Rc::clone(®), }; let set_readiness = LocalSetReadiness { entry: reg }; (registration, set_readiness) } } impl Drop for LocalRegistration { fn drop(&mut self) { let mut reg = self.entry.get().data.borrow_mut(); if let Some((key, sources)) = ®.data { sources.deregister(*key).unwrap(); reg.data = None; } } } pub struct LocalSetReadiness { entry: arena::Rc, } impl LocalSetReadiness { pub fn set_readiness(&self, readiness: Interest) -> Result<(), io::Error> { let mut reg = self.entry.get().data.borrow_mut(); match ®.data { Some((key, sources)) => sources.set_readiness(*key, readiness)?, None => reg.readiness.merge(readiness), } Ok(()) } } #[derive(Debug, PartialEq)] pub struct Event { token: Token, readiness: Interest, } impl Event { pub fn token(&self) -> Token { self.token } pub fn readiness(&self) -> Interest { self.readiness } pub fn is_readable(&self) -> bool { self.readiness.is_readable() } pub fn is_writable(&self) -> bool { self.readiness.is_writable() } } pub struct Poller { poll: Poll, events: Events, custom_sources: CustomSources, local_registration_memory: Rc>, local_budget: u32, } impl Poller { pub fn new(max_custom_sources: usize) -> Result { let poll = Poll::new()?; let events = Events::with_capacity(EVENTS_MAX); let custom_sources = CustomSources::new(&poll, Token(0), max_custom_sources)?; Ok(Self { poll, events, custom_sources, local_registration_memory: Rc::new(arena::RcMemory::new(max_custom_sources)), local_budget: LOCAL_BUDGET, }) } pub fn register( &self, source: &mut S, token: Token, interests: Interest, ) -> Result<(), io::Error> where S: Source + ?Sized, { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.poll.registry().register(source, token, interests) } pub fn deregister(&self, source: &mut S) -> Result<(), io::Error> where S: Source + ?Sized, { self.poll.registry().deregister(source) } pub fn register_custom( &self, registration: &Registration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources.register(registration, token, interests) } pub fn deregister_custom(&self, registration: &Registration) -> Result<(), io::Error> { self.custom_sources.deregister(registration) } pub fn local_registration_memory(&self) -> &Rc> { &self.local_registration_memory } pub fn register_custom_local( &self, registration: &LocalRegistration, token: Token, interests: Interest, ) -> Result<(), io::Error> { if token == Token(0) { return Err(io::Error::from(io::ErrorKind::InvalidInput)); } self.custom_sources .register_local(registration, token, interests) } pub fn deregister_custom_local( &self, registration: &LocalRegistration, ) -> Result<(), io::Error> { self.custom_sources.deregister_local(registration) } pub fn poll(&mut self, timeout: Option) -> Result<(), io::Error> { if self.custom_sources.has_local_events() && self.local_budget > 0 { self.local_budget -= 1; self.custom_sources.set_next_local_only(true); self.events.clear(); // don't reread previous mio events return Ok(()); } self.local_budget = LOCAL_BUDGET; self.custom_sources.set_next_local_only(false); let timeout = if self.custom_sources.has_events() { Some(Duration::from_millis(0)) } else { timeout }; loop { match self.poll.poll(&mut self.events, timeout) { Err(e) if e.kind() == io::ErrorKind::Interrupted => {} ret => break ret, } } } pub fn iter_events(&self) -> EventsIterator<'_, '_> { EventsIterator { events: self.events.iter(), custom_sources: &self.custom_sources, custom_left: EVENTS_MAX, } } } pub struct EventsIterator<'a, 'b> { events: mio::event::Iter<'b>, custom_sources: &'a CustomSources, custom_left: usize, } impl Iterator for EventsIterator<'_, '_> { type Item = Event; fn next(&mut self) -> Option { for event in self.events.by_ref() { if event.token() == Token(0) { continue; } let mut readiness = None; if event.is_readable() { readiness.merge(Interest::READABLE); } if event.is_writable() { readiness.merge(Interest::WRITABLE); } if let Some(readiness) = readiness { return Some(Event { token: event.token(), readiness, }); } } if self.custom_left > 0 { self.custom_left -= 1; if let Some((token, readiness)) = self.custom_sources.next_event() { return Some(Event { token, readiness }); } } None } } pub mod ffi { pub const READABLE: u8 = 0x01; pub const WRITABLE: u8 = 0x02; pub struct InterestError; pub fn interest_int_to_mio(interest: u8) -> Result { let interest = if interest & READABLE != 0 && interest & WRITABLE != 0 { mio::Interest::READABLE | mio::Interest::WRITABLE } else if interest & READABLE != 0 { mio::Interest::READABLE } else if interest & WRITABLE != 0 { mio::Interest::WRITABLE } else { // must specify at least one of READABLE or WRITABLE return Err(InterestError); }; Ok(interest) } pub struct SetReadiness(pub super::SetReadiness); #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn set_readiness_destroy(sr: *mut SetReadiness) { if !sr.is_null() { drop(Box::from_raw(sr)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn set_readiness_set_readiness( sr: *const SetReadiness, readiness: u8, ) -> libc::c_int { let sr = sr.as_ref().unwrap(); if let Ok(readiness) = interest_int_to_mio(readiness) { if sr.0.set_readiness(readiness).is_err() { return -1; } } 0 } } #[cfg(test)] mod tests { use super::*; use std::time::Duration; #[test] fn test_readiness() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); 'poll: loop { poll.poll(&mut events, None).unwrap(); for event in &events { if event.token() == token { break 'poll; } } } assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_early() { let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = Registration::new(); sr.set_readiness(Interest::READABLE).unwrap(); sources .register(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); let event = events.iter().next(); assert!(event.is_some()); let event = event.unwrap(); assert_eq!(event.token(), token); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_readiness_local() { let poller = Poller::new(1).unwrap(); let token = Token(123); let subtoken = Token(456); let mut poll = Poll::new().unwrap(); let sources = CustomSources::new(&poll, token, 1).unwrap(); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); let (reg, sr) = LocalRegistration::new(poller.local_registration_memory()); sources .register_local(®, subtoken, Interest::READABLE) .unwrap(); let mut events = Events::with_capacity(1024); poll.poll(&mut events, Some(Duration::from_millis(0))) .unwrap(); assert!(events.is_empty()); sr.set_readiness(Interest::READABLE).unwrap(); assert_eq!(sources.has_events(), true); assert_eq!(sources.next_event(), Some((subtoken, Interest::READABLE))); assert_eq!(sources.has_events(), false); assert_eq!(sources.next_event(), None); } #[test] fn test_poller() { let token = Token(123); let mut poller = Poller::new(1).unwrap(); assert_eq!(poller.iter_events().next(), None); let (reg, sr) = Registration::new(); poller .register_custom(®, token, Interest::READABLE) .unwrap(); poller.poll(Some(Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); sr.set_readiness(Interest::READABLE).unwrap(); poller.poll(None).unwrap(); let mut it = poller.iter_events(); let event = it.next().unwrap(); assert_eq!(event.token(), token); assert_eq!(event.is_readable(), true); assert_eq!(it.next(), None); } } pushpin-1.41.0/src/core/eventloop.cpp000066400000000000000000000052731504671364300175410ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "eventloop.h" #include static thread_local EventLoop *g_instance = nullptr; EventLoop::EventLoop(int capacity) { // only one per thread allowed assert(!g_instance); inner_ = ffi::event_loop_create(capacity); g_instance = this; } EventLoop::~EventLoop() { while(!cleanupHandlers_.empty()) { CleanupHandler h = cleanupHandlers_.front(); cleanupHandlers_.pop_front(); h.handler(h.ctx); } ffi::event_loop_destroy(inner_); g_instance = nullptr; } std::optional EventLoop::step() { std::optional code; int x; if(ffi::event_loop_step(inner_, &x) == 0) code = x; return code; } int EventLoop::exec() { return ffi::event_loop_exec(inner_); } void EventLoop::exit(int code) { ffi::event_loop_exit(inner_, code); } int EventLoop::registerFd(int fd, uint8_t interest, void (*cb)(void *, uint8_t), void *ctx) { size_t id; if(ffi::event_loop_register_fd(inner_, fd, interest, cb, ctx, &id) != 0) return -1; return (int)id; } int EventLoop::registerTimer(int timeout, void (*cb)(void *, uint8_t), void *ctx) { size_t id; if(ffi::event_loop_register_timer(inner_, timeout, cb, ctx, &id) != 0) return -1; return (int)id; } std::tuple> EventLoop::registerCustom(void (*cb)(void *, uint8_t), void *ctx) { size_t id; ffi::SetReadiness *srRaw = nullptr; if(ffi::event_loop_register_custom(inner_, cb, ctx, &id, &srRaw) != 0) return std::tuple>(); std::unique_ptr sr(new Event::SetReadiness(srRaw)); return std::tuple>({(int)id, std::move(sr)}); } void EventLoop::deregister(int id) { assert(ffi::event_loop_deregister(inner_, id) == 0); } void EventLoop::addCleanupHandler(void (*handler)(void *), void *ctx) { CleanupHandler h; h.handler = handler; h.ctx = ctx; cleanupHandlers_.push_front(h); } void EventLoop::removeCleanupHandler(void (*handler)(void *), void *ctx) { CleanupHandler h; h.handler = handler; h.ctx = ctx; cleanupHandlers_.remove(h); } EventLoop *EventLoop::instance() { return g_instance; } pushpin-1.41.0/src/core/eventloop.h000066400000000000000000000033041504671364300171770ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef EVENTLOOP_H #define EVENTLOOP_H #include #include #include #include "event.h" #include "rust/bindings.h" class EventLoop { public: EventLoop(int capacity); ~EventLoop(); // disable copying EventLoop(const EventLoop &) = delete; EventLoop & operator=(const EventLoop &) = delete; std::optional step(); int exec(); void exit(int code); int registerFd(int fd, uint8_t interest, void (*cb)(void *, uint8_t), void *ctx); int registerTimer(int timeout, void (*cb)(void *, uint8_t), void *ctx); std::tuple> registerCustom(void (*cb)(void *, uint8_t), void *ctx); void deregister(int id); void addCleanupHandler(void (*handler)(void *), void *ctx); void removeCleanupHandler(void (*handler)(void *), void *ctx); static EventLoop *instance(); private: class CleanupHandler { public: void (*handler)(void *); void *ctx; bool operator==(const CleanupHandler &other) const { return (other.handler == handler && other.ctx == ctx); } }; ffi::EventLoopRaw *inner_; std::list cleanupHandlers_; }; #endif pushpin-1.41.0/src/core/eventloop.rs000066400000000000000000000556141504671364300174070ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::event::{self, ReadinessExt}; use crate::core::list; use crate::core::reactor; use crate::core::waker; use slab::Slab; use std::cell::{Cell, RefCell}; use std::future::Future; use std::os::fd::RawFd; use std::pin::Pin; use std::rc::{Rc, Weak}; use std::task::{Context, Poll, Waker}; use std::time::Duration; pub trait Callback { fn call(&mut self, readiness: event::Readiness); } impl Callback for Box { fn call(&mut self, readiness: event::Readiness) { (**self).call(readiness); } } pub struct FnCallback(T); impl Callback for FnCallback { fn call(&mut self, readiness: event::Readiness) { self.0(readiness); } } enum Evented { Fd(reactor::FdEvented), Timer(reactor::TimerEvented), Custom { evented: reactor::CustomEvented, _reg: event::Registration, }, } impl Evented { fn registration(&self) -> &reactor::Registration { match self { Self::Fd(e) => e.registration(), Self::Timer(e) => e.registration(), Self::Custom { evented, .. } => evented.registration(), } } } struct Registration { evented: Evented, activated: bool, callback: Option, } struct RegistrationsData { nodes: Slab>>, activated: list::List, waker: Option, } #[derive(Debug)] struct RegistrationsError; struct Registrations { data: RefCell>, } impl Registrations { fn new(capacity: usize) -> Self { Self { data: RefCell::new(RegistrationsData { nodes: Slab::with_capacity(capacity), activated: list::List::default(), waker: None, }), } } fn add( &self, evented: Evented, interest: mio::Interest, get_waker: W, callback: C, ) -> Result where W: FnOnce(usize) -> Waker, C: Callback, { let data = &mut *self.data.borrow_mut(); if data.nodes.len() == data.nodes.capacity() { return Err(RegistrationsError); } let entry = data.nodes.vacant_entry(); let nkey = entry.key(); evented.registration().set_waker_persistent(true); evented.registration().set_waker(&get_waker(nkey), interest); let reg = Registration { evented, activated: false, callback: Some(callback), }; entry.insert(list::Node::new(reg)); Ok(nkey) } fn remove(&self, reg_id: usize) -> Result<(), RegistrationsError> { let nkey = reg_id; let data = &mut *self.data.borrow_mut(); if !data.nodes.contains(nkey) { return Err(RegistrationsError); } data.activated.remove(&mut data.nodes, nkey); data.nodes.remove(nkey); Ok(()) } fn activate(&self, reg_id: usize) { let nkey = reg_id; let data = &mut *self.data.borrow_mut(); let reg = &mut data.nodes[nkey].value; if reg.activated { return; } reg.activated = true; data.activated.push_back(&mut data.nodes, nkey); if let Some(waker) = data.waker.take() { waker.wake(); } } fn dispatch_activated(&self) { // call the callback of each activated registration, ensuring we // release borrows before each call. this way, callbacks can access // the eventloop, for example to add or remove registrations loop { let (nkey, mut callback, readiness) = { let data = &mut *self.data.borrow_mut(); let nkey = match data.activated.pop_front(&mut data.nodes) { Some(nkey) => nkey, None => break, }; let reg = &mut data.nodes[nkey].value; let callback = reg .callback .take() .expect("registration should have a callback"); let readiness = reg.evented.registration().readiness(); let nkey = if let Evented::Timer(_) = ®.evented { // remove timer registrations after activation data.nodes.remove(nkey); None } else { reg.activated = false; reg.evented .registration() .clear_readiness(mio::Interest::READABLE | mio::Interest::WRITABLE); Some(nkey) }; (nkey, callback, readiness) }; callback.call(readiness); if let Some(nkey) = nkey { let data = &mut *self.data.borrow_mut(); // if the registration still exists, restore its callback if let Some(n) = &mut data.nodes.get_mut(nkey) { let reg = &mut n.value; // only set the callback field on the registration if // it's the same registration we took the callback from // and not a new registration that happened to reuse the // same slot. if the callback field is none, then it's // the same registration. if reg.callback.is_none() { reg.callback = Some(callback); } } } } } fn set_waker(&self, waker: &Waker) { let data = &mut *self.data.borrow_mut(); if let Some(current_waker) = &data.waker { if !waker.will_wake(current_waker) { // replace data.waker = Some(waker.clone()); } } else { // set data.waker = Some(waker.clone()); } } } struct Activator { regs: Weak>, reg_id: usize, } impl waker::RcWake for Activator { fn wake(self: Rc) { if let Some(regs) = self.regs.upgrade() { regs.activate(self.reg_id); } } } #[derive(Debug)] pub struct EventLoopError; pub struct EventLoop { reactor: reactor::Reactor, exit_code: Cell>, regs: Rc>, } impl EventLoop { // will create a reactor if one does not exist in the current thread. if // one already exists, registrations_max should be <= the max configured // in the reactor. pub fn new(registrations_max: usize) -> Self { let reactor = if let Some(reactor) = reactor::Reactor::current() { // use existing reactor if available reactor } else { reactor::Reactor::new(registrations_max) }; Self { reactor, exit_code: Cell::new(None), regs: Rc::new(Registrations::new(registrations_max)), } } pub fn step(&self) -> Option { self.poll_and_dispatch(Some(Duration::from_millis(0))) } pub fn exec(&self) -> i32 { loop { if let Some(code) = self.poll_and_dispatch(None) { break code; } } } pub fn exec_async(&self) -> Exec<'_, C> { Exec { l: self } } pub fn exit(&self, code: i32) { self.exit_code.set(Some(code)); } pub fn register_fd( &self, fd: RawFd, interest: mio::Interest, callback: C, ) -> Result { let evented = match reactor::FdEvented::new(fd, interest, &self.reactor) { Ok(evented) => evented, Err(_) => return Err(EventLoopError), }; let regs = Rc::downgrade(&self.regs); let get_waker = |reg_id| { let activator = Rc::new(Activator { regs, reg_id }); waker::into_std(activator) }; Ok(self .regs .add(Evented::Fd(evented), interest, get_waker, callback) .expect("slab should have capacity")) } pub fn register_timer(&self, timeout: Duration, callback: C) -> Result { let expires = self.reactor.now() + timeout; let evented = match reactor::TimerEvented::new(expires, &self.reactor) { Ok(evented) => evented, Err(_) => return Err(EventLoopError), }; let regs = Rc::downgrade(&self.regs); let get_waker = |reg_id| { let activator = Rc::new(Activator { regs, reg_id }); waker::into_std(activator) }; Ok(self .regs .add( Evented::Timer(evented), mio::Interest::READABLE, get_waker, callback, ) .expect("slab should have capacity")) } pub fn register_custom( &self, callback: C, ) -> Result<(usize, event::SetReadiness), EventLoopError> { let (reg, sr) = event::Registration::new(); let evented = match reactor::CustomEvented::new( ®, mio::Interest::READABLE | mio::Interest::WRITABLE, &self.reactor, ) { Ok(evented) => evented, Err(_) => return Err(EventLoopError), }; let regs = Rc::downgrade(&self.regs); let get_waker = |reg_id| { let activator = Rc::new(Activator { regs, reg_id }); waker::into_std(activator) }; let id = self .regs .add( Evented::Custom { evented, _reg: reg }, mio::Interest::READABLE | mio::Interest::WRITABLE, get_waker, callback, ) .expect("slab should have capacity"); Ok((id, sr)) } pub fn deregister(&self, id: usize) -> Result<(), EventLoopError> { self.regs.remove(id).map_err(|_| EventLoopError) } fn poll_and_dispatch(&self, timeout: Option) -> Option { // if exit code set, do a non-blocking poll let timeout = if self.exit_code.get().is_some() { Some(Duration::from_millis(0)) } else { timeout }; self.reactor.poll(timeout).unwrap(); self.regs.dispatch_activated(); self.exit_code.get() } } pub struct Exec<'a, C> { l: &'a EventLoop, } impl Future for Exec<'_, C> { type Output = i32; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let l = self.l; l.regs.dispatch_activated(); if let Some(code) = l.exit_code.get() { return Poll::Ready(code); } l.regs.set_waker(cx.waker()); Poll::Pending } } mod ffi { use super::*; use event::ffi::{interest_int_to_mio, READABLE, WRITABLE}; use std::ops::Deref; pub struct RawCallback { // SAFETY: must be called with the associated ctx value f: unsafe extern "C" fn(*mut libc::c_void, u8), ctx: *mut libc::c_void, } impl RawCallback { // SAFETY: caller must ensure f is safe to call for the lifetime // of the registration pub unsafe fn new( f: unsafe extern "C" fn(*mut libc::c_void, u8), ctx: *mut libc::c_void, ) -> Self { Self { f, ctx } } } impl Callback for RawCallback { fn call(&mut self, readiness: event::Readiness) { let readiness = { let mut r = 0; if readiness.contains_any(mio::Interest::READABLE) { r |= READABLE; } if readiness.contains_any(mio::Interest::WRITABLE) { r |= WRITABLE; } r }; // SAFETY: we are passing the ctx value that was provided unsafe { (self.f)(self.ctx, readiness); } } } pub struct EventLoopRaw(EventLoop); impl Deref for EventLoopRaw { type Target = EventLoop; fn deref(&self) -> &Self::Target { &self.0 } } #[no_mangle] pub extern "C" fn event_loop_create(capacity: libc::c_uint) -> *mut EventLoopRaw { let l = EventLoopRaw(EventLoop::new(capacity as usize)); Box::into_raw(Box::new(l)) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_destroy(l: *mut EventLoopRaw) { if !l.is_null() { drop(Box::from_raw(l)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_step( l: *mut EventLoopRaw, out_code: *mut libc::c_int, ) -> libc::c_int { let l = l.as_mut().unwrap(); match l.step() { Some(code) => { unsafe { out_code.write(code) }; 0 } None => -1, } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_exec(l: *mut EventLoopRaw) -> libc::c_int { let l = l.as_mut().unwrap(); l.exec() as libc::c_int } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_exit(l: *mut EventLoopRaw, code: libc::c_int) { let l = l.as_mut().unwrap(); l.exit(code); } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_register_fd( l: *mut EventLoopRaw, fd: std::os::raw::c_int, interest: u8, cb: unsafe extern "C" fn(*mut libc::c_void, u8), ctx: *mut libc::c_void, out_id: *mut libc::size_t, ) -> libc::c_int { let l = l.as_mut().unwrap(); let Ok(interest) = interest_int_to_mio(interest) else { return -1; }; // SAFETY: we assume caller guarantees that the callback is safe to // call for the lifetime of the registration let cb = unsafe { RawCallback::new(cb, ctx) }; let id = match l.register_fd(fd, interest, cb) { Ok(id) => id, Err(_) => return -1, }; out_id.write(id); 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_register_timer( l: *mut EventLoopRaw, timeout: u64, cb: unsafe extern "C" fn(*mut libc::c_void, u8), ctx: *mut libc::c_void, out_id: *mut libc::size_t, ) -> libc::c_int { let l = l.as_mut().unwrap(); // SAFETY: we assume caller guarantees that the callback is safe to // call for the lifetime of the registration let cb = unsafe { RawCallback::new(cb, ctx) }; let id = match l.register_timer(Duration::from_millis(timeout), cb) { Ok(id) => id, Err(_) => return -1, }; out_id.write(id); 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_register_custom( l: *mut EventLoopRaw, cb: unsafe extern "C" fn(*mut libc::c_void, u8), ctx: *mut libc::c_void, out_id: *mut libc::size_t, out_set_readiness: *mut *mut event::ffi::SetReadiness, ) -> libc::c_int { let l = l.as_mut().unwrap(); // SAFETY: we assume caller guarantees that the callback is safe to // call for the lifetime of the registration let cb = unsafe { RawCallback::new(cb, ctx) }; let (id, sr) = match l.register_custom(cb) { Ok(id) => id, Err(_) => return -1, }; out_id.write(id); out_set_readiness.write(Box::into_raw(Box::new(event::ffi::SetReadiness(sr)))); 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn event_loop_deregister( l: *mut EventLoopRaw, id: libc::size_t, ) -> libc::c_int { let l = l.as_mut().unwrap(); if l.deregister(id).is_err() { return -1; } 0 } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use crate::core::reactor::Reactor; use std::cell::Cell; use std::io; use std::os::fd::AsRawFd; use std::rc::Rc; use std::thread; struct NoopCallback; impl Callback for NoopCallback { fn call(&mut self, _readiness: event::Readiness) {} } #[test] fn exec() { { let l = EventLoop::::new(1); assert_eq!(l.step(), None); l.exit(123); assert_eq!(l.step(), Some(123)); } { let l = EventLoop::::new(1); l.exit(124); assert_eq!(l.exec(), 124); } } #[test] fn fd() { let l = Rc::new(EventLoop::>::new(1)); let listener = Rc::new(std::net::TcpListener::bind("127.0.0.1:0").unwrap()); listener.set_nonblocking(true).unwrap(); let addr = listener.local_addr().unwrap(); let fd = listener.as_raw_fd(); let count = Rc::new(Cell::new(0)); let cb = { let l = Rc::clone(&l); let listener = Rc::clone(&listener); let count = Rc::clone(&count); Box::new(FnCallback(move |readiness: event::Readiness| { assert!(readiness.contains_any(mio::Interest::READABLE)); let _stream = listener.accept().unwrap(); let e = listener.accept().unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::WouldBlock); count.set(count.get() + 1); if count.get() == 2 { l.exit(0); } })) }; let id = l.register_fd(fd, mio::Interest::READABLE, cb).unwrap(); { // non-blocking connect attempt to trigger listener let _stream = mio::net::TcpStream::connect(addr); while count.get() < 1 { l.step(); thread::sleep(Duration::from_millis(10)); } } { // non-blocking connect attempt to trigger listener let _stream = mio::net::TcpStream::connect(addr); while count.get() < 2 { l.step(); thread::sleep(Duration::from_millis(10)); } } assert_eq!(l.exec(), 0); l.deregister(id).unwrap(); } #[test] fn timer() { let l = Rc::new(EventLoop::>::new(1)); let cb = { let l = Rc::clone(&l); Box::new(FnCallback(move |readiness: event::Readiness| { assert!(readiness.contains_any(mio::Interest::READABLE)); l.exit(0); })) }; let id = l.register_timer(Duration::from_millis(0), cb).unwrap(); // no space assert!(l .register_timer(Duration::from_millis(0), Box::new(NoopCallback)) .is_err()); assert_eq!(l.exec(), 0); // activated timers automatically deregister l.deregister(id).unwrap_err(); let id = l .register_timer(Duration::from_millis(0), Box::new(NoopCallback)) .unwrap(); l.deregister(id).unwrap(); } #[test] fn custom() { let l = Rc::new(EventLoop::>::new(1)); let cb = { let l = Rc::clone(&l); Box::new(FnCallback(move |readiness: event::Readiness| { assert!(readiness.contains_any(mio::Interest::READABLE)); l.exit(0); })) }; let (id, sr) = l.register_custom(cb).unwrap(); sr.set_readiness(mio::Interest::READABLE).unwrap(); assert_eq!(l.exec(), 0); l.deregister(id).unwrap(); } #[test] fn deregister_within_callback() { let l = Rc::new(EventLoop::>::new(1)); let listener = Rc::new(std::net::TcpListener::bind("127.0.0.1:0").unwrap()); listener.set_nonblocking(true).unwrap(); let addr = listener.local_addr().unwrap(); let fd = listener.as_raw_fd(); let id = Rc::new(Cell::new(None)); let cb = { let l = Rc::clone(&l); let listener = Rc::clone(&listener); let id = Rc::clone(&id); Box::new(FnCallback(move |readiness: event::Readiness| { assert!(readiness.contains_any(mio::Interest::READABLE)); let _stream = listener.accept().unwrap(); let e = listener.accept().unwrap_err(); assert_eq!(e.kind(), io::ErrorKind::WouldBlock); // this is allowed l.deregister(id.get().unwrap()).unwrap(); l.exit(0); })) }; id.set(Some( l.register_fd(fd, mio::Interest::READABLE, cb).unwrap(), )); // non-blocking connect attempt to trigger listener let _stream = mio::net::TcpStream::connect(addr); assert_eq!(l.exec(), 0); } #[test] fn exec_async() { let reactor = Reactor::new(1); let executor = Executor::new(1); executor .spawn(async { let l = Rc::new(EventLoop::>::new(1)); let listener = Rc::new(std::net::TcpListener::bind("127.0.0.1:0").unwrap()); listener.set_nonblocking(true).unwrap(); let addr = listener.local_addr().unwrap(); let fd = listener.as_raw_fd(); let cb = { let l = Rc::clone(&l); let listener = Rc::clone(&listener); Box::new(FnCallback(move |readiness: event::Readiness| { assert!(readiness.contains_any(mio::Interest::READABLE)); let _stream = listener.accept().unwrap(); l.exit(0); })) }; let id = l.register_fd(fd, mio::Interest::READABLE, cb).unwrap(); // non-blocking connect attempt to trigger listener let _stream = mio::net::TcpStream::connect(addr); assert_eq!(l.exec_async().await, 0); l.deregister(id).unwrap(); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } } pushpin-1.41.0/src/core/eventlooptest.cpp000066400000000000000000000051351504671364300204360ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include #include #include "test.h" #include "defercall.h" #include "eventloop.h" #include "socketnotifier.h" #include "timer.h" static void socketNotifier() { EventLoop loop(1); int fds[2]; TEST_ASSERT_EQ(pipe(fds), 0); SocketNotifier *sn = new SocketNotifier(fds[0], SocketNotifier::Read); sn->clearReadiness(SocketNotifier::Read); int activatedFd = -1; uint8_t activatedReadiness = -1; sn->activated.connect([&](int fd, uint8_t readiness) { activatedFd = fd; activatedReadiness = readiness; loop.exit(123); }); unsigned char c = 1; TEST_ASSERT_EQ(write(fds[1], &c, 1), 1); TEST_ASSERT_EQ(loop.exec(), 123); TEST_ASSERT_EQ(activatedFd, fds[0]); TEST_ASSERT_EQ(activatedReadiness, SocketNotifier::Read); delete sn; close(fds[1]); close(fds[0]); } static void timer() { EventLoop loop(2); Timer *t1 = new Timer; Timer *t2 = new Timer; int timeoutCount = 0; t1->timeout.connect([&] { ++timeoutCount; }); t2->timeout.connect([&] { ++timeoutCount; loop.exit(123); }); t1->setSingleShot(true); t1->start(0); t2->setSingleShot(true); t2->start(0); TEST_ASSERT_EQ(loop.exec(), 123); TEST_ASSERT_EQ(timeoutCount, 2); delete t2; delete t1; } static void custom() { class State { public: EventLoop loop; uint8_t activatedReadiness; State() : loop(EventLoop(1)), activatedReadiness(-1) { } }; State state; auto [id, sr] = state.loop.registerCustom([](void *ctx, uint8_t readiness) { State *state = (State *)ctx; state->activatedReadiness = readiness; state->loop.exit(123); }, (void *)&state); TEST_ASSERT(id >= 0); TEST_ASSERT_EQ(sr->setReadiness(Event::Readable), 0); TEST_ASSERT_EQ(state.loop.exec(), 123); TEST_ASSERT_EQ(state.activatedReadiness, Event::Readable); state.loop.deregister(id); } extern "C" int eventloop_test(ffi::TestException *out_ex) { TEST_CATCH(socketNotifier()); TEST_CATCH(timer()); TEST_CATCH(custom()); return 0; } pushpin-1.41.0/src/core/executor.rs000066400000000000000000000471541504671364300172320ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::list; use crate::core::waker; use log::debug; use slab::Slab; use std::cell::RefCell; use std::future::Future; use std::io; use std::mem; use std::pin::Pin; use std::rc::{Rc, Weak}; use std::task::{Context, Waker}; use std::time::Duration; thread_local! { static EXECUTOR: RefCell>> = const { RefCell::new(None) }; } type BoxFuture = Pin>>; struct TaskWaker { tasks: Weak, task_id: usize, } impl waker::RcWake for TaskWaker { fn wake(self: Rc) { if let Some(tasks) = self.tasks.upgrade() { tasks.wake(self.task_id, false); } } } struct TaskResumeWaker { tasks: Weak, task_id: usize, } impl waker::RcWake for TaskResumeWaker { fn wake(self: Rc) { if let Some(tasks) = self.tasks.upgrade() { tasks.wake(self.task_id, true); } } } fn poll_fut(fut: &mut BoxFuture, waker: Waker) -> bool { // convert from Pin to Pin<&mut> let fut: Pin<&mut dyn Future> = fut.as_mut(); let mut cx = Context::from_waker(&waker); fut.poll(&mut cx).is_ready() } struct Task { fut: Option>>>, wakeable: bool, low: bool, } struct TasksData { nodes: Slab>, next: list::List, next_low: list::List, wakers: Vec>, current_task: Option, } struct Tasks { data: RefCell, pre_poll: RefCell>>, } impl Tasks { fn new(max: usize) -> Rc { let data = TasksData { nodes: Slab::with_capacity(max), next: list::List::default(), next_low: list::List::default(), wakers: Vec::with_capacity(max), current_task: None, }; let tasks = Rc::new(Self { data: RefCell::new(data), pre_poll: RefCell::new(None), }); { let data = &mut *tasks.data.borrow_mut(); for task_id in 0..data.nodes.capacity() { data.wakers.push(Rc::new(TaskWaker { tasks: Rc::downgrade(&tasks), task_id, })); } } tasks } fn is_empty(&self) -> bool { self.data.borrow().nodes.is_empty() } fn have_next(&self) -> bool { !self.data.borrow().next.is_empty() || !self.data.borrow().next_low.is_empty() } fn add(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let data = &mut *self.data.borrow_mut(); if data.nodes.len() == data.nodes.capacity() { return Err(()); } let entry = data.nodes.vacant_entry(); let nkey = entry.key(); let task = Task { fut: Some(Box::pin(fut)), wakeable: false, low: false, }; entry.insert(list::Node::new(task)); data.next.push_back(&mut data.nodes, nkey); Ok(()) } fn remove(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; // drop the future. this should cause it to drop any owned wakers task.fut = None; // at this point, we should be the only remaining owner assert_eq!(Rc::strong_count(&data.wakers[nkey]), 1); if task.low { data.next_low.remove(&mut data.nodes, nkey); } else { data.next.remove(&mut data.nodes, nkey); } data.nodes.remove(nkey); } fn current_task(&self) -> Option { self.data.borrow().current_task } fn set_current_task(&self, task_id: Option) { self.data.borrow_mut().current_task = task_id; } fn take_next_list(&self, low: bool) -> list::List { let data = &mut *self.data.borrow_mut(); let mut l = list::List::default(); if low { l.concat(&mut data.nodes, &mut data.next_low); } else { l.concat(&mut data.nodes, &mut data.next); } l } fn take_task(&self, l: &mut list::List) -> Option<(usize, BoxFuture, Waker)> { let nkey = l.head?; let data = &mut *self.data.borrow_mut(); l.remove(&mut data.nodes, nkey); let task = &mut data.nodes[nkey].value; // both of these are cheap let fut = task.fut.take().unwrap(); let waker = waker::into_std(data.wakers[nkey].clone()); task.wakeable = true; Some((nkey, fut, waker)) } fn process_next(&self, low: bool) { let mut l = self.take_next_list(low); while let Some((task_id, mut fut, waker)) = self.take_task(&mut l) { self.set_current_task(Some(task_id)); self.pre_poll(); let done = poll_fut(&mut fut, waker); self.set_current_task(None); // take_task() took the future out of the task, so we // could poll it without having to maintain a borrow of // the tasks set. we'll put it back now self.set_fut(task_id, fut); if done { self.remove(task_id); } } } fn set_fut(&self, task_id: usize, fut: BoxFuture) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; task.fut = Some(fut); } fn wake(&self, task_id: usize, resume: bool) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); let task = &mut data.nodes[nkey].value; if !task.wakeable && !resume { return; } task.wakeable = false; if data.current_task == Some(task_id) || resume { // if a task triggers its own waker, queue with low priority in // order to achieve a yielding effect. do the same when waking // with resume mode, to achieve a yielding effect even when the // wake occurs during events processing task.low = true; data.next_low.push_back(&mut data.nodes, nkey); } else { task.low = false; data.next.push_back(&mut data.nodes, nkey); } } fn ignore_wakes(&self, task_id: usize) { let nkey = task_id; let data = &mut *self.data.borrow_mut(); // tasks other than the current task may be in a temporary list // during task processing, in which case removal to prevent wakes is // not possible assert_eq!( data.current_task, Some(task_id), "ignore_wakes can only be self-applied" ); let task = &mut data.nodes[nkey].value; task.wakeable = false; if task.low { data.next_low.remove(&mut data.nodes, nkey); } else { data.next.remove(&mut data.nodes, nkey); } } fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { *self.pre_poll.borrow_mut() = Some(Box::new(pre_poll_fn)); } fn pre_poll(&self) { let pre_poll = &mut *self.pre_poll.borrow_mut(); if let Some(f) = pre_poll { f(); } } } #[derive(Debug)] pub struct CurrentTaskError; pub struct Executor { tasks: Rc, } impl Executor { pub fn new(tasks_max: usize) -> Self { let tasks = Tasks::new(tasks_max); EXECUTOR.with(|ex| { if ex.borrow().is_some() { panic!("thread already has an Executor"); } ex.replace(Some(Rc::downgrade(&tasks))); }); Self { tasks } } #[allow(clippy::result_unit_err)] pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { debug!("spawning future with size {}", mem::size_of::()); self.tasks.add(fut) } pub fn set_pre_poll(&self, pre_poll_fn: F) where F: FnMut() + 'static, { self.tasks.set_pre_poll(pre_poll_fn); } pub fn have_tasks(&self) -> bool { !self.tasks.is_empty() } pub fn run_until_stalled(&self) { while self.tasks.have_next() { self.tasks.process_next(false); self.tasks.process_next(true); } } pub fn run(&self, mut park: F) -> Result<(), io::Error> where F: FnMut(Option) -> Result<(), io::Error>, { loop { // run normal priority only self.tasks.process_next(false); if !self.have_tasks() { break; } let timeout = if self.tasks.have_next() { // some tasks trigger their own waker and return Pending in // order to achieve a yielding effect. in that case they will // already be queued up for processing again. use a timeout // of 0 when parking so we can quickly resume them let timeout = Duration::from_millis(0); Some(timeout) } else { None }; park(timeout)?; // run normal priority again, in case the park triggered wakers self.tasks.process_next(false); // finally, run low priority (mainly yielding tasks) self.tasks.process_next(true); } Ok(()) } pub fn current() -> Option { EXECUTOR.with(|ex| { (*ex.borrow_mut()).as_mut().map(|tasks| Self { tasks: tasks.upgrade().unwrap(), }) }) } pub fn spawner(&self) -> Spawner { Spawner { tasks: Rc::downgrade(&self.tasks), } } pub fn create_resume_waker_for_current_task(&self) -> Result { match self.tasks.current_task() { Some(task_id) => { let waker = Rc::new(TaskResumeWaker { tasks: Rc::downgrade(&self.tasks), task_id, }); Ok(waker::into_std(waker)) } None => Err(CurrentTaskError), } } pub fn ignore_wakes_for_current_task(&self) -> Result<(), CurrentTaskError> { match self.tasks.current_task() { Some(task_id) => { self.tasks.ignore_wakes(task_id); Ok(()) } None => Err(CurrentTaskError), } } } impl Drop for Executor { fn drop(&mut self) { EXECUTOR.with(|ex| { if Rc::strong_count(&self.tasks) == 1 { ex.replace(None); } }); } } pub struct Spawner { tasks: Weak, } impl Spawner { #[allow(clippy::result_unit_err)] pub fn spawn(&self, fut: F) -> Result<(), ()> where F: Future + 'static, { let tasks = match self.tasks.upgrade() { Some(tasks) => tasks, None => return Err(()), }; let ex = Executor { tasks }; ex.spawn(fut) } } #[cfg(test)] mod tests { use super::*; use std::cell::Cell; use std::mem; use std::task::Poll; struct TestFutureData { ready: bool, waker: Option, } struct TestFuture { data: Rc>, } impl TestFuture { fn new() -> Self { let data = TestFutureData { ready: false, waker: None, }; Self { data: Rc::new(RefCell::new(data)), } } fn handle(&self) -> TestHandle { TestHandle { data: Rc::clone(&self.data), } } } impl Future for TestFuture { type Output = (); fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut data = self.data.borrow_mut(); match data.ready { true => Poll::Ready(()), false => { data.waker = Some(cx.waker().clone()); Poll::Pending } } } } struct TestHandle { data: Rc>, } impl TestHandle { fn set_ready(&self) { let data = &mut *self.data.borrow_mut(); data.ready = true; if let Some(waker) = data.waker.take() { waker.wake(); } } } struct EarlyWakeFuture { done: bool, } impl EarlyWakeFuture { fn new() -> Self { Self { done: false } } } impl Future for EarlyWakeFuture { type Output = (); fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { if !self.done { self.done = true; cx.waker().wake_by_ref(); return Poll::Pending; } Poll::Ready(()) } } #[test] fn test_executor_step() { let executor = Executor::new(1); let fut1 = TestFuture::new(); let fut2 = TestFuture::new(); let handle1 = fut1.handle(); let handle2 = fut2.handle(); let started = Rc::new(Cell::new(false)); let fut1_done = Rc::new(Cell::new(false)); let finishing = Rc::new(Cell::new(false)); { let started = Rc::clone(&started); let fut1_done = Rc::clone(&fut1_done); let finishing = Rc::clone(&finishing); executor .spawn(async move { started.set(true); fut1.await; fut1_done.set(true); fut2.await; finishing.set(true); }) .unwrap(); } // not started yet, no progress assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), false); executor.run_until_stalled(); // started, but fut1 not ready assert_eq!(executor.have_tasks(), true); assert_eq!(started.get(), true); assert_eq!(fut1_done.get(), false); handle1.set_ready(); executor.run_until_stalled(); // fut1 finished assert_eq!(executor.have_tasks(), true); assert_eq!(fut1_done.get(), true); assert_eq!(finishing.get(), false); handle2.set_ready(); executor.run_until_stalled(); // fut2 finished, and thus the task finished assert_eq!(finishing.get(), true); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_run() { let executor = Executor::new(1); let fut = TestFuture::new(); let handle = fut.handle(); executor .spawn(async move { fut.await; }) .unwrap(); executor .run(|_| { handle.set_ready(); Ok(()) }) .unwrap(); assert_eq!(executor.have_tasks(), false); } #[test] fn test_executor_spawn_error() { let executor = Executor::new(1); assert!(executor.spawn(async {}).is_ok()); assert!(executor.spawn(async {}).is_err()); } #[test] fn test_executor_current() { assert!(Executor::current().is_none()); let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor .spawn(async move { Executor::current() .unwrap() .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); let current = Executor::current().unwrap(); assert_eq!(executor.have_tasks(), false); assert!(current.spawn(async {}).is_ok()); assert_eq!(executor.have_tasks(), true); mem::drop(executor); assert!(Executor::current().is_some()); mem::drop(current); assert!(Executor::current().is_none()); } #[test] fn test_executor_spawner() { let executor = Executor::new(2); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); let spawner = executor.spawner(); executor .spawn(async move { spawner .spawn(async move { flag.set(true); }) .unwrap(); }) .unwrap(); } assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } #[test] fn test_executor_early_wake() { let executor = Executor::new(1); let fut = EarlyWakeFuture::new(); executor .spawn(async move { fut.await; }) .unwrap(); let mut park_count = 0; executor .run(|_| { park_count += 1; Ok(()) }) .unwrap(); assert_eq!(park_count, 1); } #[test] fn test_executor_pre_poll() { let executor = Executor::new(1); let flag = Rc::new(Cell::new(false)); { let flag = flag.clone(); executor.set_pre_poll(move || { flag.set(true); }); } executor.spawn(async {}).unwrap(); assert_eq!(flag.get(), false); executor.run(|_| Ok(())).unwrap(); assert_eq!(flag.get(), true); } #[test] fn test_executor_ignore_resume_wakes() { let executor = Executor::new(1); // can't create a resume waker or ignore wakes outside of task assert!(executor.create_resume_waker_for_current_task().is_err()); assert!(executor.ignore_wakes_for_current_task().is_err()); let resume_waker: Rc>> = Rc::new(Cell::new(None)); let fut = TestFuture::new(); let handle = fut.handle(); { let resume_waker = Rc::clone(&resume_waker); executor .spawn(async move { let executor = Executor::current().unwrap(); resume_waker.set(Some( executor.create_resume_waker_for_current_task().unwrap(), )); executor.ignore_wakes_for_current_task().unwrap(); fut.await; }) .unwrap(); } executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); handle.set_ready(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), true); resume_waker.take().take().unwrap().wake(); executor.run_until_stalled(); assert_eq!(executor.have_tasks(), false); } } pushpin-1.41.0/src/core/filewatcher.cpp000066400000000000000000000032141504671364300200140ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #include "filewatcher.h" #include #include FileWatcher::FileWatcher() : inner_(nullptr) { } FileWatcher::~FileWatcher() { sn_.reset(); if(inner_) ffi::file_watcher_destroy(inner_); } bool FileWatcher::start(const QString &filePath) { assert(!inner_); inner_ = ffi::file_watcher_create(filePath.toUtf8().data()); if(!inner_) return false; int fd = ffi::file_watcher_as_raw_fd(inner_); sn_ = std::make_unique(fd, SocketNotifier::Read); sn_->activated.connect(boost::bind(&FileWatcher::sn_activated, this, boost::placeholders::_1, boost::placeholders::_2)); sn_->clearReadiness(SocketNotifier::Read); sn_->setReadEnabled(true); // in case the socket was activated before registering the notifier if(ffi::file_watcher_file_changed(inner_)) deferCall_.defer([=] { fileChanged(); }); return true; } void FileWatcher::sn_activated(int socket, uint8_t readiness) { Q_UNUSED(socket); Q_UNUSED(readiness); sn_->clearReadiness(SocketNotifier::Read); if(ffi::file_watcher_file_changed(inner_)) fileChanged(); } pushpin-1.41.0/src/core/filewatcher.h000066400000000000000000000020671504671364300174660ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef FILEWATCHER_H #define FILEWATCHER_H #include #include "socketnotifier.h" #include "defercall.h" #include "rust/bindings.h" class QString; class FileWatcher { public: FileWatcher(); ~FileWatcher(); bool start(const QString &filePath); boost::signals2::signal fileChanged; private: ffi::FileWatcher *inner_; std::unique_ptr sn_; DeferCall deferCall_; void sn_activated(int socket, uint8_t readiness); }; #endif pushpin-1.41.0/src/core/fs.rs000066400000000000000000000305651504671364300160020ustar00rootroot00000000000000/* * Copyright (C) 2023-2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use log::warn; use notify::Watcher; use std::ffi::CString; use std::io; use std::mem; use std::os::fd::{AsRawFd, RawFd}; use std::os::unix::ffi::OsStrExt; use std::path::{Path, PathBuf}; use std::ptr; use std::sync::{Arc, Mutex}; use thiserror::Error; fn try_with_increasing_buffer(starting_size: usize, f: T) -> Result where T: Fn(&mut [u8]) -> Result, { let mut buf = vec![0; starting_size]; loop { match f(&mut buf) { Ok(v) => return Ok(v), Err(e) if e.raw_os_error() == Some(libc::ERANGE) => buf.resize(buf.len() * 2, 0), Err(e) => return Err(e), } } } fn get_user_uid(name: &str) -> Result { let name = CString::new(name).unwrap(); try_with_increasing_buffer(1024, |buf| unsafe { let mut pwd = mem::MaybeUninit::uninit(); let mut passwd = ptr::null_mut(); if libc::getpwnam_r( name.as_ptr(), pwd.as_mut_ptr(), buf.as_mut_ptr() as *mut libc::c_char, buf.len(), &mut passwd, ) != 0 { return Err(io::Error::last_os_error()); } let passwd = match passwd.as_ref() { Some(r) => r, None => return Err(io::Error::from(io::ErrorKind::NotFound)), }; Ok(passwd.pw_uid) }) } fn get_group_gid(name: &str) -> Result { let name = CString::new(name).unwrap(); try_with_increasing_buffer(1024, |buf| unsafe { let mut grp = mem::MaybeUninit::uninit(); let mut group = ptr::null_mut(); if libc::getgrnam_r( name.as_ptr(), grp.as_mut_ptr(), buf.as_mut_ptr() as *mut libc::c_char, buf.len(), &mut group, ) != 0 { return Err(io::Error::last_os_error()); } let group = match group.as_ref() { Some(r) => r, None => return Err(io::Error::from(io::ErrorKind::NotFound)), }; Ok(group.gr_gid) }) } pub fn set_user(path: &Path, user: &str) -> Result<(), io::Error> { let uid = get_user_uid(user)?; unsafe { let path = CString::new(path.as_os_str().as_bytes()).unwrap(); if libc::chown(path.as_ptr(), uid, u32::MAX) != 0 { return Err(io::Error::last_os_error()); } } Ok(()) } pub fn set_group(path: &Path, group: &str) -> Result<(), io::Error> { let gid = get_group_gid(group)?; unsafe { let path = CString::new(path.as_os_str().as_bytes()).unwrap(); if libc::chown(path.as_ptr(), u32::MAX, gid) != 0 { return Err(io::Error::last_os_error()); } } Ok(()) } #[cfg(target_os = "macos")] fn get_errno() -> libc::c_int { // SAFETY: always safe to call unsafe { *libc::__error() } } #[cfg(not(target_os = "macos"))] fn get_errno() -> libc::c_int { // SAFETY: always safe to call unsafe { *libc::__errno_location() } } fn set_fd_nonblocking(fd: RawFd) -> Result<(), io::Error> { // SAFETY: always safe to call let flags = unsafe { libc::fcntl(fd, libc::F_GETFL, 0) }; if flags < 0 { return Err(io::Error::last_os_error()); } // SAFETY: always safe to call let ret = unsafe { libc::fcntl(fd, libc::F_SETFL, flags | libc::O_NONBLOCK) }; if ret != 0 { return Err(io::Error::last_os_error()); } Ok(()) } fn parent_dir(path: &Path) -> Result<&Path, io::Error> { match path.parent() { Some(p) => { if p.as_os_str().is_empty() { Ok(Path::new(".")) // empty parent, assume current dir } else { Ok(p) } } None => Err(io::Error::from(io::ErrorKind::InvalidInput)), } } struct FileWatcherState { watcher: notify::RecommendedWatcher, changed: bool, } struct FileWatcherData { file: PathBuf, read_fd: RawFd, write_fd: RawFd, state: Mutex>, } pub struct FileWatcher { data: Arc, } #[derive(Debug, Error)] pub enum FileWatcherError { #[error("no parent")] NoParent, #[error(transparent)] Notify(#[from] notify::Error), } impl FileWatcher { // `file_path` must be a path to a file, not a directory, and the file's // parent directory must exist pub fn new>(file_path: P) -> Result { let file = file_path.as_ref(); let Ok(dir) = parent_dir(file) else { return Err(FileWatcherError::NoParent); }; let mut fds = [0; 2]; // SAFETY: fds pointer is valid let ret = unsafe { libc::pipe(fds.as_mut_ptr()) }; assert_eq!(ret, 0); for fd in &fds { // should never fail on a descriptor we own set_fd_nonblocking(*fd).unwrap(); } let data = Arc::new(FileWatcherData { file: file.to_owned(), read_fd: fds[0], write_fd: fds[1], state: Mutex::new(None), }); let watcher = { let data = Arc::clone(&data); notify::recommended_watcher(move |event: Result| { let event = match event { Ok(event) => event, Err(e) => { warn!("file watcher error: {:?}", e); return; } }; if !event.paths.into_iter().any(|p| p == data.file) { // skip unrelated events return; } match event.kind { notify::EventKind::Create(_) | notify::EventKind::Modify(_) | notify::EventKind::Remove(_) => {} _ => return, // skip non-change events } let mut state = data .state .lock() .expect("failed to lock during notify event"); if let Some(state) = &mut *state { if !state.changed { state.changed = true; // non-blocking write to wake up the other side let buf: [u8; 1] = [0; 1]; // SAFETY: buf pointer and size are valid let ret = unsafe { libc::write(data.write_fd, buf.as_ptr() as *const libc::c_void, 1) }; assert!(ret == 1 || get_errno() == libc::EAGAIN); } } }) .expect("failed to create file watcher") }; { let mut state = data .state .lock() .expect("failed to lock during initialization"); let state = state.insert(FileWatcherState { watcher, changed: false, }); // watch the dir instead of the file, so we can detect file creates state .watcher .watch(dir, notify::RecursiveMode::NonRecursive)?; } Ok(Self { data }) } pub fn file_changed(&self) -> bool { let mut changed = false; let mut state = self .data .state .lock() .expect("failed to lock during check for changes"); if let Some(state) = &mut *state { // non-blocking read to clear let mut buf = [0u8; 128]; // SAFETY: buf pointer and size are valid let ret = unsafe { libc::read( self.data.read_fd, buf.as_mut_ptr() as *mut libc::c_void, buf.len(), ) }; assert!(ret >= 0 || get_errno() == libc::EAGAIN); changed = state.changed; state.changed = false; } changed } } impl Drop for FileWatcher { fn drop(&mut self) { let mut state = self.data.state.lock().expect("failed to lock during drop"); *state = None; for fd in [self.data.write_fd, self.data.read_fd] { // SAFETY: always safe to call unsafe { libc::close(fd) }; } } } impl AsRawFd for FileWatcher { // for monitoring for changes. the returned file descriptor can be // registered in a poller for readability events. no I/O should be // performed on the returned file descriptor. after a readability event // is received, call file_changed() to check for a change. fn as_raw_fd(&self) -> RawFd { self.data.read_fd } } mod ffi { use super::*; use std::ffi::{CStr, OsStr}; use std::os::raw::{c_char, c_int}; #[no_mangle] pub extern "C" fn file_watcher_create(path: *const c_char) -> *mut FileWatcher { let path = unsafe { CStr::from_ptr(path) }; let path = Path::new(OsStr::from_bytes(path.to_bytes())); let w = match FileWatcher::new(path) { Ok(w) => w, Err(_) => return ptr::null_mut(), }; Box::into_raw(Box::new(w)) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn file_watcher_destroy(w: *mut FileWatcher) { if !w.is_null() { drop(Box::from_raw(w)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn file_watcher_file_changed(w: *const FileWatcher) -> c_int { let w = w.as_ref().unwrap(); match w.file_changed() { true => 1, false => 0, } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn file_watcher_as_raw_fd(w: *const FileWatcher) -> c_int { let w = w.as_ref().unwrap(); w.as_raw_fd() } } #[cfg(test)] mod tests { use super::*; use crate::core::event::Poller; use crate::core::test_dir; use std::fs; use std::time::Duration; fn wait_readable(poller: &mut Poller, token: mio::Token) { poller.poll(None).unwrap(); let event = poller.iter_events().next().unwrap(); assert_eq!(event.token(), token); assert_eq!(event.is_readable(), true); // wait for potentially multiple events to get processed for the file // operation, so that file_changed() returns true only once std::thread::sleep(Duration::from_millis(50)); } #[test] fn watcher() { let file = test_dir().join("watch-file"); match fs::remove_file(&file) { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::NotFound => {} _ => panic!("failed to remove {}", file.display()), } let mut poller = Poller::new(1).unwrap(); let token = mio::Token(1); let watcher = FileWatcher::new(&file).unwrap(); poller .register( &mut mio::unix::SourceFd(&watcher.as_raw_fd()), token, mio::Interest::READABLE, ) .unwrap(); // no change yet poller.poll(Some(Duration::from_millis(0))).unwrap(); assert_eq!(poller.iter_events().next(), None); assert!(!watcher.file_changed()); // detect create fs::write(&file, "hello").unwrap(); wait_readable(&mut poller, token); assert!(watcher.file_changed()); assert!(!watcher.file_changed()); // detect modify fs::write(&file, "world").unwrap(); wait_readable(&mut poller, token); assert!(watcher.file_changed()); assert!(!watcher.file_changed()); // detect remove fs::remove_file(&file).unwrap(); wait_readable(&mut poller, token); assert!(watcher.file_changed()); assert!(!watcher.file_changed()); } } pushpin-1.41.0/src/core/http1/000077500000000000000000000000001504671364300160535ustar00rootroot00000000000000pushpin-1.41.0/src/core/http1/client.rs000066400000000000000000000502211504671364300176770ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::buffer::{Buffer, VecRingBuffer, VECTORED_MAX}; use crate::core::http1::error::Error; use crate::core::http1::protocol::{self, BodySize, Header, ParseScratch, ParseStatus}; use crate::core::http1::util::*; use crate::core::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, StdWriteWrapper, WriteHalf}; use crate::core::select::{select_2, Select2}; use std::cell::RefCell; use std::io::{self, Write}; use std::mem; use std::pin::pin; use std::pin::Pin; use std::str; pub struct Request<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: WriteHalf<'a, W>, hbuf: &'a mut VecRingBuffer, bbuf: &'a mut VecRingBuffer, } impl<'a, R: AsyncRead, W: AsyncWrite> Request<'a, R, W> { pub fn new( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut VecRingBuffer, buf2: &'a mut VecRingBuffer, ) -> Self { Self { r: stream.0, w: stream.1, hbuf: buf1, bbuf: buf2, } } #[allow(clippy::too_many_arguments)] pub fn prepare_header( self, method: &str, uri: &str, headers: &[Header<'_>], body_size: BodySize, websocket: bool, initial_body: &[u8], end: bool, ) -> Result, Error> { let req = protocol::ClientRequest::new(); let size_limit = self.hbuf.capacity(); let req_body = match req.send_header(self.hbuf, method, uri, headers, body_size, websocket) { Ok(ret) => ret, Err(_) => return Err(Error::RequestTooLarge(size_limit)), }; if self.bbuf.write_all(initial_body).is_err() { return Err(Error::BufferExceeded); } Ok(RequestHeader { r: self.r, w: self.w, hbuf: self.hbuf, bbuf: self.bbuf, req_body, end, }) } } pub struct RequestHeader<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: WriteHalf<'a, W>, hbuf: &'a mut VecRingBuffer, bbuf: &'a mut VecRingBuffer, req_body: protocol::ClientRequestBody, end: bool, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestHeader<'a, R, W> { pub async fn send(mut self) -> Result, Error> { while self.hbuf.len() > 0 { let size = self.w.write(Buffer::read_buf(self.hbuf)).await?; self.hbuf.read_commit(size); } let block_size = self.bbuf.capacity(); Ok(RequestBody { inner: RefCell::new(Some(RequestBodyInner { r: RefCell::new(RequestBodyRead { stream: self.r, buf: self.hbuf, }), w: RefCell::new(RequestBodyWrite { stream: self.w, buf: self.bbuf, req_body: Some(self.req_body), end: self.end, block_size, }), })), }) } } struct RequestBodyRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut VecRingBuffer, } struct RequestBodyWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, buf: &'a mut VecRingBuffer, req_body: Option, end: bool, block_size: usize, } struct RequestBodyInner<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, } pub struct RequestBody<'a, R: AsyncRead, W: AsyncWrite> { inner: RefCell>>, } impl<'a, R: AsyncRead, W: AsyncWrite> RequestBody<'a, R, W> { pub fn prepare(&self, src: &[u8], end: bool) -> Result { if let Some(inner) = &*self.inner.borrow() { let w = &mut *inner.w.borrow_mut(); // call not allowed if the end has already been indicated if w.end { return Err(Error::FurtherInputNotAllowed); } let size = match w.buf.write(src) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, Err(e) => panic!("infallible buffer write failed: {}", e), }; assert!(size <= src.len()); if size == src.len() && end { w.end = true; } Ok(size) } else { Err(Error::Unusable) } } pub fn expand_write_buffer(&self, blocks_max: usize, reserve: F) -> Result where F: FnMut() -> bool, { if let Some(inner) = &*self.inner.borrow() { let w = &mut *inner.w.borrow_mut(); Ok(resize_write_buffer_if_full( w.buf, w.block_size, blocks_max, reserve, )) } else { Err(Error::Unusable) } } pub fn can_send(&self) -> bool { if let Some(inner) = &*self.inner.borrow() { let w = &*inner.w.borrow(); w.buf.len() > 0 || w.end } else { false } } pub async fn send(&self) -> SendStatus, (), Error> { if self.inner.borrow().is_none() { return SendStatus::Error((), Error::Unusable); } let status = loop { if let Some(inner) = self.take_inner_if_early_response() { let r = inner.r.into_inner(); let w = inner.w.into_inner(); let resp = w.req_body.unwrap().into_early_response(); w.buf.clear(); return SendStatus::EarlyResponse(Response { r: r.stream, rbuf: r.buf, wbuf: w.buf, inner: resp, }); } match self.process().await { Some(Ok(status)) => break status, Some(Err(e)) => return SendStatus::Error((), e), None => {} // received data. loop and check for early response } }; let mut inner = self.inner.borrow_mut(); assert!(inner.is_some()); match status { protocol::SendStatus::Complete(resp, size) => { let inner = inner.take().unwrap(); let r = inner.r.into_inner(); let w = inner.w.into_inner(); w.buf.read_commit(size); assert_eq!(w.buf.len(), 0); SendStatus::Complete(Response { r: r.stream, rbuf: r.buf, wbuf: w.buf, inner: resp, }) } protocol::SendStatus::Partial(req_body, size) => { let inner = inner.as_ref().unwrap(); let mut w = inner.w.borrow_mut(); w.req_body = Some(req_body); w.buf.read_commit(size); SendStatus::Partial((), size) } protocol::SendStatus::Error(req_body, e) => { let inner = inner.as_ref().unwrap(); inner.w.borrow_mut().req_body = Some(req_body); SendStatus::Error((), e.into()) } } } #[allow(clippy::await_holding_refcell_ref)] pub async fn fill_recv_buffer(&self) -> Error { if let Some(inner) = &*self.inner.borrow() { let r = &mut *inner.r.borrow_mut(); loop { if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever std::future::pending::<()>().await; } return e.into(); } } } else { Error::Unusable } } // assumes self.inner is Some #[allow(clippy::await_holding_refcell_ref)] async fn process( &self, ) -> Option< Result< protocol::SendStatus< protocol::ClientResponse, protocol::ClientRequestBody, protocol::Error, >, Error, >, > { let inner = self.inner.borrow(); let inner = inner.as_ref().unwrap(); let mut r = inner.r.borrow_mut(); let result = select_2( AsyncOperation::new( |cx| { let w = &mut *inner.w.borrow_mut(); if !w.stream.is_writable() { return None; } let req_body = w.req_body.take().unwrap(); // req_body.send() expects the input to leave room for at // least two more buffers in case chunked encoding is // used (for chunked header and footer) let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = w.buf.read_bufs(&mut buf_arr); match req_body.send( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs, w.end, None, ) { protocol::SendStatus::Error(req_body, protocol::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { w.req_body = Some(req_body); None } ret => Some(ret), } }, || inner.w.borrow_mut().stream.cancel(), ), pin!(async { let r = &mut *r; if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever std::future::pending::<()>().await; } return Err(Error::from(e)); } Ok(()) }), ) .await; match result { Select2::R1(ret) => match ret { protocol::SendStatus::Error(req_body, protocol::Error::Io(e)) if e.kind() == io::ErrorKind::BrokenPipe => { // if we get an error when trying to send, it could be // due to the server closing the connection after sending // an early response. here we'll check if the server left // us any data to read let w = &mut *inner.w.borrow_mut(); w.req_body = Some(req_body); if r.buf.len() == 0 { let r = &mut *r; match recv_nonzero(&mut r.stream, r.buf).await { Ok(()) => None, // received data Err(e) => Some(Err(e.into())), // error while receiving data } } else { None // we already received data } } ret => Some(Ok(ret)), }, Select2::R2(ret) => match ret { Ok(()) => None, // received data Err(e) => Some(Err(e)), // error while receiving data }, } } // assumes self.inner is Some fn take_inner_if_early_response(&self) -> Option> { let mut inner = self.inner.borrow_mut(); let inner_mut = inner.as_mut().unwrap(); if inner_mut.r.borrow().buf.len() > 0 { Some(inner.take().unwrap()) } else { None } } } pub struct Response<'a, R: AsyncRead> { r: ReadHalf<'a, R>, rbuf: &'a mut VecRingBuffer, wbuf: &'a mut VecRingBuffer, inner: protocol::ClientResponse, } impl<'a, R: AsyncRead> Response<'a, R> { pub async fn recv_header<'b, const N: usize>( mut self, mut scratch: &'b mut ParseScratch, ) -> Result< ( protocol::OwnedResponse<'b, N>, ResponseBodyKeepHeader<'a, R>, ), Error, > { let mut resp = self.inner; let (resp, resp_body) = loop { { let buf = self.rbuf.take_inner(); resp = match resp.recv_header(buf, scratch) { ParseStatus::Complete(ret) => break ret, ParseStatus::Incomplete(resp, buf, ret_scratch) => { // NOTE: after polonius it may not be necessary for // scratch to be returned scratch = ret_scratch; self.rbuf.set_inner(buf); resp } ParseStatus::Error(e, buf, _) => { self.rbuf.set_inner(buf); return Err(e.into()); } } } // take_inner aligns assert!(self.rbuf.is_readable_contiguous()); if let Err(e) = recv_nonzero(&mut self.r, self.rbuf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } }; // at this point, resp has taken rbuf's inner buffer, such that // rbuf has no inner buffer // put remaining readable bytes in wbuf self.wbuf.write_all(resp.remaining_bytes())?; // swap inner buffers, such that rbuf now contains the remaining // readable bytes, and wbuf is now the one with no inner buffer self.rbuf.swap_inner(self.wbuf); Ok(( resp, ResponseBodyKeepHeader { inner: ResponseBody { inner: RefCell::new(Some(ResponseBodyInner { r: self.r, closed: false, rbuf: self.rbuf, resp_body, })), }, wbuf: RefCell::new(Some(self.wbuf)), }, )) } } struct ResponseBodyInner<'a, R: AsyncRead> { r: ReadHalf<'a, R>, closed: bool, rbuf: &'a mut VecRingBuffer, resp_body: protocol::ClientResponseBody, } pub struct ResponseBody<'a, R: AsyncRead> { inner: RefCell>>, } impl ResponseBody<'_, R> { // on EOF and any subsequent calls, return success #[allow(clippy::await_holding_refcell_ref)] pub async fn add_to_buffer(&self) -> Result<(), Error> { if let Some(inner) = &mut *self.inner.borrow_mut() { if !inner.closed { match recv_nonzero(&mut inner.r, inner.rbuf).await { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::WriteZero => { return Err(Error::BufferExceeded) } Err(e) if e.kind() == io::ErrorKind::UnexpectedEof => inner.closed = true, Err(e) => return Err(e.into()), } } Ok(()) } else { Err(Error::Unusable) } } pub fn try_recv(&self, dest: &mut [u8]) -> Result, Error> { loop { let mut b_inner = self.inner.borrow_mut(); if let Some(inner) = b_inner.take() { let mut scratch = mem::MaybeUninit::<[httparse::Header; HEADERS_MAX]>::uninit(); let src = Buffer::read_buf(inner.rbuf); let end = src.len() == inner.rbuf.len() && inner.closed; match inner.resp_body.recv(src, dest, end, &mut scratch)? { protocol::RecvStatus::NeedBytes(resp_body) => { *b_inner = Some(ResponseBodyInner { r: inner.r, closed: inner.closed, rbuf: inner.rbuf, resp_body, }); let inner = b_inner.as_mut().unwrap(); if !inner.rbuf.is_readable_contiguous() { inner.rbuf.align(); continue; } return Ok(RecvStatus::NeedBytes(())); } protocol::RecvStatus::Complete(finished, read, written) => { inner.rbuf.read_commit(read); *b_inner = None; return Ok(RecvStatus::Complete(Finished { inner: finished }, written)); } protocol::RecvStatus::Read(resp_body, read, written) => { *b_inner = Some(ResponseBodyInner { r: inner.r, closed: inner.closed, rbuf: inner.rbuf, resp_body, }); let inner = b_inner.as_mut().unwrap(); inner.rbuf.read_commit(read); if read > 0 && written == 0 { // input consumed but no output produced, retry continue; } // written is only zero here if read is also zero assert!(written > 0 || read == 0); return Ok(RecvStatus::Read((), written)); } } } else { return Err(Error::Unusable); } } } } pub struct ResponseBodyKeepHeader<'a, R: AsyncRead> { inner: ResponseBody<'a, R>, wbuf: RefCell>, } impl<'a, R: AsyncRead> ResponseBodyKeepHeader<'a, R> { pub fn discard_header( self, resp: protocol::OwnedResponse, ) -> Result, Error> { if let Some(wbuf) = self.wbuf.borrow_mut().take() { wbuf.set_inner(resp.into_buf()); wbuf.clear(); Ok(self.inner) } else { Err(Error::Unusable) } } pub async fn add_to_buffer(&self) -> Result<(), Error> { self.inner.add_to_buffer().await } pub fn try_recv( &self, dest: &mut [u8], ) -> Result>, Error> { if !self.wbuf.borrow().is_some() { return Err(Error::Unusable); } match self.inner.try_recv(dest)? { RecvStatus::Complete(finished, written) => Ok(RecvStatus::Complete( FinishedKeepHeader { inner: finished, wbuf: self.wbuf.borrow_mut().take().unwrap(), }, written, )), RecvStatus::Read((), written) => Ok(RecvStatus::Read((), written)), RecvStatus::NeedBytes(()) => Ok(RecvStatus::NeedBytes(())), } } } pub struct Finished { inner: protocol::ClientFinished, } impl Finished { pub fn is_persistent(&self) -> bool { self.inner.persistent } } pub struct FinishedKeepHeader<'a> { inner: Finished, wbuf: &'a mut VecRingBuffer, } impl FinishedKeepHeader<'_> { pub fn discard_header(self, resp: protocol::OwnedResponse) -> Finished { self.wbuf.set_inner(resp.into_buf()); self.wbuf.clear(); self.inner } pub fn is_persistent(&self) -> bool { self.inner.is_persistent() } } pushpin-1.41.0/src/core/http1/error.rs000066400000000000000000000021101504671364300175440ustar00rootroot00000000000000/* * Copyright (C) 2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::http1::protocol; use std::io; #[derive(Debug)] pub enum Error { Io(io::Error), Protocol(protocol::Error), RequestTooLarge(usize), ResponseTooLarge(usize), ResponseDuringContinue, FurtherInputNotAllowed, BufferExceeded, Unusable, } impl From for Error { fn from(e: io::Error) -> Self { Self::Io(e) } } impl From for Error { fn from(e: protocol::Error) -> Self { Self::Protocol(e) } } pushpin-1.41.0/src/core/http1/mod.rs000066400000000000000000000015421504671364300172020ustar00rootroot00000000000000/* * Copyright (C) 2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ mod error; mod protocol; mod util; pub mod client; pub mod server; pub use error::*; pub use protocol::{ parse_header_value, BodySize, Header, HeaderParamsIterator, ParseScratch, Request, Response, EMPTY_HEADER, }; pub use util::{RecvStatus, SendStatus}; pushpin-1.41.0/src/core/http1/protocol.rs000066400000000000000000005321421504671364300202710ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #![allow(clippy::collapsible_if)] #![allow(clippy::collapsible_else_if)] use crate::core::buffer::{write_vectored_offset, FilledBuf, LimitBufs, VECTORED_MAX}; use arrayvec::ArrayVec; use std::cmp; use std::convert::TryFrom; use std::io; use std::io::{Read, Write}; use std::mem; use std::str; const CHUNK_SIZE_MAX: usize = 0xffff; const CHUNK_HEADER_SIZE_MAX: usize = 6; // ffff\r\n const CHUNK_FOOTER: &[u8] = b"\r\n"; fn parse_as_int(src: &[u8]) -> Result { let int_str = str::from_utf8(src); let int_str = match int_str { Ok(int_str) => int_str, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidData)); } }; let x = int_str.parse(); let x = match x { Ok(x) => x, Err(_) => { return Err(io::Error::from(io::ErrorKind::InvalidData)); } }; Ok(x) } fn header_contains_param(value: &[u8], param: &[u8], ignore_case: bool) -> bool { let param_str = str::from_utf8(param); let param_str = match param_str { Ok(param_str) => param_str, Err(_) => { return false; } }; for part in value.split(|b| *b == b',') { let part_str = str::from_utf8(part); let part_str = match part_str { Ok(part_str) => part_str, Err(_) => { continue; } }; let part_str = part_str.trim(); if ignore_case { if part_str.eq_ignore_ascii_case(param_str) { return true; } } else { if part_str == param_str { return true; } } } false } fn find_one_of(s: &str, values: &[u8]) -> Option<(usize, u8)> { for (pos, &c) in s.as_bytes().iter().enumerate() { for v in values { if c == *v { return Some((pos, c)); } } } None } fn find_non_space(s: &str) -> Option { for (pos, c) in s.char_indices() { if !c.is_ascii_whitespace() { return Some(pos); } } None } // return (value, remainder) fn parse_quoted(s: &str) -> Result<(&str, &str), io::Error> { match s.find('"') { Some(pos) => Ok((&s[..pos], &s[(pos + 1)..])), None => Err(io::Error::from(io::ErrorKind::InvalidData)), } } // return (value, remainder). // remainder will start at the first non-space character following the param, // or will be empty fn parse_param_value(s: &str) -> Result<(&str, &str), io::Error> { let s = match find_non_space(s) { Some(pos) => &s[pos..], None => return Ok(("", "")), }; if s.as_bytes()[0] == b'"' { let (s, remainder) = parse_quoted(&s[1..])?; let remainder = match find_non_space(remainder) { Some(pos) => &remainder[pos..], None => "", }; Ok((s, remainder)) } else { let (s, remainder) = match find_one_of(s, b";,") { Some((pos, _)) => (&s[..pos], &s[pos..]), None => (s, ""), }; Ok((s.trim(), remainder)) } } pub struct HeaderParamsIterator<'a> { s: &'a str, done: bool, } impl<'a> HeaderParamsIterator<'a> { fn new(s: &'a str) -> Self { Self { s, done: false } } fn empty() -> Self { Self { s: "", done: true } } } impl<'a> Iterator for HeaderParamsIterator<'a> { type Item = Result<(&'a str, &'a str), io::Error>; fn next(&mut self) -> Option { if self.done { return None; } let (k, v, remainder, done) = match find_one_of(self.s, b"=;,") { Some((pos, b'=')) => { let k = &self.s[..pos]; let (v, remainder) = match parse_param_value(&self.s[(pos + 1)..]) { Ok(ret) => ret, Err(e) => return Some(Err(e)), }; let (remainder, done) = if !remainder.is_empty() { match remainder.as_bytes()[0] { b';' => (&remainder[1..], false), b',' => (remainder, true), _ => return Some(Err(io::Error::from(io::ErrorKind::InvalidData))), } } else { ("", true) }; (k, v, remainder, done) } Some((pos, b';')) => (&self.s[..pos], "", &self.s[(pos + 1)..], false), Some((pos, b',')) => (&self.s[..pos], "", &self.s[pos..], true), Some(_) => unreachable!(), None => (self.s, "", "", true), }; let k = k.trim(); if k.is_empty() { return Some(Err(io::Error::from(io::ErrorKind::InvalidData))); } self.s = remainder; self.done = done; Some(Ok((k, v))) } } pub struct HeaderValueIterator<'a> { s: &'a str, done: bool, } impl<'a> Iterator for HeaderValueIterator<'a> { type Item = Result<(&'a str, HeaderParamsIterator<'a>), io::Error>; fn next(&mut self) -> Option { if self.done { return None; } let (first_part, params, remainder, done) = match find_one_of(self.s, b";,") { Some((pos, b';')) => { // make a temporary params iterator let mut params = HeaderParamsIterator::new(&self.s[(pos + 1)..]); // drive it to the end for p in params.by_ref() { if let Err(e) = p { return Some(Err(e)); } } // when HeaderParamsIterator completes, its remaining value // will either start with a comma or be empty let (remainder, done) = if params.s.starts_with(',') { (¶ms.s[1..], false) } else if params.s.is_empty() { ("", true) } else { unreachable!(); }; // prepare a fresh iterator for returning let params = HeaderParamsIterator::new(&self.s[(pos + 1)..]); (&self.s[..pos], params, remainder, done) } Some((pos, b',')) => ( &self.s[..pos], HeaderParamsIterator::empty(), &self.s[(pos + 1)..], false, ), Some(_) => unreachable!(), None => (self.s, HeaderParamsIterator::empty(), "", true), }; let first_part = first_part.trim(); if first_part.is_empty() { return Some(Err(io::Error::from(io::ErrorKind::InvalidData))); } self.s = remainder; self.done = done; Some(Ok((first_part, params))) } } // parse a header value into parts pub fn parse_header_value(s: &[u8]) -> HeaderValueIterator<'_> { match str::from_utf8(s) { Ok(s) => HeaderValueIterator { s, done: false }, Err(_) => HeaderValueIterator { s: "", done: false }, } } #[derive(Debug, PartialEq, Clone, Copy)] struct Chunk { header: [u8; CHUNK_HEADER_SIZE_MAX], header_len: usize, size: usize, sent: usize, } // writes src to dest as chunks. current chunk state is passed in fn write_chunk( content: &[&[u8]], footer: &[u8], dest: &mut W, chunk: &mut Option, max_size: usize, ) -> Result { assert!(max_size <= CHUNK_SIZE_MAX); let mut content_len = 0; for buf in content.iter() { content_len += buf.len(); } if chunk.is_none() { let size = cmp::min(content_len, max_size); let mut h = [0; CHUNK_HEADER_SIZE_MAX]; let h_len = { let mut c = io::Cursor::new(&mut h[..]); write!(&mut c, "{:x}\r\n", size).unwrap(); c.position() as usize }; *chunk = Some(Chunk { header: h, header_len: h_len, size, sent: 0, }); } let chunkv = chunk.as_mut().unwrap(); let cheader = &chunkv.header[..chunkv.header_len]; let data_size = chunkv.size; let total = cheader.len() + data_size + footer.len(); let mut content = ArrayVec::<&[u8], { VECTORED_MAX - 2 }>::try_from(content).unwrap(); let content = content.as_mut_slice().limit(data_size); let size = { let mut out = ArrayVec::<&[u8], VECTORED_MAX>::new(); out.push(cheader); for buf in content.as_slice() { out.push(buf); } out.push(footer); write_vectored_offset(dest, out.as_slice(), chunkv.sent)? }; chunkv.sent += size; if chunkv.sent < total { return Ok(0); } *chunk = None; Ok(data_size) } #[cfg(test)] pub fn write_headers(writer: &mut W, headers: &[Header]) -> Result<(), io::Error> { for h in headers.iter() { write!(writer, "{}: ", h.name)?; writer.write(h.value)?; writer.write(b"\r\n")?; } writer.write(b"\r\n")?; Ok(()) } #[derive(Clone, Copy)] pub struct Header<'a> { pub name: &'a str, pub value: &'a [u8], } pub const EMPTY_HEADER: Header<'static> = Header { name: "", value: b"", }; #[derive(Debug, PartialEq, Clone, Copy)] pub enum BodySize { NoBody, Known(usize), Unknown, } pub struct ParseScratch { headers: [httparse::Header<'static>; N], } #[allow(clippy::new_without_default)] impl ParseScratch { pub fn new() -> Self { Self { headers: [httparse::EMPTY_HEADER; N], } } fn clear(&mut self) { self.headers.fill(httparse::EMPTY_HEADER); } } pub enum ParseStatus<'a, T, I, E, const N: usize> { Complete(T), Incomplete(I, FilledBuf, &'a mut ParseScratch), Error(E, FilledBuf, &'a mut ParseScratch), } struct OwnedParsedInner<'s, T, const N: usize> { parsed: T, scratch: &'s mut ParseScratch, buf: FilledBuf, size: usize, } struct OwnedHttparseRequest<'s, const N: usize> { inner: Option, N>>, } impl<'s, const N: usize> OwnedHttparseRequest<'s, N> { // on success, takes ownership of the buffer/scratch // on incomplete/error, returns the buffer/scratch fn parse( buf: FilledBuf, scratch: &'s mut ParseScratch, ) -> ParseStatus<'s, Self, (), httparse::Error, N> { let buf_ref: &[u8] = buf.filled(); let headers_mut: &mut [httparse::Header<'static>] = scratch.headers.as_mut(); // SAFETY: Self will take ownership of buf, and the bytes referred to // by buf_ref are on the heap, and buf will not be modified or // dropped until Self is dropped, so the bytes referred to by buf_ref // will remain valid for the lifetime of Self let buf_ref: &'static [u8] = unsafe { mem::transmute(buf_ref) }; // SAFETY: Self borrows scratch, and the location // referred to by headers_mut is on the heap, and the borrow will not // be released until Self is dropped, so the location referred to by // headers_mut will remain valid for the lifetime of Self // // further, it is safe for httparse::Request::parse() to write // references to buf_ref into headers_mut, because we guarantee buf // lives as long as scratch, except if into_buf() is called in // which case we clear the content of scratch let headers_mut: &'static mut [httparse::Header<'static>] = unsafe { mem::transmute(headers_mut) }; let mut req = httparse::Request::new(headers_mut); let size = match req.parse(buf_ref) { Ok(httparse::Status::Complete(size)) => size, Ok(httparse::Status::Partial) => return ParseStatus::Incomplete((), buf, scratch), Err(e) => return ParseStatus::Error(e, buf, scratch), }; ParseStatus::Complete(Self { inner: Some(OwnedParsedInner { parsed: req, scratch, buf, size, }), }) } fn get<'a>(&'a self) -> &'a httparse::Request<'a, 'a> { let s = self.inner.as_ref().unwrap(); let req = &s.parsed; // SAFETY: here we simply reduce the inner lifetimes to that of the owning // object, which is fine let req: &'a httparse::Request<'a, 'a> = unsafe { mem::transmute(req) }; req } fn remaining_bytes(&self) -> &[u8] { let s = self.inner.as_ref().unwrap(); &s.buf.filled()[s.size..] } fn into_parts(mut self) -> (FilledBuf, &'s mut ParseScratch) { let OwnedParsedInner { buf, scratch, .. } = self.inner.take().unwrap(); // SAFETY: ensure there are no references to buf in scratch scratch.clear(); (buf, scratch) } } impl Drop for OwnedHttparseRequest<'_, N> { fn drop(&mut self) { // SAFETY: ensure there are no references to buf in scratch if let Some(s) = &mut self.inner { s.scratch.clear(); } } } struct OwnedHttparseResponse<'s, const N: usize> { inner: Option, N>>, } impl<'s, const N: usize> OwnedHttparseResponse<'s, N> { // on success, takes ownership of the buffer/scratch // on incomplete/error, returns the buffer/scratch fn parse( buf: FilledBuf, scratch: &'s mut ParseScratch, ) -> ParseStatus<'s, Self, (), httparse::Error, N> { let buf_ref: &[u8] = buf.filled(); let headers_mut: &mut [httparse::Header<'static>] = scratch.headers.as_mut(); // SAFETY: Self will take ownership of buf, and the bytes referred to // by buf_ref are on the heap, and buf will not be modified or // dropped until Self is dropped, so the bytes referred to by buf_ref // will remain valid for the lifetime of Self let buf_ref: &'static [u8] = unsafe { mem::transmute(buf_ref) }; // SAFETY: Self borrows scratch, and the location // referred to by headers_mut is on the heap, and the borrow will not // be released until Self is dropped, so the location referred to by // headers_mut will remain valid for the lifetime of Self // // further, it is safe for httparse::Response::parse() to write // references to buf_ref into headers_mut, because we guarantee buf // lives as long as scratch, except if into_buf() is called in // which case we clear the content of scratch let headers_mut: &'static mut [httparse::Header<'static>] = unsafe { mem::transmute(headers_mut) }; let mut resp = httparse::Response::new(headers_mut); let size = match resp.parse(buf_ref) { Ok(httparse::Status::Complete(size)) => size, Ok(httparse::Status::Partial) => return ParseStatus::Incomplete((), buf, scratch), Err(e) => return ParseStatus::Error(e, buf, scratch), }; ParseStatus::Complete(Self { inner: Some(OwnedParsedInner { parsed: resp, scratch, buf, size, }), }) } fn get<'a>(&'a self) -> &'a httparse::Response<'a, 'a> { let s = self.inner.as_ref().unwrap(); let resp = &s.parsed; // SAFETY: here we simply reduce the inner lifetimes to that of the owning // object, which is fine let resp: &'a httparse::Response<'a, 'a> = unsafe { mem::transmute(resp) }; resp } fn remaining_bytes(&self) -> &[u8] { let s = self.inner.as_ref().unwrap(); &s.buf.filled()[s.size..] } fn into_parts(mut self) -> (FilledBuf, &'s mut ParseScratch) { let OwnedParsedInner { buf, scratch, .. } = self.inner.take().unwrap(); // SAFETY: ensure there are no references to buf in scratch scratch.clear(); (buf, scratch) } } impl Drop for OwnedHttparseResponse<'_, N> { fn drop(&mut self) { // SAFETY: ensure there are no references to buf in scratch if let Some(s) = &mut self.inner { s.scratch.clear(); } } } #[derive(Debug, PartialEq)] pub struct Request<'buf, 'headers> { pub method: &'buf str, pub uri: &'buf str, pub headers: &'headers [httparse::Header<'buf>], pub body_size: BodySize, pub expect_100: bool, } pub struct OwnedRequest<'s, const N: usize> { req: OwnedHttparseRequest<'s, N>, body_size: BodySize, expect_100: bool, } impl OwnedRequest<'_, N> { pub fn get(&self) -> Request<'_, '_> { let req = self.req.get(); Request { method: req.method.unwrap(), uri: req.path.unwrap(), headers: req.headers, body_size: self.body_size, expect_100: self.expect_100, } } pub fn remaining_bytes(&self) -> &[u8] { self.req.remaining_bytes() } pub fn into_buf(self) -> FilledBuf { self.req.into_parts().0 } } #[derive(Debug, PartialEq)] pub struct Response<'buf, 'headers> { pub code: u16, pub reason: &'buf str, pub headers: &'headers [httparse::Header<'buf>], pub body_size: BodySize, } pub struct OwnedResponse<'s, const N: usize> { resp: OwnedHttparseResponse<'s, N>, body_size: BodySize, } impl OwnedResponse<'_, N> { pub fn get(&self) -> Response<'_, '_> { let resp = self.resp.get(); Response { code: resp.code.unwrap(), reason: resp.reason.unwrap(), headers: resp.headers, body_size: self.body_size, } } pub fn remaining_bytes(&self) -> &[u8] { self.resp.remaining_bytes() } pub fn into_buf(self) -> FilledBuf { self.resp.into_parts().0 } } #[derive(Debug, PartialEq, Clone, Copy)] pub enum ServerState { // call: recv_request // next: ReceivingRequest, ReceivingBody, AwaitingResponse ReceivingRequest, // call: recv_body // next: ReceivingBody, AwaitingResponse ReceivingBody, // call: send_response // next: SendingBody AwaitingResponse, // call: send_body // next: SendingBody, Finished SendingBody, // request/response has completed Finished, } #[derive(Debug, thiserror::Error)] pub enum Error { #[error(transparent)] Parse(#[from] httparse::Error), #[error("invalid content length")] InvalidContentLength, #[error("unsupported transfer encoding")] UnsupportedTransferEncoding, #[error(transparent)] Io(#[from] io::Error), #[error("invalid chunk size")] InvalidChunkSize, #[error("chunk too large")] ChunkTooLarge, #[error("invalid chunk suffix")] InvalidChunkSuffix, } pub struct ServerProtocol { state: ServerState, ver_min: u8, body_size: BodySize, chunk_left: Option, chunk_size: usize, persistent: bool, chunked: bool, sending_chunk: Option, } #[allow(clippy::new_without_default)] impl<'buf, 'headers> ServerProtocol { pub fn new() -> Self { Self { state: ServerState::ReceivingRequest, ver_min: 0, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, persistent: false, chunked: false, sending_chunk: None, } } pub fn state(&self) -> ServerState { self.state } pub fn is_persistent(&self) -> bool { self.persistent } #[cfg(test)] pub fn recv_request( &mut self, rbuf: &mut io::Cursor<&'buf [u8]>, headers: &'headers mut [httparse::Header<'buf>], ) -> Option, Error>> { assert_eq!(self.state, ServerState::ReceivingRequest); let mut req = httparse::Request::new(headers); let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; let size = match req.parse(buf) { Ok(httparse::Status::Complete(size)) => size, Ok(httparse::Status::Partial) => return None, Err(e) => return Some(Err(Error::Parse(e))), }; let expect_100 = match self.process_request(&req) { Ok(ret) => ret, Err(e) => return Some(Err(e)), }; rbuf.set_position(rbuf.position() + (size as u64)); Some(Ok(Request { method: req.method.unwrap(), uri: req.path.unwrap(), headers: req.headers, body_size: self.body_size, expect_100, })) } pub fn recv_request_owned<'a, const N: usize>( &mut self, rbuf: FilledBuf, scratch: &'a mut ParseScratch, ) -> ParseStatus<'a, OwnedRequest<'a, N>, (), Error, N> { assert_eq!(self.state, ServerState::ReceivingRequest); let req = match OwnedHttparseRequest::parse(rbuf, scratch) { ParseStatus::Complete(req) => req, ParseStatus::Incomplete((), rbuf, scratch) => { return ParseStatus::Incomplete((), rbuf, scratch) } ParseStatus::Error(e, rbuf, scratch) => { return ParseStatus::Error(Error::Parse(e), rbuf, scratch) } }; let expect_100 = match self.process_request(req.get()) { Ok(ret) => ret, Err(e) => { let (buf, scratch) = req.into_parts(); return ParseStatus::Error(e, buf, scratch); } }; ParseStatus::Complete(OwnedRequest { req, body_size: self.body_size, expect_100, }) } pub fn skip_recv_request(&mut self) { assert_eq!(self.state, ServerState::ReceivingRequest); self.state = ServerState::AwaitingResponse; self.persistent = false; } #[allow(clippy::type_complexity)] pub fn recv_body( &mut self, rbuf: &mut io::Cursor<&'buf [u8]>, dest: &mut [u8], headers: &'headers mut [httparse::Header<'buf>], ) -> Result]>)>, Error> { assert_eq!(self.state, ServerState::ReceivingBody); match self.body_size { BodySize::Known(_) => { let mut chunk_left = self.chunk_left.unwrap(); let src_avail = cmp::min( chunk_left, rbuf.get_ref()[(rbuf.position() as usize)..].len(), ); let read_size = cmp::min(src_avail, dest.len()); // rbuf holds body as-is let size = rbuf.read(&mut dest[..read_size])?; chunk_left -= size; if chunk_left == 0 { self.chunk_left = None; self.state = ServerState::AwaitingResponse; } else { self.chunk_left = Some(chunk_left); // nothing to read? if src_avail == 0 { assert_eq!(size, 0); return Ok(None); } } Ok(Some((size, None))) } BodySize::Unknown => { if self.chunk_left.is_none() { let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; match httparse::parse_chunk_size(buf) { Ok(httparse::Status::Complete((pos, size))) => { let size = match u32::try_from(size) { Ok(size) => size, Err(_) => return Err(Error::ChunkTooLarge), }; let size = size as usize; rbuf.set_position(rbuf.position() + (pos as u64)); self.chunk_left = Some(size); self.chunk_size = size; } Ok(httparse::Status::Partial) => return Ok(None), Err(_) => return Err(Error::InvalidChunkSize), } } let mut chunk_left = self.chunk_left.unwrap(); if chunk_left > 0 { let src_avail = cmp::min( chunk_left, rbuf.get_ref()[(rbuf.position() as usize)..].len(), ); let read_size = cmp::min(src_avail, dest.len()); let size = rbuf.read(&mut dest[..read_size])?; chunk_left -= size; self.chunk_left = Some(chunk_left); // nothing to read? if src_avail == 0 { assert_eq!(size, 0); return Ok(None); } return Ok(Some((size, None))); } // done with content bytes. now to read the footer let mut trailing_headers = None; if chunk_left == 0 { let buf = &rbuf.get_ref()[(rbuf.position() as usize)..]; if self.chunk_size == 0 { // trailing headers match httparse::parse_headers(buf, headers) { Ok(httparse::Status::Complete((pos, headers))) => { rbuf.set_position(rbuf.position() + (pos as u64)); trailing_headers = Some(headers); } Ok(httparse::Status::Partial) => return Ok(None), Err(e) => return Err(Error::Parse(e)), } self.state = ServerState::AwaitingResponse; } else { if buf.len() < 2 { return Ok(None); } if &buf[..2] != b"\r\n" { return Err(Error::InvalidChunkSuffix); } rbuf.set_position(rbuf.position() + 2); } self.chunk_left = None; self.chunk_size = 0; } Ok(Some((0, trailing_headers))) } BodySize::NoBody => unreachable!(), } } pub fn send_100_continue(&mut self, writer: &mut W) -> Result<(), Error> { writer.write_all(b"HTTP/1.1 100 Continue\r\n\r\n")?; Ok(()) } pub fn send_response( &mut self, writer: &mut W, code: u16, reason: &str, headers: &[Header], body_size: BodySize, ) -> Result<(), Error> { assert!( self.state == ServerState::AwaitingResponse || self.state == ServerState::ReceivingBody ); let persistent = if self.state == ServerState::ReceivingBody { // when responding early, input stream may be broken false } else { self.persistent }; let mut body_size = body_size; // certain responses have no body match code { 100..=199 | 204 | 304 => { body_size = BodySize::NoBody; } _ => {} } let chunked = body_size == BodySize::Unknown && self.ver_min >= 1; if self.ver_min >= 1 { writer.write_all(b"HTTP/1.1 ")?; } else { writer.write_all(b"HTTP/1.0 ")?; } write!(writer, "{} {}\r\n", code, reason)?; for h in headers.iter() { // we'll override these headers if (h.name.eq_ignore_ascii_case("Connection") && code != 101) || h.name.eq_ignore_ascii_case("Content-Length") || h.name.eq_ignore_ascii_case("Transfer-Encoding") { continue; } write!(writer, "{}: ", h.name)?; writer.write_all(h.value)?; writer.write_all(b"\r\n")?; } // Connection header if persistent && self.ver_min == 0 { writer.write_all(b"Connection: keep-alive\r\n")?; } else if !persistent && self.ver_min >= 1 { writer.write_all(b"Connection: close\r\n")?; } if chunked { writer.write_all(b"Connection: Transfer-Encoding\r\n")?; } // Content-Length header if let BodySize::Known(x) = body_size { write!(writer, "Content-Length: {}\r\n", x)?; } // Transfer-Encoding header if chunked { writer.write_all(b"Transfer-Encoding: chunked\r\n")?; } writer.write_all(b"\r\n")?; self.state = ServerState::SendingBody; self.body_size = body_size; self.persistent = persistent; self.chunked = chunked; Ok(()) } pub fn send_body( &mut self, writer: &mut W, src: &[&[u8]], end: bool, headers: Option<&[u8]>, ) -> Result { assert_eq!(self.state, ServerState::SendingBody); let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if let BodySize::NoBody = self.body_size { // ignore the data if end { self.state = ServerState::Finished; } return Ok(src_len); } if !self.chunked { let size = write_vectored_offset(writer, src, 0)?; if end && size >= src_len { self.state = ServerState::Finished; } return Ok(size); } // chunked let mut content_written = 0; if src_len > 0 { content_written = write_chunk( src, CHUNK_FOOTER, writer, &mut self.sending_chunk, CHUNK_SIZE_MAX, )?; } // if all content is written then we can send the closing chunk if end && content_written >= src_len { let footer = if let Some(headers) = headers { headers } else { CHUNK_FOOTER }; write_chunk( &[b""], footer, writer, &mut self.sending_chunk, CHUNK_SIZE_MAX, )?; if self.sending_chunk.is_none() { self.state = ServerState::Finished; } } Ok(content_written) } fn process_request(&mut self, req: &httparse::Request) -> Result { let version = req.version.unwrap(); let mut content_len = None; let mut chunked = false; let mut keep_alive = false; let mut close = false; let mut expect_100 = false; for i in 0..req.headers.len() { let h = req.headers[i]; if h.name.eq_ignore_ascii_case("Content-Length") { let len = parse_as_int(h.value); let len = match len { Ok(len) => len, Err(_) => return Err(Error::InvalidContentLength), }; content_len = Some(len); } else if h.name.eq_ignore_ascii_case("Transfer-Encoding") { if h.value == b"chunked" { chunked = true; } else { // unknown transfer encoding return Err(Error::UnsupportedTransferEncoding); } } else if h.name.eq_ignore_ascii_case("Connection") { if !keep_alive && header_contains_param(h.value, b"keep-alive", true) { keep_alive = true; } if !close && header_contains_param(h.value, b"close", false) { close = true; } } else if h.name.eq_ignore_ascii_case("Expect") { if header_contains_param(h.value, b"100-continue", false) && version >= 1 { expect_100 = true; } } } self.ver_min = version; if chunked { self.body_size = BodySize::Unknown; } else if let Some(len) = content_len { self.body_size = BodySize::Known(len); self.chunk_left = Some(len); } else { self.body_size = BodySize::NoBody; } if version >= 1 { self.persistent = !close; } else { self.persistent = keep_alive && !close; } self.state = match self.body_size { BodySize::Unknown | BodySize::Known(_) => ServerState::ReceivingBody, BodySize::NoBody => ServerState::AwaitingResponse, }; let expect_100 = expect_100 && self.body_size != BodySize::NoBody; Ok(expect_100) } } struct ClientState { ver_min: u8, body_size: BodySize, chunk_left: Option, chunk_size: usize, persistent: bool, chunked: bool, sending_chunk: Option, } #[allow(clippy::new_without_default)] impl ClientState { fn new() -> Self { Self { ver_min: 1, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, persistent: true, chunked: false, sending_chunk: None, } } } pub struct ClientRequest { state: ClientState, } #[allow(clippy::new_without_default)] impl ClientRequest { pub fn new() -> Self { Self { state: ClientState::new(), } } pub fn send_header( mut self, writer: &mut W, method: &str, uri: &str, headers: &[Header], body_size: BodySize, websocket: bool, ) -> Result { let body_size = if websocket { BodySize::NoBody } else { body_size }; let chunked = body_size == BodySize::Unknown; write!(writer, "{} {} HTTP/1.1\r\n", method, uri)?; for h in headers.iter() { // we'll override these headers if (h.name.eq_ignore_ascii_case("Connection") && !websocket) || h.name.eq_ignore_ascii_case("Content-Length") || h.name.eq_ignore_ascii_case("Transfer-Encoding") { continue; } write!(writer, "{}: ", h.name)?; writer.write_all(h.value)?; writer.write_all(b"\r\n")?; } // Connection header if chunked { writer.write_all(b"Connection: Transfer-Encoding\r\n")?; } // Content-Length header if let BodySize::Known(x) = body_size { if x > 0 || !method.eq_ignore_ascii_case("OPTIONS") && !method.eq_ignore_ascii_case("GET") && !method.eq_ignore_ascii_case("HEAD") { write!(writer, "Content-Length: {}\r\n", x)?; } } // Transfer-Encoding header if chunked { writer.write_all(b"Transfer-Encoding: chunked\r\n")?; } writer.write_all(b"\r\n")?; self.state.body_size = body_size; self.state.chunked = chunked; Ok(ClientRequestBody { state: self.state }) } } pub enum SendStatus { Complete(T, usize), Partial(P, usize), Error(P, E), } pub enum RecvStatus { NeedBytes(T), Read(T, usize, usize), Complete(C, usize, usize), } pub struct ClientRequestBody { state: ClientState, } impl ClientRequestBody { pub fn into_early_response(mut self) -> ClientResponse { self.state.persistent = false; ClientResponse { state: self.state } } pub fn send( mut self, writer: &mut W, src: &[&[u8]], end: bool, headers: Option<&[u8]>, ) -> SendStatus { let state = &mut self.state; let mut src_len = 0; for buf in src.iter() { src_len += buf.len(); } if let BodySize::NoBody = state.body_size { // ignore the data if end { return SendStatus::Complete(ClientResponse { state: self.state }, 0); } return SendStatus::Partial(self, src_len); } if !state.chunked { let size = match write_vectored_offset(writer, src, 0) { Ok(ret) => ret, Err(e) => return SendStatus::Error(self, e.into()), }; assert!(size <= src_len); if end && size == src_len { return SendStatus::Complete(ClientResponse { state: self.state }, size); } return SendStatus::Partial(self, size); } // chunked let mut content_written = 0; if src_len > 0 { content_written = match write_chunk( src, CHUNK_FOOTER, writer, &mut state.sending_chunk, CHUNK_SIZE_MAX, ) { Ok(ret) => ret, Err(e) => return SendStatus::Error(self, e.into()), }; assert!(content_written <= src_len); } // if all content is written then we can send the closing chunk if end && content_written == src_len { let footer = if let Some(headers) = headers { headers } else { CHUNK_FOOTER }; match write_chunk( &[b""], footer, writer, &mut state.sending_chunk, CHUNK_SIZE_MAX, ) { Ok(ret) => ret, Err(e) => return SendStatus::Error(self, e.into()), }; if state.sending_chunk.is_none() { return SendStatus::Complete(ClientResponse { state: self.state }, content_written); } } SendStatus::Partial(self, content_written) } } pub struct ClientResponse { state: ClientState, } impl ClientResponse { pub fn recv_header( mut self, rbuf: FilledBuf, scratch: &mut ParseScratch, ) -> ParseStatus<'_, (OwnedResponse<'_, N>, ClientResponseBody), Self, Error, N> { let resp = match OwnedHttparseResponse::parse(rbuf, scratch) { ParseStatus::Complete(resp) => resp, ParseStatus::Incomplete((), rbuf, scratch) => { return ParseStatus::Incomplete(self, rbuf, scratch) } ParseStatus::Error(e, rbuf, scratch) => { return ParseStatus::Error(Error::Parse(e), rbuf, scratch) } }; if let Err(e) = self.process_response(resp.get()) { let (buf, scratch) = resp.into_parts(); return ParseStatus::Error(e, buf, scratch); } ParseStatus::Complete(( OwnedResponse { resp, body_size: self.state.body_size, }, ClientResponseBody { state: self.state }, )) } fn process_response(&mut self, resp: &httparse::Response) -> Result<(), Error> { let state = &mut self.state; let version = resp.version.unwrap(); let code = resp.code.unwrap(); let mut content_len = None; let mut chunked = false; let mut keep_alive = false; let mut close = false; for i in 0..resp.headers.len() { let h = resp.headers[i]; if h.name.eq_ignore_ascii_case("Content-Length") { let len = parse_as_int(h.value); let len = match len { Ok(len) => len, Err(_) => return Err(Error::InvalidContentLength), }; content_len = Some(len); } else if h.name.eq_ignore_ascii_case("Transfer-Encoding") { if h.value == b"chunked" { chunked = true; } else { // unknown transfer encoding return Err(Error::UnsupportedTransferEncoding); } } else if h.name.eq_ignore_ascii_case("Connection") { if !keep_alive && header_contains_param(h.value, b"keep-alive", true) { keep_alive = true; } if !close && header_contains_param(h.value, b"close", false) { close = true; } } } state.ver_min = version; state.chunked = false; if chunked { state.body_size = BodySize::Unknown; state.chunked = true; } else if let Some(len) = content_len { state.body_size = BodySize::Known(len); state.chunk_left = Some(len); } else { state.body_size = match code { 100..=199 | 204 | 304 => BodySize::NoBody, _ => BodySize::Unknown, }; } let close_end = state.body_size == BodySize::Unknown && !chunked; if version >= 1 { state.persistent = !close && !close_end; } else { state.persistent = keep_alive && !close && !close_end; } Ok(()) } } pub struct ClientResponseBody { state: ClientState, } impl ClientResponseBody { #[cfg(test)] pub fn size(&self) -> BodySize { self.state.body_size } pub fn recv<'buf, const N: usize>( self, src: &'buf [u8], dest: &mut [u8], end: bool, scratch: &mut mem::MaybeUninit<[httparse::Header<'buf>; N]>, ) -> Result, Error> { match self.state.body_size { BodySize::Known(_) => self.process_known_size(src, dest, end), BodySize::Unknown => { if self.state.chunked { self.process_unknown_size_chunked(src, dest, end, scratch) } else { self.process_unknown_size(src, dest, end) } } BodySize::NoBody => Ok(RecvStatus::Complete( ClientFinished { _headers_range: None, persistent: self.state.persistent, }, 0, 0, )), } } fn process_known_size( mut self, src: &[u8], dest: &mut [u8], end: bool, ) -> Result, Error> { let state = &mut self.state; let mut chunk_left = state.chunk_left.unwrap(); let max_read = cmp::min(chunk_left, src.len()); let src = &src[..max_read]; // src holds body as-is let mut rbuf = io::Cursor::new(src); let size = rbuf.read(dest)?; chunk_left -= size; if chunk_left == 0 { state.chunk_left = None; return Ok(RecvStatus::Complete( ClientFinished { _headers_range: None, persistent: state.persistent, }, size, size, )); } // we are expecting more bytes state.chunk_left = Some(chunk_left); // nothing to read? if src.is_empty() { assert_eq!(size, 0); // if the input has ended, return error if end { return Err(Error::Io(io::Error::from(io::ErrorKind::UnexpectedEof))); } return Ok(RecvStatus::NeedBytes(self)); } // there was something to read. however, whether anything actually // got read depends on the length of dest Ok(RecvStatus::Read(self, size, size)) } fn process_unknown_size( self, src: &[u8], dest: &mut [u8], end: bool, ) -> Result, Error> { // src holds body as-is let mut rbuf = io::Cursor::new(src); let size = rbuf.read(dest)?; // we're done when we've consumed the entire input if size == src.len() && end { return Ok(RecvStatus::Complete( ClientFinished { _headers_range: None, persistent: self.state.persistent, }, size, size, )); } // nothing to read? if src.is_empty() { assert_eq!(size, 0); return Ok(RecvStatus::NeedBytes(self)); } // there was something to read. however, whether anything actually // got read depends on the length of dest Ok(RecvStatus::Read(self, size, size)) } fn process_unknown_size_chunked<'buf, const N: usize>( mut self, src: &'buf [u8], dest: &mut [u8], end: bool, scratch: &mut mem::MaybeUninit<[httparse::Header<'buf>; N]>, ) -> Result, Error> { let state = &mut self.state; let mut pos = if state.chunk_left.is_none() { match httparse::parse_chunk_size(src) { Ok(httparse::Status::Complete((pos, size))) => { let size = match u32::try_from(size) { Ok(size) => size, Err(_) => return Err(Error::ChunkTooLarge), }; let size = size as usize; state.chunk_left = Some(size); state.chunk_size = size; pos } Ok(httparse::Status::Partial) => { if end { return Err(Error::Io(io::Error::from(io::ErrorKind::UnexpectedEof))); } return Ok(RecvStatus::NeedBytes(self)); } Err(_) => { return Err(Error::InvalidChunkSize); } } } else { 0 }; let mut chunk_left = state.chunk_left.unwrap(); if chunk_left > 0 { let max_read = cmp::min(chunk_left, src.len() - pos); let src = &src[pos..(pos + max_read)]; let mut rbuf = io::Cursor::new(src); let size = rbuf.read(dest)?; pos += size; chunk_left -= size; state.chunk_left = Some(chunk_left); // nothing to read? if src.is_empty() { assert_eq!(size, 0); if end { return Err(Error::Io(io::Error::from(io::ErrorKind::UnexpectedEof))); } // if pos advanced we need to return it if pos > 0 { return Ok(RecvStatus::Read(self, pos, 0)); } return Ok(RecvStatus::NeedBytes(self)); } // there was something to read. however, whether anything actually // got read depends on the length of dest return Ok(RecvStatus::Read(self, pos, size)); } // done with content bytes. now to read the footer // final chunk? if state.chunk_size == 0 { let src = &src[pos..]; // trailing headers let scratch = unsafe { scratch.assume_init_mut() }; match httparse::parse_headers(src, scratch) { Ok(httparse::Status::Complete((x, _))) => { let headers_start = pos; let headers_end = pos + x; return Ok(RecvStatus::Complete( ClientFinished { _headers_range: Some((headers_start, headers_end)), persistent: state.persistent, }, headers_end, 0, )); } Ok(httparse::Status::Partial) => { if end { return Err(Error::Io(io::Error::from(io::ErrorKind::UnexpectedEof))); } // if pos advanced we need to return it if pos > 0 { return Ok(RecvStatus::Read(self, pos, 0)); } return Ok(RecvStatus::NeedBytes(self)); } Err(e) => return Err(Error::Parse(e)), } } // for chunks of non-zero size, pos for header/content will have // already been returned by previous calls assert_eq!(pos, 0); if src.len() < 2 { if end { return Err(Error::Io(io::Error::from(io::ErrorKind::UnexpectedEof))); } return Ok(RecvStatus::NeedBytes(self)); } if &src[..2] != b"\r\n" { return Err(Error::InvalidChunkSuffix); } state.chunk_left = None; state.chunk_size = 0; Ok(RecvStatus::Read(self, 2, 0)) } } pub struct ClientFinished { _headers_range: Option<(usize, usize)>, pub persistent: bool, } #[cfg(test)] mod tests { use super::*; const HEADERS_MAX: usize = 32; struct MyBuffer { data: Vec, max: usize, allow_partial: bool, } impl MyBuffer { fn new(cap: usize, allow_partial: bool) -> Self { Self { data: Vec::new(), max: cap, allow_partial, } } } impl Write for MyBuffer { fn write(&mut self, buf: &[u8]) -> Result { let size = cmp::min(buf.len(), self.max - self.data.len()); if (size == 0 && !buf.is_empty()) || (size < buf.len() && !self.allow_partial) { return Err(io::Error::from(io::ErrorKind::WriteZero)); } self.data.extend_from_slice(&buf[..size]); Ok(size) } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { let mut total = 0; for buf in bufs { let size = match self.write(buf.as_ref()) { Ok(size) => size, Err(e) => { if e.kind() == io::ErrorKind::WriteZero && total > 0 { return Ok(total); } return Err(e); } }; total += size; if size < buf.len() { break; } } Ok(total) } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } struct TestRequest { pub method: String, pub uri: String, pub headers: Vec<(String, Vec)>, pub body: Vec, pub trailing_headers: Vec<(String, Vec)>, pub persistent: bool, } impl TestRequest { fn new() -> Self { Self { method: String::new(), uri: String::new(), headers: Vec::new(), body: Vec::new(), trailing_headers: Vec::new(), persistent: false, } } } struct TestResponse { pub code: u16, pub reason: String, pub headers: Vec<(String, Vec)>, pub body: Vec, pub chunked: bool, pub trailing_headers: Vec<(String, Vec)>, } impl TestResponse { fn new() -> Self { Self { code: 0, reason: String::new(), headers: Vec::new(), body: Vec::new(), chunked: false, trailing_headers: Vec::new(), } } } fn read_req(p: &mut ServerProtocol, src: &[u8], read_size: usize) -> TestRequest { const READ_SIZE_MAX: usize = 1024; const LOOPS_MAX: u32 = 20; assert!(read_size <= READ_SIZE_MAX); assert_eq!(p.state(), ServerState::ReceivingRequest); let rbuf = FilledBuf::new(src.to_vec(), src.len()); let mut result = TestRequest::new(); assert_eq!(p.state(), ServerState::ReceivingRequest); let mut scratch = ParseScratch::::new(); let req = match p.recv_request_owned(rbuf, &mut scratch) { ParseStatus::Complete(req) => req, _ => panic!("recv_request_owned did not return complete"), }; let mut rbuf = io::Cursor::new(req.remaining_bytes()); let req = req.get(); result.method = String::from(req.method); result.uri = String::from(req.uri); for h in req.headers { let name = String::from(h.name); let value = Vec::from(h.value); result.headers.push((name, value)); } for _ in 0..LOOPS_MAX { if p.state() != ServerState::ReceivingBody { break; } let mut buf = [0; READ_SIZE_MAX]; let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let (size, trailing_headers) = p .recv_body(&mut rbuf, &mut buf[..read_size], &mut headers) .unwrap() .unwrap(); result.body.extend_from_slice(&buf[..size]); if let Some(trailing_headers) = trailing_headers { for h in trailing_headers { let name = String::from(h.name); let value = Vec::from(h.value); result.trailing_headers.push((name, value)); } } } result.persistent = p.is_persistent(); assert_eq!(p.state(), ServerState::AwaitingResponse); return result; } fn write_resp(p: &mut ServerProtocol, resp: TestResponse, write_size: usize) -> Vec { const WRITE_SIZE_MAX: usize = 1024; const LOOPS_MAX: u32 = 20; assert!(write_size <= WRITE_SIZE_MAX); assert_eq!(p.state(), ServerState::AwaitingResponse); let mut header_out = [0; 1024]; let mut wbuf = io::Cursor::new(&mut header_out[..]); let mut headers = Vec::new(); for h in resp.headers.iter() { headers.push(Header { name: &h.0, value: &h.1, }); } let body_size = if resp.chunked { BodySize::Unknown } else { BodySize::Known(resp.body.len()) }; p.send_response(&mut wbuf, resp.code, &resp.reason, &headers, body_size) .unwrap(); let size = wbuf.position() as usize; let header_out = &header_out[..size]; let mut body_out = MyBuffer::new(0, true); let mut sent = 0; for _ in 0..LOOPS_MAX { if p.state() != ServerState::SendingBody { break; } body_out.max += write_size; let trailing_headers = if !resp.trailing_headers.is_empty() { let mut buf = Vec::new(); for (name, value) in resp.trailing_headers.iter() { write!(buf, "{}: ", name).unwrap(); buf.write(value).unwrap(); write!(buf, "\r\n").unwrap(); } write!(buf, "\r\n").unwrap(); Some(buf) } else { None }; let trailing_headers: Option<&[u8]> = if let Some(trailing_headers) = &trailing_headers { Some(trailing_headers) } else { None }; let size = match p.send_body(&mut body_out, &[&resp.body[sent..]], true, trailing_headers) { Ok(size) => size, Err(Error::Io(e)) if e.kind() == io::ErrorKind::WriteZero => 0, Err(_) => panic!("send_body failed"), }; sent += size; } assert_eq!(p.state(), ServerState::Finished); let mut out = Vec::new(); out.extend_from_slice(header_out); out.append(&mut body_out.data); out } #[test] fn test_parse_as_int() { // invalid utf8 assert!(parse_as_int(b"\xa0\xa1").is_err()); // not an integer assert!(parse_as_int(b"bogus").is_err()); // not a non-negative integer assert!(parse_as_int(b"-123").is_err()); // success assert_eq!(parse_as_int(b"0").unwrap(), 0); assert_eq!(parse_as_int(b"123").unwrap(), 123); } #[test] fn test_header_contains_param() { // param invalid utf8 assert_eq!(header_contains_param(b"", b"\xa0\xa1", false), false); // skip invalid utf8 part assert_eq!(header_contains_param(b"\xa0\xa1,a", b"a", false), true); // not found assert_eq!(header_contains_param(b"", b"a", false), false); assert_eq!(header_contains_param(b"a", b"b", false), false); assert_eq!(header_contains_param(b"a,b", b"c", false), false); // success assert_eq!(header_contains_param(b"a", b"a", false), true); assert_eq!(header_contains_param(b"a,b", b"a", false), true); assert_eq!(header_contains_param(b"a,b", b"b", false), true); assert_eq!(header_contains_param(b" a ,b", b"a", false), true); assert_eq!(header_contains_param(b"a, b ", b"b", false), true); assert_eq!(header_contains_param(b"A", b"a", true), true); } #[test] fn test_write_chunk() { struct Test { name: &'static str, write_space: usize, data: &'static [&'static [u8]], footer: &'static str, chunk: Option, max_size: usize, result: Result, chunk_after: Option, written: &'static str, } let tests = [ Test { name: "new-partial", write_space: 2, data: &[b"hello"], footer: "\r\n", chunk: None, max_size: CHUNK_SIZE_MAX, result: Ok(0), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), written: "5\r", }, Test { name: "resume-partial", write_space: 2, data: &[b"hello"], footer: "\r\n", chunk: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), max_size: CHUNK_SIZE_MAX, result: Ok(0), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 4, }), written: "\nh", }, Test { name: "error", write_space: 0, data: &[b"hello"], footer: "\r\n", chunk: None, max_size: CHUNK_SIZE_MAX, result: Err(io::Error::from(io::ErrorKind::WriteZero)), chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 0, }), written: "", }, Test { name: "complete", write_space: 1024, data: &[b"hello"], footer: "\r\n", chunk: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 4, }), max_size: CHUNK_SIZE_MAX, result: Ok(5), chunk_after: None, written: "ello\r\n", }, Test { name: "partial-content", write_space: 1024, data: &[b"hel", b"lo world"], footer: "\r\n", chunk: Some(Chunk { header: [b'7', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 7, sent: 0, }), max_size: 7, result: Ok(7), chunk_after: None, written: "7\r\nhello w\r\n", }, ]; for test in tests.iter() { let mut w = MyBuffer::new(test.write_space, true); let mut chunk = test.chunk.clone(); let r = write_chunk( test.data, test.footer.as_bytes(), &mut w, &mut chunk, test.max_size, ); match r { Ok(size) => { let expected = match &test.result { Ok(size) => size, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected, "test={}", test.name); } Err(_) => { assert!(test.result.is_err(), "test={}", test.name); } } assert_eq!(chunk, test.chunk_after, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_write_headers() { struct Test<'buf, 'headers> { name: &'static str, write_space: usize, headers: &'headers [Header<'buf>], err: bool, written: &'static str, } let tests = [ Test { name: "cant-write-header-name", write_space: 2, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A", }, Test { name: "cant-write-header-value", write_space: 3, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: ", }, Test { name: "cant-write-header-eol", write_space: 4, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: a", }, Test { name: "cant-write-eol", write_space: 13, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: true, written: "A: a\r\nB: b\r\n", }, Test { name: "success", write_space: 1024, headers: &[ Header { name: "A", value: b"a", }, Header { name: "B", value: b"b", }, ], err: false, written: "A: a\r\nB: b\r\n\r\n", }, ]; for test in tests.iter() { let mut w = MyBuffer::new(test.write_space, false); let r = write_headers(&mut w, test.headers); assert_eq!(r.is_err(), test.err, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_recv_request_header() { struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, result: Option, Error>>, state: ServerState, ver_min: u8, chunk_left: Option, persistent: bool, rbuf_position: u64, } let tests = [ Test { name: "partial", data: "G", result: None, state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "parse-error", data: "G\n", result: Some(Err(Error::Parse(httparse::Error::Token))), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "invalid-content-length", data: "GET / HTTP/1.0\r\nContent-Length: a\r\n\r\n", result: Some(Err(Error::InvalidContentLength)), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "unsupported-transfer-encoding", data: "GET / HTTP/1.0\r\nTransfer-Encoding: bogus\r\n\r\n", result: Some(Err(Error::UnsupportedTransferEncoding)), state: ServerState::ReceivingRequest, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 0, }, Test { name: "no-body", data: "GET / HTTP/1.0\r\nFoo: Bar\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Foo", value: b"Bar", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 28, }, Test { name: "body-size-known", data: "GET / HTTP/1.0\r\nContent-Length: 42\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Content-Length", value: b"42", }], body_size: BodySize::Known(42), expect_100: false, })), state: ServerState::ReceivingBody, ver_min: 0, chunk_left: Some(42), persistent: false, rbuf_position: 38, }, Test { name: "body-size-unknown", data: "GET / HTTP/1.0\r\nTransfer-Encoding: chunked\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Transfer-Encoding", value: b"chunked", }], body_size: BodySize::Unknown, expect_100: false, })), state: ServerState::ReceivingBody, ver_min: 0, chunk_left: None, persistent: false, rbuf_position: 46, }, Test { name: "1.0-persistent", data: "GET / HTTP/1.0\r\nConnection: keep-alive\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Connection", value: b"keep-alive", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 0, chunk_left: None, persistent: true, rbuf_position: 42, }, Test { name: "1.1-persistent", data: "GET / HTTP/1.1\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 1, chunk_left: None, persistent: true, rbuf_position: 18, }, Test { name: "1.1-non-persistent", data: "GET / HTTP/1.1\r\nConnection: close\r\n\r\n", result: Some(Ok(Request { method: "GET", uri: "/", headers: &[httparse::Header { name: "Connection", value: b"close", }], body_size: BodySize::NoBody, expect_100: false, })), state: ServerState::AwaitingResponse, ver_min: 1, chunk_left: None, persistent: false, rbuf_position: 37, }, Test { name: "expect-100", data: "POST / HTTP/1.1\r\nContent-Length: 10\r\nExpect: 100-continue\r\n\r\n", result: Some(Ok(Request { method: "POST", uri: "/", headers: &[ httparse::Header { name: "Content-Length", value: b"10", }, httparse::Header { name: "Expect", value: b"100-continue", }, ], body_size: BodySize::Known(10), expect_100: true, })), state: ServerState::ReceivingBody, ver_min: 1, chunk_left: Some(10), persistent: true, rbuf_position: 61, }, ]; for test in tests.iter() { let mut p = ServerProtocol::new(); let mut c = io::Cursor::new(test.data.as_bytes()); let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let r = p.recv_request(&mut c, &mut headers); match r { None => { assert!(test.result.is_none(), "test={}", test.name); } Some(Ok(req)) => { let expected = match &test.result { Some(Ok(req)) => req, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(req, *expected, "test={}", test.name); } Some(Err(e)) => { let expected = match &test.result { Some(Err(e)) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.ver_min, test.ver_min, "test={}", test.name); assert_eq!(p.chunk_left, test.chunk_left, "test={}", test.name); assert_eq!(p.is_persistent(), test.persistent, "test={}", test.name); assert_eq!(c.position(), test.rbuf_position, "test={}", test.name); } for test in tests.iter() { let mut p = ServerProtocol::new(); let src = test.data.as_bytes(); let rbuf = FilledBuf::new(src.to_vec(), src.len()); let mut scratch = ParseScratch::::new(); let mut rbuf_position = 0; let r = p.recv_request_owned(rbuf, &mut scratch); match r { ParseStatus::Complete(req) => { let expected = match &test.result { Some(Ok(req)) => req, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(req.get(), *expected, "test={}", test.name); rbuf_position = (src.len() - req.remaining_bytes().len()) as u64 } ParseStatus::Incomplete(_, _, _) => { assert!(test.result.is_none(), "test={}", test.name); } ParseStatus::Error(e, _, _) => { let expected = match &test.result { Some(Err(e)) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.ver_min, test.ver_min, "test={}", test.name); assert_eq!(p.chunk_left, test.chunk_left, "test={}", test.name); assert_eq!(p.is_persistent(), test.persistent, "test={}", test.name); assert_eq!(rbuf_position, test.rbuf_position, "test={}", test.name); } } #[test] fn test_recv_request_body() { struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, body_size: BodySize, chunk_left: Option, chunk_size: usize, result: Result]>)>, Error>, state: ServerState, chunk_left_after: Option, chunk_size_after: usize, rbuf_position: u64, dest_data: &'static str, } let tests = [ Test { name: "partial", data: "hel", body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, result: Ok(Some((3, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(2), chunk_size_after: 0, rbuf_position: 3, dest_data: "hel", }, Test { name: "complete", data: "hello", body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, result: Ok(Some((5, None))), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "hello", }, Test { name: "chunked-header-partial", data: "5", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(None), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-parse-error", data: "z", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(Error::InvalidChunkSize), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-too-large", data: "ffffffffff\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(Error::ChunkTooLarge), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-ok", data: "5\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(None), state: ServerState::ReceivingBody, chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 3, dest_data: "", }, Test { name: "chunked-content-partial", data: "5\r\nhel", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some((3, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(2), chunk_size_after: 5, rbuf_position: 6, dest_data: "hel", }, Test { name: "chunked-footer-partial-full-none", data: "5\r\nhello", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some((5, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-full-r", data: "5\r\nhello\r", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some((5, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-mid-r", data: "\r", body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, result: Ok(None), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-footer-parse-error", data: "XX", body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, result: Err(Error::InvalidChunkSuffix), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-complete-full", data: "5\r\nhello\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some((5, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-complete-mid", data: "lo\r\n", body_size: BodySize::Unknown, chunk_left: Some(2), chunk_size: 5, result: Ok(Some((2, None))), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 2, dest_data: "lo", }, Test { name: "chunked-complete-end", data: "\r\n", body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, result: Ok(Some((0, None))), state: ServerState::ReceivingBody, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 2, dest_data: "", }, Test { name: "chunked-empty", data: "0\r\n\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some((0, Some(&[])))), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "", }, Test { name: "trailing-headers-partial", data: "0\r\nhelloXX", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(None), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-parse-error", data: "0\r\nhelloXX\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Err(Error::Parse(httparse::Error::Token)), state: ServerState::ReceivingBody, chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-complete", data: "0\r\nFoo: Bar\r\n\r\n", body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, result: Ok(Some(( 0, Some(&[httparse::Header { name: "Foo", value: b"Bar", }]), ))), state: ServerState::AwaitingResponse, chunk_left_after: None, chunk_size_after: 0, rbuf_position: 15, dest_data: "", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::ReceivingBody, ver_min: 0, body_size: test.body_size, chunk_left: test.chunk_left, chunk_size: test.chunk_size, persistent: false, chunked: test.body_size == BodySize::Unknown, sending_chunk: None, }; let mut c = io::Cursor::new(test.data.as_bytes()); let mut dest = [0; 1024]; let mut dest_size = 0; let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let r = p.recv_body(&mut c, &mut dest, &mut headers); match r { Ok(Some((size, headers))) => { let (expected_size, expected_headers) = match &test.result { Ok(Some((size, headers))) => (size, headers), _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected_size, "test={}", test.name); assert_eq!(headers, *expected_headers, "test={}", test.name); dest_size = size; } Ok(None) => match &test.result { Ok(None) => {} _ => panic!("result mismatch: test={}", test.name), }, Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.chunk_left, test.chunk_left_after, "test={}", test.name); assert_eq!(p.chunk_size, test.chunk_size_after, "test={}", test.name); assert_eq!(c.position(), test.rbuf_position, "test={}", test.name); assert_eq!( str::from_utf8(&dest[..dest_size]).unwrap(), test.dest_data, "test={}", test.name ); } } #[test] fn test_send_response_header() { struct Test<'buf, 'headers> { name: &'static str, write_space: usize, code: u16, reason: &'static str, headers: &'headers [Header<'buf>], body_size: BodySize, ver_min: u8, persistent: bool, result: Result<(), Error>, state: ServerState, body_size_after: BodySize, chunked: bool, written: &'static str, } let tests = [ Test { name: "cant-write-1.1", write_space: 5, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 1, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "", }, Test { name: "cant-write-1.0", write_space: 5, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "", }, Test { name: "cant-write-status-line", write_space: 12, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200", }, Test { name: "cant-write-header-name", write_space: 20, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo", }, Test { name: "cant-write-header-value", write_space: 24, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: ", }, Test { name: "cant-write-header-eol", write_space: 26, code: 200, reason: "OK", headers: &[ Header { name: "Foo", value: b"Bar" }, ], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: Bar", }, Test { name: "cant-write-keep-alive", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: true, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "cant-write-close", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 1, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n", }, Test { name: "cant-write-transfer-encoding", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n", }, Test { name: "cant-write-content-length", write_space: 26, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(0), ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "cant-write-te-chunked", write_space: 50, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: true, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\n", }, Test { name: "cant-write-eol", write_space: 18, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::AwaitingResponse, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n", }, Test { name: "exclude-headers", write_space: 1024, code: 200, reason: "OK", headers: &[ Header { name: "Connection", value: b"X" }, Header { name: "Foo", value: b"Bar" }, Header { name: "Content-Length", value: b"X" }, Header { name: "Transfer-Encoding", value: b"X" }, ], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, written: "HTTP/1.0 200 OK\r\nFoo: Bar\r\n\r\n", }, Test { name: "exclude-headers-101", write_space: 1024, code: 101, reason: "Switching Protocols", headers: &[ Header { name: "Connection", value: b"X" }, Header { name: "Foo", value: b"Bar" }, Header { name: "Content-Length", value: b"X" }, Header { name: "Transfer-Encoding", value: b"X" }, ], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 101 Switching Protocols\r\nConnection: X\r\nFoo: Bar\r\n\r\n", }, Test { name: "1.0-no-body", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.0-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(42), ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, written: "HTTP/1.0 200 OK\r\nContent-Length: 42\r\n\r\n", }, Test { name: "1.0-no-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.1-no-body", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n\r\n", }, Test { name: "1.1-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Known(42), ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Known(42), chunked: false, written: "HTTP/1.1 200 OK\r\nContent-Length: 42\r\n\r\n", }, Test { name: "1.1-no-len", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::Unknown, chunked: true, written: "HTTP/1.1 200 OK\r\nConnection: Transfer-Encoding\r\nTransfer-Encoding: chunked\r\n\r\n", }, Test { name: "1.0-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\nConnection: keep-alive\r\n\r\n", }, Test { name: "1.0-non-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 200 OK\r\n\r\n", }, Test { name: "1.1-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: true, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\n\r\n", }, Test { name: "1.1-non-persistent", write_space: 1024, code: 200, reason: "OK", headers: &[], body_size: BodySize::NoBody, ver_min: 1, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", }, Test { name: "force-no-body", write_space: 1024, code: 101, reason: "Switching Protocols", headers: &[], body_size: BodySize::Known(42), ver_min: 0, persistent: false, result: Ok(()), state: ServerState::SendingBody, body_size_after: BodySize::NoBody, chunked: false, written: "HTTP/1.0 101 Switching Protocols\r\n\r\n", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::AwaitingResponse, ver_min: test.ver_min, body_size: BodySize::NoBody, chunk_left: None, chunk_size: 0, persistent: test.persistent, chunked: false, sending_chunk: None, }; let mut w = MyBuffer::new(test.write_space, false); let r = p.send_response(&mut w, test.code, test.reason, test.headers, test.body_size); match r { Ok(_) => { match &test.result { Ok(_) => {} _ => panic!("result mismatch: test={}", test.name), }; } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!(p.body_size, test.body_size_after, "test={}", test.name); assert_eq!(p.chunked, test.chunked, "test={}", test.name); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_send_response_body() { struct Test { name: &'static str, write_space: usize, src: &'static str, end: bool, headers: Option<&'static [u8]>, body_size: BodySize, chunked: bool, sending_chunk: Option, result: Result, state: ServerState, sending_chunk_after: Option, written: &'static str, } let tests = [ Test { name: "no-body", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "", }, Test { name: "no-body-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "", }, Test { name: "non-chunked-partial", write_space: 3, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(3), state: ServerState::SendingBody, sending_chunk_after: None, written: "hel", }, Test { name: "non-chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: None, written: "", }, Test { name: "non-chunked", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "hello", }, Test { name: "non-chunked-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "", }, Test { name: "chunked-partial", write_space: 2, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), written: "5\r", }, Test { name: "chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 0, }), written: "", }, Test { name: "chunked-complete", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(5), state: ServerState::SendingBody, sending_chunk_after: None, written: "5\r\nhello\r\n", }, Test { name: "end-partial", write_space: 2, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 2, }), written: "0\r", }, Test { name: "end-error", write_space: 0, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), state: ServerState::SendingBody, sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 0, }), written: "", }, Test { name: "end-complete", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "0\r\n\r\n", }, Test { name: "end-headers", write_space: 1024, src: "", end: true, headers: Some(b"Foo: Bar\r\n\r\n"), body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(0), state: ServerState::Finished, sending_chunk_after: None, written: "0\r\nFoo: Bar\r\n\r\n", }, Test { name: "content-and-end", write_space: 1024, src: "hello", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok(5), state: ServerState::Finished, sending_chunk_after: None, written: "5\r\nhello\r\n0\r\n\r\n", }, ]; for test in tests.iter() { let mut p = ServerProtocol { state: ServerState::SendingBody, ver_min: 0, body_size: test.body_size, chunk_left: None, chunk_size: 0, persistent: false, chunked: test.chunked, sending_chunk: test.sending_chunk, }; let mut w = MyBuffer::new(test.write_space, true); let r = p.send_body(&mut w, &[test.src.as_bytes()], test.end, test.headers); match r { Ok(size) => { let expected_size = match &test.result { Ok(size) => size, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(size, *expected_size, "test={}", test.name); } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(p.state(), test.state, "test={}", test.name); assert_eq!( p.sending_chunk, test.sending_chunk_after, "test={}", test.name ); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_server_req() { let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n".as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "GET"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 1); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body.len(), 0); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n", "0\r\n\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 0); assert_eq!(req.persistent, true); let data = concat!( "POST /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n", "0\r\n", "Foo: bar\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.method, "POST"); assert_eq!(req.uri, "/foo"); assert_eq!(req.headers.len(), 2); assert_eq!(req.headers[0].0, "Host"); assert_eq!(req.headers[0].1, b"example.com"); assert_eq!(req.body, b"hello\n"); assert_eq!(req.trailing_headers.len(), 1); assert_eq!(req.trailing_headers[0].0, "Foo"); assert_eq!(req.trailing_headers[0].1, b"bar"); assert_eq!(req.persistent, true); } #[test] fn test_server_resp() { let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; let mut p = ServerProtocol::new(); read_req(&mut p, data.as_bytes(), 2); let mut resp = TestResponse::new(); resp.code = 200; resp.reason = String::from("OK"); resp.headers = vec![(String::from("Content-Type"), b"text/plain".to_vec())]; resp.body = b"hello\n".to_vec(); let out = write_resp(&mut p, resp, 2); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); assert_eq!(str::from_utf8(&out).unwrap(), data); let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; let mut p = ServerProtocol::new(); read_req(&mut p, data.as_bytes(), 2); let mut resp = TestResponse::new(); resp.code = 200; resp.reason = String::from("OK"); resp.headers = vec![(String::from("Content-Type"), b"text/plain".to_vec())]; resp.body = b"hello\n".to_vec(); resp.chunked = true; let out = write_resp(&mut p, resp, 2); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n0\r\n\r\n", ); assert_eq!(str::from_utf8(&out).unwrap(), data); let data = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; let mut p = ServerProtocol::new(); read_req(&mut p, data.as_bytes(), 2); let mut resp = TestResponse::new(); resp.code = 200; resp.reason = String::from("OK"); resp.headers = vec![(String::from("Content-Type"), b"text/plain".to_vec())]; resp.body = b"hello\n".to_vec(); resp.chunked = true; resp.trailing_headers = vec![(String::from("Foo"), b"bar".to_vec())]; let out = write_resp(&mut p, resp, 2); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Connection: Transfer-Encoding\r\n", "Transfer-Encoding: chunked\r\n", "\r\n", "6\r\nhello\n\r\n", "0\r\n", "Foo: bar\r\n", "\r\n" ); assert_eq!(str::from_utf8(&out).unwrap(), data); } #[test] fn test_server_persistent() { // http 1.0 without keep alive let data = concat!("GET /foo HTTP/1.0\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, false); // http 1.0 with keep alive let data = concat!( "GET /foo HTTP/1.0\r\n", "Host: example.com\r\n", "Connection: keep-alive\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, true); // http 1.1 without keep alive let data = concat!( "GET /foo HTTP/1.1\r\n", "Host: example.com\r\n", "Connection: close\r\n", "\r\n" ) .as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, false); // http 1.1 with keep alive let data = concat!("GET /foo HTTP/1.1\r\n", "Host: example.com\r\n", "\r\n").as_bytes(); let mut p = ServerProtocol::new(); let req = read_req(&mut p, data, 2); assert_eq!(req.persistent, true); } #[test] fn test_send_request_header() { struct Test<'buf, 'headers> { name: &'static str, write_space: usize, method: &'static str, uri: &'static str, headers: &'headers [Header<'buf>], body_size: BodySize, websocket: bool, result: Result<(), Error>, body_size_after: BodySize, chunked: bool, written: &'static str, } let tests = [ Test { name: "cant-write", write_space: 2, method: "GET", uri: "/foo", headers: &[], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "", }, Test { name: "cant-write-request-line", write_space: 12, method: "GET", uri: "/foo", headers: &[], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "GET /foo", }, Test { name: "cant-write-header-name", write_space: 22, method: "GET", uri: "/foo", headers: &[Header { name: "Foo", value: b"Bar", }], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "GET /foo HTTP/1.1\r\nFoo", }, Test { name: "cant-write-header-value", write_space: 26, method: "GET", uri: "/foo", headers: &[Header { name: "Foo", value: b"Bar", }], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "GET /foo HTTP/1.1\r\nFoo: ", }, Test { name: "cant-write-header-eol", write_space: 28, method: "GET", uri: "/foo", headers: &[Header { name: "Foo", value: b"Bar", }], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "GET /foo HTTP/1.1\r\nFoo: Bar", }, Test { name: "cant-write-transfer-encoding", write_space: 27, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Unknown, websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Unknown, chunked: false, written: "POST /foo HTTP/1.1\r\n", }, Test { name: "cant-write-content-length", write_space: 27, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Known(0), websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Known(0), chunked: false, written: "POST /foo HTTP/1.1\r\n", }, Test { name: "cant-write-eol", write_space: 20, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Unknown, websocket: false, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), body_size_after: BodySize::Unknown, chunked: false, written: "POST /foo HTTP/1.1\r\n", }, Test { name: "exclude-headers", write_space: 1024, method: "POST", uri: "/foo", headers: &[ Header { name: "Connection", value: b"X", }, Header { name: "Foo", value: b"Bar", }, Header { name: "Content-Length", value: b"X", }, Header { name: "Transfer-Encoding", value: b"X", }, ], body_size: BodySize::Known(0), websocket: false, result: Ok(()), body_size_after: BodySize::Known(0), chunked: false, written: "POST /foo HTTP/1.1\r\nFoo: Bar\r\nContent-Length: 0\r\n\r\n", }, Test { name: "no-body", write_space: 1024, method: "GET", uri: "/foo", headers: &[], body_size: BodySize::NoBody, websocket: false, result: Ok(()), body_size_after: BodySize::NoBody, chunked: false, written: "GET /foo HTTP/1.1\r\n\r\n", }, Test { name: "len", write_space: 1024, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Known(42), websocket: false, result: Ok(()), body_size_after: BodySize::Known(42), chunked: false, written: "POST /foo HTTP/1.1\r\nContent-Length: 42\r\n\r\n", }, Test { name: "no-len", write_space: 1024, method: "POST", uri: "/foo", headers: &[], body_size: BodySize::Unknown, websocket: false, result: Ok(()), body_size_after: BodySize::Unknown, chunked: true, written: "POST /foo HTTP/1.1\r\nConnection: Transfer-Encoding\r\nTransfer-Encoding: chunked\r\n\r\n", }, Test { name: "force-no-body", write_space: 1024, method: "GET", uri: "/foo", headers: &[], body_size: BodySize::Known(42), websocket: true, result: Ok(()), body_size_after: BodySize::NoBody, chunked: false, written: "GET /foo HTTP/1.1\r\n\r\n", }, ]; for test in tests.iter() { let req = ClientRequest::new(); let mut w = MyBuffer::new(test.write_space, false); let r = req.send_header( &mut w, test.method, test.uri, test.headers, test.body_size, test.websocket, ); match r { Ok(req_body) => { match &test.result { Ok(_) => {} _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( req_body.state.body_size, test.body_size_after, "test={}", test.name ); assert_eq!(req_body.state.chunked, test.chunked, "test={}", test.name); } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_send_request_body() { struct Test { name: &'static str, write_space: usize, src: &'static str, end: bool, headers: Option<&'static [u8]>, body_size: BodySize, chunked: bool, sending_chunk: Option, result: Result<(bool, usize), Error>, sending_chunk_after: Option, written: &'static str, } let tests = [ Test { name: "no-body", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok((false, 5)), sending_chunk_after: None, written: "", }, Test { name: "no-body-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::NoBody, chunked: false, sending_chunk: None, result: Ok((true, 0)), sending_chunk_after: None, written: "", }, Test { name: "non-chunked-partial", write_space: 3, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok((false, 3)), sending_chunk_after: None, written: "hel", }, Test { name: "non-chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), sending_chunk_after: None, written: "", }, Test { name: "non-chunked", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok((false, 5)), sending_chunk_after: None, written: "hello", }, Test { name: "non-chunked-end", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: false, sending_chunk: None, result: Ok((true, 0)), sending_chunk_after: None, written: "", }, Test { name: "chunked-partial", write_space: 2, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((false, 0)), sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 2, }), written: "5\r", }, Test { name: "chunked-error", write_space: 0, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), sending_chunk_after: Some(Chunk { header: [b'5', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 5, sent: 0, }), written: "", }, Test { name: "chunked-complete", write_space: 1024, src: "hello", end: false, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((false, 5)), sending_chunk_after: None, written: "5\r\nhello\r\n", }, Test { name: "end-partial", write_space: 2, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((false, 0)), sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 2, }), written: "0\r", }, Test { name: "end-error", write_space: 0, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Err(Error::Io(io::Error::from(io::ErrorKind::WriteZero))), sending_chunk_after: Some(Chunk { header: [b'0', b'\r', b'\n', 0, 0, 0], header_len: 3, size: 0, sent: 0, }), written: "", }, Test { name: "end-complete", write_space: 1024, src: "", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((true, 0)), sending_chunk_after: None, written: "0\r\n\r\n", }, Test { name: "end-headers", write_space: 1024, src: "", end: true, headers: Some(b"Foo: Bar\r\n\r\n"), body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((true, 0)), sending_chunk_after: None, written: "0\r\nFoo: Bar\r\n\r\n", }, Test { name: "content-and-end", write_space: 1024, src: "hello", end: true, headers: None, body_size: BodySize::Unknown, chunked: true, sending_chunk: None, result: Ok((true, 5)), sending_chunk_after: None, written: "5\r\nhello\r\n0\r\n\r\n", }, ]; for test in tests.iter() { let req_body = ClientRequestBody { state: ClientState { ver_min: 0, body_size: test.body_size, chunk_left: None, chunk_size: 0, persistent: false, chunked: test.chunked, sending_chunk: test.sending_chunk, }, }; let mut w = MyBuffer::new(test.write_space, true); let (state, r) = match req_body.send(&mut w, &[test.src.as_bytes()], test.end, test.headers) { SendStatus::Complete(resp, size) => (resp.state, Ok((true, size))), SendStatus::Partial(req_body, size) => (req_body.state, Ok((false, size))), SendStatus::Error(req_body, e) => (req_body.state, Err(e)), }; match r { Ok((done, size)) => { let (expected_done, expected_size) = match &test.result { Ok(v) => v, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(done, *expected_done, "test={}", test.name); assert_eq!(size, *expected_size, "test={}", test.name); } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!( state.sending_chunk, test.sending_chunk_after, "test={}", test.name ); assert_eq!( str::from_utf8(&w.data).unwrap(), test.written, "test={}", test.name ); } } #[test] fn test_recv_response_header() { struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, result: Option, Error>>, ver_min: u8, chunk_left: Option, persistent: bool, chunked: bool, rbuf_position: u64, } let tests = [ Test { name: "partial", data: "H", result: None, ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 0, }, Test { name: "parse-error", data: "H\n", result: Some(Err(Error::Parse(httparse::Error::Token))), ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 0, }, Test { name: "invalid-content-length", data: "HTTP/1.0 200 OK\r\nContent-Length: a\r\n\r\n", result: Some(Err(Error::InvalidContentLength)), ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 0, }, Test { name: "unsupported-transfer-encoding", data: "HTTP/1.0 200 OK\r\nTransfer-Encoding: bogus\r\n\r\n", result: Some(Err(Error::UnsupportedTransferEncoding)), ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 0, }, Test { name: "no-body", data: "HTTP/1.0 204 No Content\r\nFoo: Bar\r\n\r\n", result: Some(Ok(Response { code: 204, reason: "No Content", headers: &[httparse::Header { name: "Foo", value: b"Bar", }], body_size: BodySize::NoBody, })), ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 37, }, Test { name: "body-size-known", data: "HTTP/1.0 200 OK\r\nContent-Length: 42\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[httparse::Header { name: "Content-Length", value: b"42", }], body_size: BodySize::Known(42), })), ver_min: 0, chunk_left: Some(42), persistent: false, chunked: false, rbuf_position: 39, }, Test { name: "body-size-unknown", data: "HTTP/1.0 200 OK\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[], body_size: BodySize::Unknown, })), ver_min: 0, chunk_left: None, persistent: false, chunked: false, rbuf_position: 19, }, Test { name: "body-size-unknown-chunked", data: "HTTP/1.0 200 OK\r\nTransfer-Encoding: chunked\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[httparse::Header { name: "Transfer-Encoding", value: b"chunked", }], body_size: BodySize::Unknown, })), ver_min: 0, chunk_left: None, persistent: false, chunked: true, rbuf_position: 47, }, Test { name: "1.0-persistent", data: "HTTP/1.0 200 OK\r\nContent-Length: 5\r\nConnection: keep-alive\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[ httparse::Header { name: "Content-Length", value: b"5", }, httparse::Header { name: "Connection", value: b"keep-alive", }, ], body_size: BodySize::Known(5), })), ver_min: 0, chunk_left: Some(5), persistent: true, chunked: false, rbuf_position: 62, }, Test { name: "1.1-persistent", data: "HTTP/1.1 200 OK\r\nContent-Length: 5\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[httparse::Header { name: "Content-Length", value: b"5", }], body_size: BodySize::Known(5), })), ver_min: 1, chunk_left: Some(5), persistent: true, chunked: false, rbuf_position: 38, }, Test { name: "1.1-non-persistent", data: "HTTP/1.1 200 OK\r\nConnection: close\r\n\r\n", result: Some(Ok(Response { code: 200, reason: "OK", headers: &[httparse::Header { name: "Connection", value: b"close", }], body_size: BodySize::Unknown, })), ver_min: 1, chunk_left: None, persistent: false, chunked: false, rbuf_position: 38, }, ]; for test in tests.iter() { let resp = ClientResponse { state: ClientState::new(), }; let src = test.data.as_bytes(); let rbuf = FilledBuf::new(src.to_vec(), src.len()); let mut scratch = ParseScratch::::new(); let mut rbuf_position = 0; let r = resp.recv_header(rbuf, &mut scratch); match r { ParseStatus::Complete((resp, resp_body)) => { let expected = match &test.result { Some(Ok(req)) => req, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(resp.get(), *expected, "test={}", test.name); assert_eq!(resp_body.state.ver_min, test.ver_min, "test={}", test.name); assert_eq!( resp_body.state.chunk_left, test.chunk_left, "test={}", test.name ); assert_eq!( resp_body.state.persistent, test.persistent, "test={}", test.name ); assert_eq!(resp_body.state.chunked, test.chunked, "test={}", test.name); rbuf_position = (src.len() - resp.remaining_bytes().len()) as u64 } ParseStatus::Incomplete(_, _, _) => { assert!(test.result.is_none(), "test={}", test.name); } ParseStatus::Error(e, _, _) => { let expected = match &test.result { Some(Err(e)) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } assert_eq!(rbuf_position, test.rbuf_position, "test={}", test.name); } } #[test] fn test_recv_response_body() { enum Status { NeedBytes, Read, Complete, } struct Test<'buf, 'headers> { name: &'static str, data: &'buf str, end: bool, body_size: BodySize, chunk_left: Option, chunk_size: usize, chunked: bool, result: Result<(Status, usize, Option<&'headers [httparse::Header<'buf>]>), Error>, chunk_left_after: Option, chunk_size_after: usize, rbuf_position: u64, dest_data: &'static str, } let tests = [ Test { name: "known-partial", data: "hel", end: false, body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, chunked: false, result: Ok((Status::Read, 3, None)), chunk_left_after: Some(2), chunk_size_after: 0, rbuf_position: 3, dest_data: "hel", }, Test { name: "known-partial-no-data", data: "", end: false, body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, chunked: false, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: Some(5), chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "known-partial-end", data: "hel", end: true, body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, chunked: false, result: Ok((Status::Read, 3, None)), chunk_left_after: Some(2), chunk_size_after: 0, rbuf_position: 3, dest_data: "hel", }, Test { name: "known-partial-end-no-data", data: "", end: true, body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, chunked: false, result: Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()), chunk_left_after: Some(5), chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "known-complete", data: "hello", end: false, body_size: BodySize::Known(5), chunk_left: Some(5), chunk_size: 0, chunked: false, result: Ok((Status::Complete, 5, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "hello", }, Test { name: "unknown-partial", data: "hel", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: false, result: Ok((Status::Read, 3, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 3, dest_data: "hel", }, Test { name: "unknown-partial-no-data", data: "", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: false, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "unknown-complete", data: "hello", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: false, result: Ok((Status::Complete, 5, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "hello", }, Test { name: "chunked-header-partial", data: "5", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-partial-end", data: "5", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-parse-error", data: "z", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(Error::InvalidChunkSize), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-too-large", data: "ffffffffff\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(Error::ChunkTooLarge), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-header-ok", data: "5\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 0, None)), chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 3, dest_data: "", }, Test { name: "chunked-content-partial", data: "5\r\nhel", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 3, None)), chunk_left_after: Some(2), chunk_size_after: 5, rbuf_position: 6, dest_data: "hel", }, Test { name: "chunked-content-partial-no-data", data: "5\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 0, None)), chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 3, dest_data: "", }, Test { name: "chunked-content-partial-end", data: "5\r\nhel", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 3, None)), chunk_left_after: Some(2), chunk_size_after: 5, rbuf_position: 6, dest_data: "hel", }, Test { name: "chunked-content-partial-end-no-data", data: "5\r\n", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()), chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 3, dest_data: "", }, Test { name: "chunked-content-mid-no-data", data: "", end: false, body_size: BodySize::Unknown, chunk_left: Some(5), chunk_size: 5, chunked: true, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: Some(5), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-footer-partial-full-none", data: "5\r\nhello", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 5, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-full-none-end", data: "5\r\nhello", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 5, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-full-r", data: "5\r\nhello\r", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 5, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-footer-partial-mid-r", data: "\r", end: false, body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, chunked: true, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-footer-partial-mid-r-last", data: "\r", end: false, body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 0, chunked: true, result: Ok((Status::NeedBytes, 0, None)), chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-footer-parse-error", data: "XX", end: false, body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, chunked: true, result: Err(Error::InvalidChunkSuffix), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 0, dest_data: "", }, Test { name: "chunked-complete-full", data: "5\r\nhello\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 5, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 8, dest_data: "hello", }, Test { name: "chunked-complete-mid", data: "lo\r\n", end: false, body_size: BodySize::Unknown, chunk_left: Some(2), chunk_size: 5, chunked: true, result: Ok((Status::Read, 2, None)), chunk_left_after: Some(0), chunk_size_after: 5, rbuf_position: 2, dest_data: "lo", }, Test { name: "chunked-complete-end", data: "\r\n", end: false, body_size: BodySize::Unknown, chunk_left: Some(0), chunk_size: 5, chunked: true, result: Ok((Status::Read, 0, None)), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 2, dest_data: "", }, Test { name: "chunked-empty", data: "0\r\n\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Complete, 0, Some(&[]))), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 5, dest_data: "", }, Test { name: "trailing-headers-partial", data: "0\r\nhelloXX", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok((Status::Read, 0, None)), chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-partial-end", data: "0\r\nhelloXX", end: true, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(io::Error::from(io::ErrorKind::UnexpectedEof).into()), chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-parse-error", data: "0\r\nhelloXX\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Err(Error::Parse(httparse::Error::Token)), chunk_left_after: Some(0), chunk_size_after: 0, rbuf_position: 3, dest_data: "", }, Test { name: "trailing-headers-complete", data: "0\r\nFoo: Bar\r\n\r\n", end: false, body_size: BodySize::Unknown, chunk_left: None, chunk_size: 0, chunked: true, result: Ok(( Status::Complete, 0, Some(&[httparse::Header { name: "Foo", value: b"Bar", }]), )), chunk_left_after: None, chunk_size_after: 0, rbuf_position: 15, dest_data: "", }, ]; for test in tests.iter() { let resp_body = ClientResponseBody { state: ClientState { ver_min: 0, body_size: test.body_size, chunk_left: test.chunk_left, chunk_size: test.chunk_size, persistent: false, chunked: test.chunked, sending_chunk: None, }, }; let mut dest = [0; 1024]; let mut scratch = mem::MaybeUninit::<[httparse::Header; HEADERS_MAX]>::uninit(); let r = resp_body.recv(test.data.as_bytes(), &mut dest, test.end, &mut scratch); let (r, headers) = match r { Ok(RecvStatus::Complete(finished, read, written)) => { if let Some((start, end)) = finished._headers_range { let headers_data = &test.data.as_bytes()[start..end]; let scratch = unsafe { scratch.assume_init_mut() }; match httparse::parse_headers(headers_data, scratch) { Ok(httparse::Status::Complete((pos, headers))) => { assert_eq!(pos, end - start); ( Ok(RecvStatus::Complete(finished, read, written)), Some(headers), ) } Ok(httparse::Status::Partial) => panic!("unexpected partial parse"), Err(e) => (Err(e.into()), None), } } else { (Ok(RecvStatus::Complete(finished, read, written)), None) } } r => (r, None), }; match r { Ok(RecvStatus::NeedBytes(resp_body)) => { match &test.result { Ok((Status::NeedBytes, _, _)) => {} _ => panic!("result mismatch: test={}", test.name), } let state = resp_body.state; assert_eq!( state.chunk_left, test.chunk_left_after, "test={}", test.name ); assert_eq!( state.chunk_size, test.chunk_size_after, "test={}", test.name ); } Ok(RecvStatus::Complete(_, read, written)) => { let (expected_size, expected_headers) = match &test.result { Ok((Status::Complete, size, headers)) => (size, headers), _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(written, *expected_size, "test={}", test.name); assert_eq!(read as u64, test.rbuf_position, "test={}", test.name); assert_eq!(headers, *expected_headers, "test={}", test.name); } Ok(RecvStatus::Read(resp_body, read, written)) => { let expected_size = match &test.result { Ok((Status::Read, size, _)) => size, _ => panic!("result mismatch: test={}", test.name), }; let state = resp_body.state; assert_eq!(written, *expected_size, "test={}", test.name); assert_eq!( state.chunk_left, test.chunk_left_after, "test={}", test.name ); assert_eq!( state.chunk_size, test.chunk_size_after, "test={}", test.name ); assert_eq!(read as u64, test.rbuf_position, "test={}", test.name); assert_eq!( str::from_utf8(&dest[..written]).unwrap(), test.dest_data, "test={}", test.name ); } Err(e) => { let expected = match &test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!( mem::discriminant(&e), mem::discriminant(expected), "test={}", test.name ); } } } } #[test] fn test_client_flow() { let req = ClientRequest::new(); let mut out = MyBuffer::new(1024, true); let req_body = req .send_header( &mut out, "GET", "/foo", &[Header { name: "Host", value: b"example.com", }], BodySize::NoBody, false, ) .unwrap(); let expected = "GET /foo HTTP/1.1\r\nHost: example.com\r\n\r\n"; assert_eq!(str::from_utf8(&out.data).unwrap(), expected); out.data.clear(); let resp = match req_body.send(&mut out, &[], true, None) { SendStatus::Complete(resp, read) => { assert_eq!(read, 0); resp } _ => panic!("unexpected status"), }; assert_eq!(out.data.is_empty(), true); let mut buf = Vec::new(); let data = concat!( "HTTP/1.1 200 OK\r\n", "Content-Type: text/plain\r\n", "Content-Length: 6\r\n", "\r\n", "hello\n", ); let size = buf.write(data.as_bytes()).unwrap(); buf.resize(1024, 0); let buf = FilledBuf::new(buf, size); let mut scratch = ParseScratch::::new(); let (resp, resp_body) = match resp.recv_header(buf, &mut scratch) { ParseStatus::Complete(ret) => ret, _ => panic!("unexpected status"), }; { let resp = resp.get(); assert_eq!(resp.code, 200); assert_eq!(resp.headers.len(), 2); assert_eq!(resp.headers[0].name, "Content-Type"); assert_eq!(resp.headers[0].value, b"text/plain"); assert_eq!(resp.headers[1].name, "Content-Length"); assert_eq!(resp.headers[1].value, b"6"); assert_eq!(resp_body.size(), BodySize::Known(6)); } let mut out = [0; 1024]; let mut scratch = mem::MaybeUninit::<[httparse::Header; HEADERS_MAX]>::uninit(); let finished = match resp_body .recv(resp.remaining_bytes(), &mut out, false, &mut scratch) .unwrap() { RecvStatus::Complete(finished, read, written) => { assert_eq!(read, 6); assert_eq!(written, 6); finished } _ => panic!("unexpected status"), }; assert_eq!(finished._headers_range, None); assert_eq!(finished.persistent, true); } fn collect_values<'a>( s: &'a [u8], ) -> Result)>, io::Error> { let mut out = Vec::new(); for part in parse_header_value(s) { let (name, params_iter) = part?; let mut params = Vec::new(); for p in params_iter { let (k, v) = p?; params.push((k, v)); } out.push((name, params)); } Ok(out) } #[test] fn test_parse_header_value() { struct Test { name: &'static str, value: &'static str, result: Result)>, io::Error>, } let tests = [ Test { name: "empty", value: "", result: Err(io::Error::from(io::ErrorKind::InvalidData)), }, Test { name: "value", value: "apple", result: Ok(vec![("apple", vec![])]), }, Test { name: "incomplete-value", value: "apple,", result: Err(io::Error::from(io::ErrorKind::InvalidData)), }, Test { name: "incomplete-param", value: "apple;", result: Err(io::Error::from(io::ErrorKind::InvalidData)), }, Test { name: "incomplete-second-param", value: "apple; type=gala;", result: Err(io::Error::from(io::ErrorKind::InvalidData)), }, Test { name: "value-with-param", value: "apple; type=gala", result: Ok(vec![("apple", vec![("type", "gala")])]), }, Test { name: "value-with-params", value: "apple; type=\"granny smith\"; color=green", result: Ok(vec![( "apple", vec![("type", "granny smith"), ("color", "green")], )]), }, Test { name: "values", value: "apple, banana, cherry", result: Ok(vec![ ("apple", vec![]), ("banana", vec![]), ("cherry", vec![]), ]), }, Test { name: "values-and-params", value: "apple, banana; color=yellow; ripe=true, cherry", result: Ok(vec![ ("apple", vec![]), ("banana", vec![("color", "yellow"), ("ripe", "true")]), ("cherry", vec![]), ]), }, Test { name: "spacing", value: "apple ,banana ;color= yellow ; ripe= \"true\" , cherry", result: Ok(vec![ ("apple", vec![]), ("banana", vec![("color", "yellow"), ("ripe", "true")]), ("cherry", vec![]), ]), }, ]; for test in tests { match collect_values(test.value.as_bytes()) { Ok(values) => { let expected = match test.result { Ok(v) => v, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(values, expected, "test={}", test.name); } Err(e) => { let expected = match test.result { Err(e) => e, _ => panic!("result mismatch: test={}", test.name), }; assert_eq!(e.kind(), expected.kind(), "test={}", test.name); } } } } } pushpin-1.41.0/src/core/http1/server.rs000066400000000000000000001147031504671364300177350ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::buffer::{Buffer, ContiguousBuffer, VecRingBuffer, VECTORED_MAX}; use crate::core::http1::error::Error; use crate::core::http1::protocol::{self, BodySize, Header, ParseScratch, ParseStatus}; use crate::core::http1::util::*; use crate::core::io::{AsyncRead, AsyncWrite, AsyncWriteExt, ReadHalf, StdWriteWrapper, WriteHalf}; use crate::core::select::{select_2, Select2}; use std::cell::{Cell, RefCell}; use std::io::{self, Write}; use std::pin::pin; use std::pin::Pin; use std::str; struct RequestInner<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: WriteHalf<'a, W>, rbuf: &'a mut VecRingBuffer, wbuf: &'a mut VecRingBuffer, protocol: protocol::ServerProtocol, send_is_dirty: bool, } pub struct Request; impl Request { pub fn new<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( stream: (ReadHalf<'a, R>, WriteHalf<'a, W>), buf1: &'a mut VecRingBuffer, buf2: &'a mut VecRingBuffer, ) -> (Self, Response<'a, R, W>) { ( Self, Response { inner: Some(RequestInner { r: stream.0, w: stream.1, rbuf: buf1, wbuf: buf2, protocol: protocol::ServerProtocol::new(), send_is_dirty: false, }), }, ) } pub fn recv_header<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite>( self, resp: &'b mut Response<'a, R, W>, ) -> RequestHeader<'a, 'b, R, W> { RequestHeader { inner: resp.inner.as_mut().unwrap(), } } } pub struct RequestHeader<'a, 'b, R: AsyncRead, W: AsyncWrite> { inner: &'b mut RequestInner<'a, R, W>, } impl<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite> RequestHeader<'a, 'b, R, W> { // read from stream into buf, and parse buf as a request header pub async fn recv<'c, const N: usize>( self, mut scratch: &'c mut ParseScratch, ) -> Result< ( protocol::OwnedRequest<'c, N>, RequestBodyKeepHeader<'a, 'b, R, W>, ), Error, > { assert_eq!( self.inner.protocol.state(), protocol::ServerState::ReceivingRequest ); let size_limit = self.inner.rbuf.remaining_capacity(); let req = loop { { let buf = self.inner.rbuf.take_inner(); match self.inner.protocol.recv_request_owned(buf, scratch) { ParseStatus::Complete(req) => break req, ParseStatus::Incomplete((), buf, ret_scratch) => { // NOTE: after polonius it may not be necessary for // scratch to be returned scratch = ret_scratch; self.inner.rbuf.set_inner(buf); } ParseStatus::Error(e, buf, _) => { self.inner.rbuf.set_inner(buf); return Err(e.into()); } } } // take_inner aligns assert!(self.inner.rbuf.is_readable_contiguous()); if let Err(e) = recv_nonzero(&mut self.inner.r, self.inner.rbuf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::RequestTooLarge(size_limit)); } return Err(e.into()); } }; assert!([ protocol::ServerState::ReceivingBody, protocol::ServerState::AwaitingResponse ] .contains(&self.inner.protocol.state())); // at this point, req has taken rbuf's inner buffer, such that // rbuf has no inner buffer // put remaining readable bytes in wbuf self.inner.wbuf.write_all(req.remaining_bytes())?; // swap inner buffers, such that rbuf now contains the remaining // readable bytes, and wbuf is now the one with no inner buffer self.inner.rbuf.swap_inner(self.inner.wbuf); let need_send_100 = req.get().expect_100; Ok(( req, RequestBodyKeepHeader { inner: Some(RequestBodyKeepHeaderInner { inner: RequestBody { inner: RefCell::new(Some(RequestBodyInner { r: &mut self.inner.r, w: &mut self.inner.w, rbuf: self.inner.rbuf, protocol: &mut self.inner.protocol, need_send_100, send_is_dirty: &mut self.inner.send_is_dirty, })), }, wbuf: self.inner.wbuf, }), }, )) } } struct RequestBodyInner<'a, 'b, R: AsyncRead, W: AsyncWrite> { r: &'b mut ReadHalf<'a, R>, w: &'b mut WriteHalf<'a, W>, rbuf: &'b mut VecRingBuffer, protocol: &'b mut protocol::ServerProtocol, need_send_100: bool, send_is_dirty: &'b mut bool, } pub struct RequestBody<'a, 'b, R: AsyncRead, W: AsyncWrite> { inner: RefCell>>, } impl<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite> RequestBody<'a, 'b, R, W> { // on EOF and any subsequent calls, return success #[allow(clippy::await_holding_refcell_ref)] pub async fn add_to_buffer(&self) -> Result<(), Error> { let mut b_inner = self.inner.borrow_mut(); if let Some(mut inner) = b_inner.take() { Self::handle_expect(&mut inner).await?; if let Err(e) = recv_nonzero(inner.r, inner.rbuf).await { if e.kind() == io::ErrorKind::WriteZero { return Err(Error::BufferExceeded); } return Err(e.into()); } *b_inner = Some(inner); Ok(()) } else { Err(Error::Unusable) } } pub fn try_recv(&self, dest: &mut [u8]) -> Result, Error> { loop { let mut b_inner = self.inner.borrow_mut(); if let Some(inner) = b_inner.take() { let (read, written, done, need_bytes) = if inner.protocol.state() == protocol::ServerState::ReceivingBody { let mut buf = io::Cursor::new(Buffer::read_buf(inner.rbuf)); let mut headers = [httparse::EMPTY_HEADER; HEADERS_MAX]; let (written, need_bytes) = match inner.protocol.recv_body(&mut buf, dest, &mut headers) { Ok(Some((written, _))) => (written, false), Ok(None) => (0, true), Err(e) => return Err(e.into()), }; let read = buf.position() as usize; ( read, written, inner.protocol.state() == protocol::ServerState::AwaitingResponse, need_bytes, ) } else { (0, 0, true, false) }; if done { inner.rbuf.read_commit(read); assert_eq!( inner.protocol.state(), protocol::ServerState::AwaitingResponse ); *b_inner = None; return Ok(RecvStatus::Complete((), written)); } else { *b_inner = Some(RequestBodyInner { r: inner.r, w: inner.w, rbuf: inner.rbuf, protocol: inner.protocol, need_send_100: inner.need_send_100, send_is_dirty: inner.send_is_dirty, }); let inner = b_inner.as_mut().unwrap(); if need_bytes { if read == 0 && !inner.rbuf.is_readable_contiguous() { inner.rbuf.align(); continue; } return Ok(RecvStatus::NeedBytes(())); } inner.rbuf.read_commit(read); if read > 0 && written == 0 { // input consumed but no output produced, retry continue; } // written is only zero here if read is also zero assert!(written > 0 || read == 0); return Ok(RecvStatus::Read((), written)); } } else { return Err(Error::Unusable); } } } async fn handle_expect(inner: &mut RequestBodyInner<'a, 'b, R, W>) -> Result<(), Error> { if !inner.need_send_100 { return Ok(()); } let mut cont = [0; 32]; let cont = { let mut c = io::Cursor::new(&mut cont[..]); inner.protocol.send_100_continue(&mut c).unwrap(); let size = c.position() as usize; &cont[..size] }; let mut left = cont.len(); while left > 0 { let pos = cont.len() - left; let size = match inner.w.write(&cont[pos..]).await { Ok(size) => size, Err(e) => return Err(e.into()), }; *inner.send_is_dirty = true; left -= size; } inner.need_send_100 = false; *inner.send_is_dirty = false; Ok(()) } } struct RequestBodyKeepHeaderInner<'a, 'b, R: AsyncRead, W: AsyncWrite> { inner: RequestBody<'a, 'b, R, W>, wbuf: &'b mut VecRingBuffer, } pub struct RequestBodyKeepHeader<'a, 'b, R: AsyncRead, W: AsyncWrite> { inner: Option>, } impl<'a: 'b, 'b, R: AsyncRead, W: AsyncWrite> RequestBodyKeepHeader<'a, 'b, R, W> { pub fn discard_header( mut self, req: protocol::OwnedRequest, ) -> RequestBody<'a, 'b, R, W> { let inner = self.inner.take().unwrap(); inner.wbuf.set_inner(req.into_buf()); inner.wbuf.clear(); inner.inner } pub async fn add_to_buffer(&self) -> Result<(), Error> { let inner = self.inner.as_ref().unwrap(); inner.inner.add_to_buffer().await } pub fn try_recv(&self, dest: &mut [u8]) -> Result, Error> { let inner = self.inner.as_ref().unwrap(); match inner.inner.try_recv(dest)? { RecvStatus::Complete((), written) => Ok(RecvStatus::Complete((), written)), RecvStatus::Read((), written) => Ok(RecvStatus::Read((), written)), RecvStatus::NeedBytes(()) => Ok(RecvStatus::NeedBytes(())), } } } impl Drop for RequestBodyKeepHeader<'_, '_, R, W> { fn drop(&mut self) { if self.inner.is_some() { panic!("RequestBodyKeepHeader must be consumed by discard_header() instead of dropped"); } } } pub struct Response<'a, R: AsyncRead, W: AsyncWrite> { inner: Option>, } impl<'a, R: AsyncRead, W: AsyncWrite> Response<'a, R, W> { pub async fn fill_recv_buffer(&mut self) -> Error { if let Some(inner) = &mut self.inner { loop { if let Err(e) = recv_nonzero(&mut inner.r, inner.rbuf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever std::future::pending::<()>().await; } return e.into(); } } } else { Error::Unusable } } #[allow(clippy::type_complexity)] pub fn prepare_header<'b>( &mut self, code: u16, reason: &str, headers: &[Header<'_>], body_size: BodySize, state: &'b mut ResponseState<'a, R, W>, ) -> Result< ( ResponseHeader<'a, 'b, R, W>, ResponsePrepareBody<'a, 'b, R, W>, ), Error, > { let inner = match &mut self.inner { Some(inner) => inner, None => return Err(Error::Unusable), }; if inner.protocol.state() == protocol::ServerState::ReceivingRequest { inner.protocol.skip_recv_request(); } inner.wbuf.clear(); let size_limit = inner.wbuf.capacity(); let header_size = { let mut buf = io::Cursor::new(inner.wbuf.write_buf()); if inner .protocol .send_response(&mut buf, code, reason, headers, body_size) .is_err() { // enable prepare_header to be called again inner.wbuf.clear(); return Err(Error::ResponseTooLarge(size_limit)); } buf.position() as usize }; inner.wbuf.write_commit(header_size); let inner = self.inner.take().unwrap(); *state.inner.borrow_mut() = Some(ResponseStateInner { r: inner.r, w: RefCell::new(inner.w), rbuf: inner.rbuf, wbuf: RefCell::new(LimitedRingBuffer { inner: inner.wbuf, limit: header_size, }), protocol: inner.protocol, overflow: RefCell::new(None), end: Cell::new(false), }); let state = &state.inner; Ok((ResponseHeader { state }, ResponsePrepareBody { state })) } } struct ResponseStateInner<'a, R: AsyncRead, W: AsyncWrite> { r: ReadHalf<'a, R>, w: RefCell>, rbuf: &'a mut VecRingBuffer, wbuf: RefCell>, protocol: protocol::ServerProtocol, overflow: RefCell>, end: Cell, } pub struct ResponseState<'a, R: AsyncRead, W: AsyncWrite> { inner: RefCell>>, } impl Default for ResponseState<'_, R, W> { fn default() -> Self { Self { inner: RefCell::new(None), } } } pub struct ResponseHeader<'a, 'b, R: AsyncRead, W: AsyncWrite> { state: &'b RefCell>>, } impl<'a, 'b, R: AsyncRead, W: AsyncWrite> ResponseHeader<'a, 'b, R, W> { #[allow(clippy::await_holding_refcell_ref)] pub async fn send(self) -> Result, Error> { // ok to hold across await as self.state is only ever immutably borrowed let state = self.state.borrow(); let state = state.as_ref().unwrap(); while state.wbuf.borrow().limit > 0 { // ok to hold across await as this is the only place state.w is borrowed let mut w = state.w.borrow_mut(); // TODO: vectored write let size = w.write_shared(&state.wbuf).await?; let mut wbuf = state.wbuf.borrow_mut(); wbuf.inner.read_commit(size); wbuf.limit -= size; } let mut overflow = state.overflow.borrow_mut(); if let Some(overflow_ref) = &mut *overflow { // overflow is guaranteed to fit let mut wbuf = state.wbuf.borrow_mut(); wbuf.inner.write_all(overflow_ref.read_buf()).unwrap(); *overflow = None; } Ok(ResponseHeaderSent { state: self.state }) } } pub struct ResponsePrepareBody<'a, 'b, R: AsyncRead, W: AsyncWrite> { state: &'b RefCell>>, } impl ResponsePrepareBody<'_, '_, R, W> { // only returns an error on invalid input pub fn prepare(&mut self, src: &[u8], end: bool) -> Result<(usize, usize), Error> { let state = self.state.borrow(); let state = state.as_ref().unwrap(); // call not allowed if the end has already been indicated if state.end.get() { return Err(Error::FurtherInputNotAllowed); } let wbuf = &mut *state.wbuf.borrow_mut(); let overflow = &mut *state.overflow.borrow_mut(); // workaround for rust 1.77 #[allow(clippy::unused_io_amount)] let accepted = if overflow.is_none() { match wbuf.inner.write(src) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, Err(e) => panic!("infallible buffer write failed: {}", e), } } else { 0 }; let (size, overflowed) = if accepted < src.len() { // only allow overflowing as much as there are header bytes left let overflow = overflow.get_or_insert_with(|| ContiguousBuffer::new(wbuf.limit)); let remaining = &src[accepted..]; let overflowed = match overflow.write(remaining) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, Err(e) => panic!("infallible buffer write failed: {}", e), }; (accepted + overflowed, overflowed) } else { (accepted, 0) }; assert!(size <= src.len()); if size == src.len() && end { state.end.set(true); } Ok((size, overflowed)) } } pub struct ResponseHeaderSent<'a, 'b, R: AsyncRead, W: AsyncWrite> { state: &'b RefCell>>, } impl<'a, 'b, R: AsyncRead, W: AsyncWrite> ResponseHeaderSent<'a, 'b, R, W> { pub fn start_body( self, _prepare_body: ResponsePrepareBody<'a, 'b, R, W>, ) -> ResponseBody<'a, R, W> { let state = self.state.take().unwrap(); let wbuf = state.wbuf.into_inner(); let block_size = wbuf.inner.capacity(); ResponseBody { inner: RefCell::new(Some(ResponseBodyInner { r: RefCell::new(ResponseBodyRead { stream: state.r, buf: state.rbuf, }), w: RefCell::new(ResponseBodyWrite { stream: state.w.into_inner(), buf: wbuf.inner, protocol: state.protocol, end: state.end.get(), block_size, }), })), } } } struct ResponseBodyRead<'a, R: AsyncRead> { stream: ReadHalf<'a, R>, buf: &'a mut VecRingBuffer, } struct ResponseBodyWrite<'a, W: AsyncWrite> { stream: WriteHalf<'a, W>, buf: &'a mut VecRingBuffer, protocol: protocol::ServerProtocol, end: bool, block_size: usize, } struct ResponseBodyInner<'a, R: AsyncRead, W: AsyncWrite> { r: RefCell>, w: RefCell>, } pub struct ResponseBody<'a, R: AsyncRead, W: AsyncWrite> { inner: RefCell>>, } impl ResponseBody<'_, R, W> { pub fn prepare(&self, src: &[u8], end: bool) -> Result { if let Some(inner) = &*self.inner.borrow() { let w = &mut *inner.w.borrow_mut(); // call not allowed if the end has already been indicated if w.end { return Err(Error::FurtherInputNotAllowed); } let size = match w.buf.write(src) { Ok(size) => size, Err(e) if e.kind() == io::ErrorKind::WriteZero => 0, Err(e) => panic!("infallible buffer write failed: {}", e), }; assert!(size <= src.len()); if size == src.len() && end { w.end = true; } Ok(size) } else { Err(Error::Unusable) } } pub fn expand_write_buffer(&self, blocks_max: usize, reserve: F) -> Result where F: FnMut() -> bool, { if let Some(inner) = &*self.inner.borrow() { let w = &mut *inner.w.borrow_mut(); Ok(resize_write_buffer_if_full( w.buf, w.block_size, blocks_max, reserve, )) } else { Err(Error::Unusable) } } pub fn can_send(&self) -> bool { if let Some(inner) = &*self.inner.borrow() { let w = &*inner.w.borrow(); w.buf.len() > 0 || w.end } else { false } } pub async fn send(&self) -> SendStatus { if self.inner.borrow().is_none() { return SendStatus::Error((), Error::Unusable); } let size = loop { match self.process().await { Some(Ok(size)) => break size, Some(Err(e)) => return SendStatus::Error((), e), None => {} // received data } }; let mut inner = self.inner.borrow_mut(); assert!(inner.is_some()); let done = { let inner = inner.as_ref().unwrap(); let mut w = inner.w.borrow_mut(); w.buf.read_commit(size); w.protocol.state() == protocol::ServerState::Finished }; if done { let inner = inner.take().unwrap(); let w = inner.w.into_inner(); assert_eq!(w.buf.len(), 0); SendStatus::Complete(Finished { protocol: w.protocol, }) } else { SendStatus::Partial((), size) } } #[allow(clippy::await_holding_refcell_ref)] pub async fn fill_recv_buffer(&self) -> Error { if let Some(inner) = &*self.inner.borrow() { let r = &mut *inner.r.borrow_mut(); loop { if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever std::future::pending::<()>().await; } return e.into(); } } } else { Error::Unusable } } // assumes self.inner is Some #[allow(clippy::await_holding_refcell_ref)] async fn process(&self) -> Option> { let inner = self.inner.borrow(); let inner = inner.as_ref().unwrap(); let mut r = inner.r.borrow_mut(); let result = select_2( AsyncOperation::new( |cx| { let w = &mut *inner.w.borrow_mut(); if !w.stream.is_writable() { return None; } assert_eq!(w.protocol.state(), protocol::ServerState::SendingBody); if w.buf.len() == 0 && !w.end { return Some(Ok(0)); } // protocol.send_body() expects the input to leave room // for at least two more buffers in case chunked encoding // is used (for chunked header and footer) let mut buf_arr = [&b""[..]; VECTORED_MAX - 2]; let bufs = w.buf.read_bufs(&mut buf_arr); match w.protocol.send_body( &mut StdWriteWrapper::new(Pin::new(&mut w.stream), cx), bufs, w.end, None, ) { Ok(size) => Some(Ok(size)), Err(protocol::Error::Io(e)) if e.kind() == io::ErrorKind::WouldBlock => { None } Err(e) => Some(Err(e.into())), } }, || inner.w.borrow_mut().stream.cancel(), ), pin!(async { let r = &mut *r; if let Err(e) = recv_nonzero(&mut r.stream, r.buf).await { if e.kind() == io::ErrorKind::WriteZero { // if there's no more space, suspend forever std::future::pending::<()>().await; } return Err(Error::from(e)); } Ok(()) }), ) .await; match result { Select2::R1(ret) => Some(ret), Select2::R2(ret) => match ret { Ok(()) => None, // received data Err(e) => Some(Err(e)), // error while receiving data }, } } } pub struct Finished { protocol: protocol::ServerProtocol, } impl Finished { pub fn is_persistent(&self) -> bool { self.protocol.is_persistent() } } #[cfg(test)] mod tests { use super::*; use crate::core::buffer::TmpBuffer; use crate::core::io::io_split; use std::cmp; use std::future::Future; use std::io::Read; use std::panic; use std::rc::Rc; use std::sync::Arc; use std::task::{Context, Poll, Wake}; struct FakeStream { in_data: Vec, out_data: Vec, } impl FakeStream { fn new() -> Self { Self { in_data: Vec::new(), out_data: Vec::new(), } } } impl AsyncRead for FakeStream { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context, buf: &mut [u8], ) -> Poll> { let size = cmp::min(buf.len(), self.in_data.len()); if size == 0 { return Poll::Pending; } let left = self.in_data.split_off(size); (&mut buf[..size]).copy_from_slice(&self.in_data); self.in_data = left; Poll::Ready(Ok(size)) } fn cancel(&mut self) {} } impl AsyncWrite for FakeStream { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8], ) -> Poll> { let size = self.out_data.write(buf).unwrap(); Poll::Ready(Ok(size)) } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { true } fn cancel(&mut self) {} } struct NoopWaker; impl Wake for NoopWaker { fn wake(self: Arc) {} } #[test] fn request_response() { let mut fut = pin!(async { let mut stream = FakeStream::new(); stream .in_data .write_all("POST /path HTTP/1.1\r\nContent-Length: 6\r\n\r\nhello\n".as_bytes()) .unwrap(); { let stream = RefCell::new(&mut stream); let tmp = Rc::new(TmpBuffer::new(1024)); let mut buf1 = VecRingBuffer::new(1024, &tmp); let mut buf2 = VecRingBuffer::new(1024, &tmp); let (req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); let header = req.recv_header(&mut resp); let mut scratch = ParseScratch::::new(); let (req_header, req_body) = header.recv(&mut scratch).await.unwrap(); // catch to avoid running req_body destructor let result = panic::catch_unwind(|| { let req_ref = req_header.get(); assert_eq!(req_ref.method, "POST"); assert_eq!(req_ref.uri, "/path"); assert_eq!(req_ref.body_size, BodySize::Known(6)); }); let req_body = req_body.discard_header(req_header); assert!(result.is_ok()); let mut buf = [0; 64]; let size = match req_body.try_recv(&mut buf).unwrap() { RecvStatus::Complete((), size) => size, _ => unreachable!(), }; drop(req_body); let buf = &buf[..size]; assert_eq!(str::from_utf8(buf).unwrap(), "hello\n"); let mut state = ResponseState::default(); let (header, prepare_body) = match resp.prepare_header(200, "OK", &[], BodySize::Known(6), &mut state) { Ok(ret) => ret, Err(_) => unreachable!(), }; let sent = header.send().await.unwrap(); let resp_body = sent.start_body(prepare_body); assert_eq!(resp_body.prepare(b"world\n", true).unwrap(), 6); let finished = match resp_body.send().await { SendStatus::Complete(finished) => finished, _ => unreachable!(), }; assert!(finished.is_persistent()); } let expected = "HTTP/1.1 200 OK\r\nContent-Length: 6\r\n\r\nworld\n"; assert_eq!(str::from_utf8(&stream.out_data).unwrap(), expected); }); let waker = Arc::new(NoopWaker).into(); let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } #[test] fn request_noncontiguous() { let mut fut = pin!(async { let mut stream = FakeStream::new(); stream .in_data .write_all("OST /path HTTP/1.1\r\nContent-Length: 6\r\n\r\nhello\n".as_bytes()) .unwrap(); { let stream = RefCell::new(&mut stream); let tmp = Rc::new(TmpBuffer::new(64)); let mut buf1 = VecRingBuffer::new(64, &tmp); let mut buf2 = VecRingBuffer::new(64, &tmp); // shift the write cursor and leave a "P" towards the end buf1.write_all(&[b'a'; 40]).unwrap(); buf1.write_all(b"P").unwrap(); assert_eq!(buf1.read(&mut [0; 40]).unwrap(), 40); let (req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); let header = req.recv_header(&mut resp); let mut scratch = ParseScratch::::new(); let (req_header, req_body) = header.recv(&mut scratch).await.unwrap(); // catch to avoid running req_body destructor let result = panic::catch_unwind(|| { let req_ref = req_header.get(); assert_eq!(req_ref.method, "POST"); assert_eq!(req_ref.uri, "/path"); assert_eq!(req_ref.body_size, BodySize::Known(6)); }); req_body.discard_header(req_header); assert!(result.is_ok()); } }); let waker = Arc::new(NoopWaker).into(); let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } #[test] fn response_during_header() { let mut fut = pin!(async { let mut stream = FakeStream::new(); stream .in_data .write_all("POST /path HTTP/1.1\r\nContent-Length: 6\r\n\r\nhello\n".as_bytes()) .unwrap(); { let stream = RefCell::new(&mut stream); let tmp = Rc::new(TmpBuffer::new(1024)); let mut buf1 = VecRingBuffer::new(1024, &tmp); let mut buf2 = VecRingBuffer::new(1024, &tmp); let (_req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); let mut state = ResponseState::default(); let (header, prepare_body) = match resp.prepare_header(200, "OK", &[], BodySize::Known(6), &mut state) { Ok(ret) => ret, Err(_) => unreachable!(), }; let sent = header.send().await.unwrap(); let resp_body = sent.start_body(prepare_body); assert_eq!(resp_body.prepare(b"world\n", true).unwrap(), 6); let finished = match resp_body.send().await { SendStatus::Complete(finished) => finished, _ => unreachable!(), }; assert!(!finished.is_persistent()); } let expected = "HTTP/1.0 200 OK\r\nContent-Length: 6\r\n\r\nworld\n"; assert_eq!(str::from_utf8(&stream.out_data).unwrap(), expected); }); let waker = Arc::new(NoopWaker).into(); let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } #[test] fn response_during_body() { let mut fut = pin!(async { let mut stream = FakeStream::new(); stream .in_data .write_all("POST /path HTTP/1.1\r\nContent-Length: 6\r\n\r\nhello\n".as_bytes()) .unwrap(); { let stream = RefCell::new(&mut stream); let tmp = Rc::new(TmpBuffer::new(1024)); let mut buf1 = VecRingBuffer::new(1024, &tmp); let mut buf2 = VecRingBuffer::new(1024, &tmp); let (req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); let header = req.recv_header(&mut resp); let mut scratch = ParseScratch::::new(); let (req_header, req_body) = header.recv(&mut scratch).await.unwrap(); // catch to avoid running req_body destructor let result = panic::catch_unwind(|| { let req_ref = req_header.get(); assert_eq!(req_ref.method, "POST"); assert_eq!(req_ref.uri, "/path"); assert_eq!(req_ref.body_size, BodySize::Known(6)); }); let req_body = req_body.discard_header(req_header); drop(req_body); assert!(result.is_ok()); let mut state = ResponseState::default(); let (header, prepare_body) = match resp.prepare_header(200, "OK", &[], BodySize::Known(6), &mut state) { Ok(ret) => ret, Err(_) => unreachable!(), }; let sent = header.send().await.unwrap(); let resp_body = sent.start_body(prepare_body); assert_eq!(resp_body.prepare(b"world\n", true).unwrap(), 6); let finished = match resp_body.send().await { SendStatus::Complete(finished) => finished, _ => unreachable!(), }; assert!(!finished.is_persistent()); } let expected = "HTTP/1.1 200 OK\r\nConnection: close\r\nContent-Length: 6\r\n\r\nworld\n"; assert_eq!(str::from_utf8(&stream.out_data).unwrap(), expected); }); let waker = Arc::new(NoopWaker).into(); let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } #[test] fn response_overflow() { let mut fut = pin!(async { let mut stream = FakeStream::new(); stream .in_data .write_all("GET /path HTTP/1.1\r\n\r\n".as_bytes()) .unwrap(); let mut body = [0; 100]; for i in 0..body.len() { body[i] = b'a' + ((i as u8) % 26); } let attempted_body = str::from_utf8(&body).unwrap(); let expected_body = &attempted_body[..64]; { let stream = RefCell::new(&mut stream); let tmp = Rc::new(TmpBuffer::new(64)); let mut buf1 = VecRingBuffer::new(64, &tmp); let mut buf2 = VecRingBuffer::new(64, &tmp); let (req, mut resp) = Request::new(io_split(&stream), &mut buf1, &mut buf2); let header = req.recv_header(&mut resp); let mut scratch = ParseScratch::::new(); let (req_header, req_body) = header.recv(&mut scratch).await.unwrap(); let req_body = req_body.discard_header(req_header); drop(req_body); let mut state = ResponseState::default(); // this will serialize to 39 bytes, leaving 25 bytes left let (header, mut prepare_body) = match resp.prepare_header(200, "OK", &[], BodySize::Known(64), &mut state) { Ok(ret) => ret, Err(_) => unreachable!(), }; // only the first 64 bytes will fit assert_eq!( prepare_body .prepare(attempted_body.as_bytes(), true) .unwrap(), (64, 39) ); // end is ignored if input doesn't fit, so set end again assert_eq!(prepare_body.prepare(&[], true).unwrap(), (0, 0)); let sent = header.send().await.unwrap(); let resp_body = sent.start_body(prepare_body); let size = match resp_body.send().await { SendStatus::Partial(_, size) => size, _ => unreachable!(), }; assert_eq!(size, 25); let finished = match resp_body.send().await { SendStatus::Complete(finished) => finished, _ => unreachable!(), }; assert!(finished.is_persistent()); } let expected = "HTTP/1.1 200 OK\r\nContent-Length: 64\r\n\r\n".to_string() + expected_body; assert_eq!(str::from_utf8(&stream.out_data).unwrap(), expected); }); let waker = Arc::new(NoopWaker).into(); let mut cx = Context::from_waker(&waker); assert!(fut.as_mut().poll(&mut cx).is_ready()); } } pushpin-1.41.0/src/core/http1/util.rs000066400000000000000000000062261504671364300174040ustar00rootroot00000000000000/* * Copyright (C) 2023-2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::buffer::{Buffer, VecRingBuffer}; use crate::core::io::{AsyncRead, AsyncReadExt}; use std::cmp; use std::future::Future; use std::io; use std::pin::Pin; use std::task::{Context, Poll}; // some reasonable number pub const HEADERS_MAX: usize = 64; // return the capacity increase pub fn resize_write_buffer_if_full( buf: &mut VecRingBuffer, block_size: usize, blocks_max: usize, mut reserve: F, ) -> usize where F: FnMut() -> bool, { assert!(blocks_max >= 2); // all but one block can be used for writing let allowed = blocks_max - 1; if buf.remaining_capacity() == 0 && buf.capacity() < block_size.checked_mul(allowed).unwrap() && reserve() { buf.resize(buf.capacity() + block_size); block_size } else { 0 } } pub async fn recv_nonzero( r: &mut R, buf: &mut VecRingBuffer, ) -> Result<(), io::Error> { if buf.remaining_capacity() == 0 { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let size = match r.read(buf.write_buf()).await { Ok(size) => size, Err(e) => return Err(e), }; buf.write_commit(size); if size == 0 { return Err(io::Error::from(io::ErrorKind::UnexpectedEof)); } Ok(()) } pub struct LimitedRingBuffer<'a> { pub inner: &'a mut VecRingBuffer, pub limit: usize, } impl AsRef<[u8]> for LimitedRingBuffer<'_> { fn as_ref(&self) -> &[u8] { let buf = Buffer::read_buf(self.inner); let limit = cmp::min(buf.len(), self.limit); &buf[..limit] } } pub struct AsyncOperation where C: FnMut(), { op_fn: O, cancel_fn: C, } impl AsyncOperation where O: FnMut(&mut Context) -> Option, C: FnMut(), { pub fn new(op_fn: O, cancel_fn: C) -> Self { Self { op_fn, cancel_fn } } } impl Future for AsyncOperation where O: FnMut(&mut Context) -> Option + Unpin, C: FnMut() + Unpin, { type Output = R; fn poll(self: Pin<&mut Self>, cx: &mut Context) -> Poll { let s = Pin::into_inner(self); match (s.op_fn)(cx) { Some(ret) => Poll::Ready(ret), None => Poll::Pending, } } } impl Drop for AsyncOperation where C: FnMut(), { fn drop(&mut self) { (self.cancel_fn)(); } } pub enum SendStatus { Complete(T), EarlyResponse(T), Partial(P, usize), Error(P, E), } pub enum RecvStatus { NeedBytes(T), Read(T, usize), Complete(C, usize), } pushpin-1.41.0/src/core/httpheaders.cpp000066400000000000000000000145061504671364300200400ustar00rootroot00000000000000/* * Copyright (C) 2012-2017 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "httpheaders.h" // return position, end of string if not found, -1 on error static int findNonQuoted(const QByteArray &in, char c, int offset = 0) { bool inQuote = false; for(int n = offset; n < in.size(); ++n) { char i = in[n]; if(inQuote) { if(i == '\\') { ++n; // no character after the escape if(n >= in.size()) { return -1; } } else if(i == '\"') inQuote = false; } else { if(i == '\"') { inQuote = true; } else if(i == c) { return n; } } } // unterminated quote if(inQuote) { return -1; } return in.size(); } // search for one of many chars static int findNext(const QByteArray &in, const char *charList, int offset = 0) { int len = qstrlen(charList); for(int n = offset; n < in.size(); ++n) { char c = in[n]; for(int i = 0; i < len; ++i) { if(c == charList[i]) return n; } } return -1; } static QList headerSplit(const QByteArray &in) { QList parts; int pos = 0; while(pos < in.size()) { int end = findNonQuoted(in, ',', pos); if(end != -1) { parts += in.mid(pos, end - pos).trimmed(); if(end < in.size()) pos = end + 1; else pos = in.size(); } else { parts += in.mid(pos).trimmed(); pos = in.size(); } } return parts; } bool HttpHeaderParameters::contains(const QByteArray &key) const { for(int n = 0; n < count(); ++n) { if(qstricmp(at(n).first.data(), key.data()) == 0) return true; } return false; } QByteArray HttpHeaderParameters::get(const QByteArray &key) const { for(int n = 0; n < count(); ++n) { const HttpHeaderParameter &h = at(n); if(qstricmp(h.first.data(), key.data()) == 0) return h.second; } return QByteArray(); } bool HttpHeaders::contains(const QByteArray &key) const { for(int n = 0; n < count(); ++n) { if(qstricmp(at(n).first.data(), key.data()) == 0) return true; } return false; } QByteArray HttpHeaders::get(const QByteArray &key) const { for(int n = 0; n < count(); ++n) { const HttpHeader &h = at(n); if(qstricmp(h.first.data(), key.data()) == 0) return h.second; } return QByteArray(); } HttpHeaderParameters HttpHeaders::getAsParameters(const QByteArray &key, ParseMode mode) const { QByteArray h = get(key); if(h.isEmpty()) return HttpHeaderParameters(); return parseParameters(h, mode); } QByteArray HttpHeaders::getAsFirstParameter(const QByteArray &key) const { HttpHeaderParameters p = getAsParameters(key); if(p.isEmpty()) return QByteArray(); return p[0].first; } QList HttpHeaders::getAll(const QByteArray &key, bool split) const { QList out; for(int n = 0; n < count(); ++n) { const HttpHeader &h = at(n); if(qstricmp(h.first.data(), key.data()) == 0) { if(split) out += headerSplit(h.second); else out += h.second; } } return out; } QList HttpHeaders::getAllAsParameters(const QByteArray &key, ParseMode mode, bool split) const { QList out; foreach(const QByteArray &h, getAll(key, split)) { bool ok; HttpHeaderParameters params = parseParameters(h, mode, &ok); if(ok) out += params; } return out; } QList HttpHeaders::takeAll(const QByteArray &key, bool split) { QList out; for(int n = 0; n < count(); ++n) { const HttpHeader &h = at(n); if(qstricmp(h.first.data(), key.data()) == 0) { if(split) out += headerSplit(h.second); else out += h.second; removeAt(n); --n; // adjust position } } return out; } void HttpHeaders::removeAll(const QByteArray &key) { for(int n = 0; n < count(); ++n) { if(qstricmp(at(n).first.data(), key.data()) == 0) { removeAt(n); --n; // adjust position } } } QByteArray HttpHeaders::join(const QList &values) { QByteArray out; bool first = true; foreach(const QByteArray &val, values) { if(!first) out += ", "; out += val; first = false; } return out; } HttpHeaderParameters HttpHeaders::parseParameters(const QByteArray &in, ParseMode mode, bool *ok) { HttpHeaderParameters out; int start = 0; if(mode == NoParseFirstParameter) { int at = in.indexOf(';'); if(at != -1) { out += HttpHeaderParameter(in.mid(0, at).trimmed(), QByteArray()); start = at + 1; } else { out += HttpHeaderParameter(in.trimmed(), QByteArray()); start = in.size(); } } while(start < in.size()) { QByteArray var; QByteArray val; int at = findNext(in, "=;", start); if(at != -1) { var = in.mid(start, at - start).trimmed(); if(in[at] == '=') { ++at; if(at < in.size() && in[at] == '\"') { ++at; bool complete = false; for(int n = at; n < in.size(); ++n) { if(in[n] == '\\') { if(n + 1 >= in.size()) { if(ok) *ok = false; return HttpHeaderParameters(); } ++n; val += in[n]; } else if(in[n] == '\"') { complete = true; at = n + 1; break; } else val += in[n]; } if(!complete) { if(ok) *ok = false; return HttpHeaderParameters(); } at = in.indexOf(';', at); if(at != -1) start = at + 1; else start = in.size(); } else { int vstart = at; at = in.indexOf(';', vstart); if(at != -1) { val = in.mid(vstart, at - vstart).trimmed(); start = at + 1; } else { val = in.mid(vstart).trimmed(); start = in.size(); } } } else start = at + 1; } else { var = in.mid(start).trimmed(); start = in.size(); } out.append(HttpHeaderParameter(var, val)); } if(ok) *ok = true; return out; } pushpin-1.41.0/src/core/httpheaders.h000066400000000000000000000036371504671364300175100ustar00rootroot00000000000000/* * Copyright (C) 2012-2013 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef HTTPHEADERS_H #define HTTPHEADERS_H #include #include #include typedef QPair HttpHeaderParameter; class HttpHeaderParameters : public QList { public: bool contains(const QByteArray &key) const; QByteArray get(const QByteArray &key) const; }; typedef QPair HttpHeader; class HttpHeaders : public QList { public: enum ParseMode { NoParseFirstParameter, ParseAllParameters }; bool contains(const QByteArray &key) const; QByteArray get(const QByteArray &key) const; HttpHeaderParameters getAsParameters(const QByteArray &key, ParseMode mode = NoParseFirstParameter) const; QByteArray getAsFirstParameter(const QByteArray &key) const; QList getAll(const QByteArray &key, bool split = true) const; QList getAllAsParameters(const QByteArray &key, ParseMode mode = NoParseFirstParameter, bool split = true) const; QList takeAll(const QByteArray &key, bool split = true); void removeAll(const QByteArray &key); static QByteArray join(const QList &values); static HttpHeaderParameters parseParameters(const QByteArray &in, ParseMode mode = NoParseFirstParameter, bool *ok = 0); }; #endif pushpin-1.41.0/src/core/httpheaderstest.cpp000066400000000000000000000054211504671364300207340ustar00rootroot00000000000000/* * Copyright (C) 2017 Fanout, Inc. * Copyright (C) 2025 Fastly, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ * */ #include "test.h" #include "httpheaders.h" static void parseParameters() { HttpHeaders h; h += HttpHeader("Fruit", "apple"); h += HttpHeader("Fruit", "banana"); h += HttpHeader("Fruit", "cherry"); QList params = h.getAllAsParameters("Fruit"); TEST_ASSERT_EQ(params.count(), 3); TEST_ASSERT_EQ(params[0][0].first, QByteArray("apple")); TEST_ASSERT_EQ(params[1][0].first, QByteArray("banana")); TEST_ASSERT_EQ(params[2][0].first, QByteArray("cherry")); h.clear(); h += HttpHeader("Fruit", "apple, banana, cherry"); params = h.getAllAsParameters("Fruit"); TEST_ASSERT_EQ(params.count(), 3); TEST_ASSERT_EQ(params[0][0].first, QByteArray("apple")); TEST_ASSERT_EQ(params[1][0].first, QByteArray("banana")); TEST_ASSERT_EQ(params[2][0].first, QByteArray("cherry")); h.clear(); h += HttpHeader("Fruit", "apple; type=\"granny, smith\", banana; type=\"\\\"yellow\\\"\""); params = h.getAllAsParameters("Fruit"); TEST_ASSERT_EQ(params.count(), 2); TEST_ASSERT_EQ(params[0][0].first, QByteArray("apple")); TEST_ASSERT_EQ(params[0][1].first, QByteArray("type")); TEST_ASSERT_EQ(params[0][1].second, QByteArray("granny, smith")); TEST_ASSERT_EQ(params[1][0].first, QByteArray("banana")); TEST_ASSERT_EQ(params[1][1].first, QByteArray("type")); TEST_ASSERT_EQ(params[1][1].second, QByteArray("\"yellow\"")); h.clear(); h += HttpHeader("Fruit", "\"apple"); QList l = h.getAll("Fruit"); TEST_ASSERT_EQ(l.count(), 1); TEST_ASSERT_EQ(l[0], QByteArray("\"apple")); h.clear(); h += HttpHeader("Fruit", "\"apple\\"); l = h.getAll("Fruit"); TEST_ASSERT_EQ(l.count(), 1); TEST_ASSERT_EQ(l[0], QByteArray("\"apple\\")); h.clear(); h += HttpHeader("Fruit", "apple; type=gala, banana; type=\"yellow, cherry"); params = h.getAllAsParameters("Fruit"); TEST_ASSERT_EQ(params.count(), 1); TEST_ASSERT_EQ(params[0][0].first, QByteArray("apple")); TEST_ASSERT_EQ(params[0][1].first, QByteArray("type")); TEST_ASSERT_EQ(params[0][1].second, QByteArray("gala")); } extern "C" int httpheaders_test(ffi::TestException *out_ex) { TEST_CATCH(parseParameters()); return 0; } pushpin-1.41.0/src/core/httprequest.h000066400000000000000000000053151504671364300175600ustar00rootroot00000000000000/* * Copyright (C) 2012-2016 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef HTTPREQUEST_H #define HTTPREQUEST_H #include #include #include "httpheaders.h" #include using Signal = boost::signals2::signal; using SignalInt = boost::signals2::signal; class HttpRequest { public: enum ErrorCondition { ErrorGeneric, ErrorPolicy, ErrorConnect, ErrorConnectTimeout, ErrorTls, ErrorLengthRequired, ErrorDisconnected, ErrorTimeout, ErrorUnavailable, ErrorRequestTooLarge }; virtual ~HttpRequest() = default; virtual QHostAddress peerAddress() const = 0; virtual void setConnectHost(const QString &host) = 0; virtual void setConnectPort(int port) = 0; virtual void setIgnorePolicies(bool on) = 0; virtual void setTrustConnectHost(bool on) = 0; virtual void setIgnoreTlsErrors(bool on) = 0; virtual void setTimeout(int msecs) = 0; virtual void start(const QString &method, const QUrl &uri, const HttpHeaders &headers) = 0; virtual void beginResponse(int code, const QByteArray &reason, const HttpHeaders &headers) = 0; // may call this multiple times virtual void writeBody(const QByteArray &body) = 0; virtual void endBody() = 0; virtual int bytesAvailable() const = 0; virtual int writeBytesAvailable() const = 0; virtual bool isFinished() const = 0; virtual bool isInputFinished() const = 0; virtual bool isOutputFinished() const = 0; virtual bool isErrored() const = 0; virtual ErrorCondition errorCondition() const = 0; virtual QString requestMethod() const = 0; virtual QUrl requestUri() const = 0; virtual HttpHeaders requestHeaders() const = 0; virtual int responseCode() const = 0; virtual QByteArray responseReason() const = 0; virtual HttpHeaders responseHeaders() const = 0; virtual QByteArray readBody(int size = -1) = 0; // takes from the buffer // indicates input data and/or input finished Signal readyRead; // indicates output data written and/or output finished SignalInt bytesWritten; Signal writeBytesChanged; Signal paused; Signal error; }; #endif pushpin-1.41.0/src/core/inspectdata.h000066400000000000000000000017361504671364300174720ustar00rootroot00000000000000/* * Copyright (C) 2012-2013 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef INSPECTDATA_H #define INSPECTDATA_H #include #include class InspectData { public: bool doProxy; QByteArray sharingKey; QByteArray sid; QHash lastIds; QVariant userData; InspectData() : doProxy(false) { } }; #endif pushpin-1.41.0/src/core/io.rs000066400000000000000000000334331504671364300157760ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::cell::RefCell; use std::future::Future; use std::io::{self, Write}; use std::pin::Pin; use std::task::{Context, Poll}; pub trait AsyncRead: Unpin { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll>; fn cancel(&mut self); } pub trait AsyncWrite: Unpin { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll>; fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { for b in bufs { if !b.is_empty() { return self.poll_write(cx, b); } } self.poll_write(cx, &[]) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll>; // for use with std Write fn is_writable(&self) -> bool; fn cancel(&mut self); } impl AsyncRead for &mut T { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { Pin::new(&mut **self).poll_read(cx, buf) } fn cancel(&mut self) { AsyncRead::cancel(&mut **self) } } impl AsyncWrite for &mut T { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { Pin::new(&mut **self).poll_write(cx, buf) } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { Pin::new(&mut **self).poll_write_vectored(cx, bufs) } fn poll_close(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { Pin::new(&mut **self).poll_close(cx) } fn is_writable(&self) -> bool { AsyncWrite::is_writable(&**self) } fn cancel(&mut self) { AsyncWrite::cancel(&mut **self) } } pub trait AsyncReadExt: AsyncRead { fn read<'a>(&'a mut self, buf: &'a mut [u8]) -> ReadFuture<'a, Self> { ReadFuture { r: self, buf } } } pub trait AsyncWriteExt: AsyncWrite { fn write<'a>(&'a mut self, buf: &'a [u8]) -> WriteFuture<'a, Self> { WriteFuture { w: self, buf, pos: 0, } } fn write_vectored<'a>( &'a mut self, bufs: &'a [io::IoSlice<'a>], ) -> WriteVectoredFuture<'a, Self> { WriteVectoredFuture { w: self, bufs, pos: 0, } } fn close(&mut self) -> CloseFuture<'_, Self> { CloseFuture { w: self } } fn write_shared<'a, B>(&'a mut self, buf: &'a RefCell) -> WriteSharedFuture<'a, Self, B> where B: AsRef<[u8]>, { WriteSharedFuture { w: self, buf } } } impl AsyncReadExt for R {} impl AsyncWriteExt for W {} pub struct StdWriteWrapper<'a, 'b, W> { w: Pin<&'a mut W>, cx: &'a mut Context<'b>, } impl<'a, 'b, W: AsyncWrite> StdWriteWrapper<'a, 'b, W> { pub fn new(w: Pin<&'a mut W>, cx: &'a mut Context<'b>) -> Self { StdWriteWrapper { w, cx } } } impl Write for StdWriteWrapper<'_, '_, W> { fn write(&mut self, buf: &[u8]) -> Result { match self.w.as_mut().poll_write(self.cx, buf) { Poll::Ready(ret) => ret, Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } fn write_vectored(&mut self, bufs: &[io::IoSlice]) -> Result { match self.w.as_mut().poll_write_vectored(self.cx, bufs) { Poll::Ready(ret) => ret, Poll::Pending => Err(io::Error::from(io::ErrorKind::WouldBlock)), } } fn flush(&mut self) -> Result<(), io::Error> { Ok(()) } } pub struct ReadHalf<'a, T: AsyncRead> { handle: &'a RefCell, } impl AsyncRead for ReadHalf<'_, T> { fn poll_read( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &mut [u8], ) -> Poll> { let mut handle = self.handle.borrow_mut(); Pin::new(&mut *handle).poll_read(cx, buf) } fn cancel(&mut self) { self.handle.borrow_mut().cancel(); } } pub struct WriteHalf<'a, T: AsyncWrite> { handle: &'a RefCell, } impl AsyncWrite for WriteHalf<'_, T> { fn poll_write( self: Pin<&mut Self>, cx: &mut Context<'_>, buf: &[u8], ) -> Poll> { let mut handle = self.handle.borrow_mut(); Pin::new(&mut *handle).poll_write(cx, buf) } fn poll_write_vectored( self: Pin<&mut Self>, cx: &mut Context<'_>, bufs: &[io::IoSlice<'_>], ) -> Poll> { let mut handle = self.handle.borrow_mut(); Pin::new(&mut *handle).poll_write_vectored(cx, bufs) } fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll> { let mut handle = self.handle.borrow_mut(); Pin::new(&mut *handle).poll_close(cx) } fn is_writable(&self) -> bool { self.handle.borrow().is_writable() } fn cancel(&mut self) { self.handle.borrow_mut().cancel(); } } pub fn io_split( handle: &RefCell, ) -> (ReadHalf<'_, T>, WriteHalf<'_, T>) { (ReadHalf { handle }, WriteHalf { handle }) } pub struct ReadFuture<'a, R: AsyncRead + ?Sized> { r: &'a mut R, buf: &'a mut [u8], } impl Future for ReadFuture<'_, R> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let r: Pin<&mut R> = Pin::new(f.r); r.poll_read(cx, f.buf) } } impl Drop for ReadFuture<'_, R> { fn drop(&mut self) { self.r.cancel(); } } pub struct WriteFuture<'a, W: AsyncWrite + ?Sized + Unpin> { w: &'a mut W, buf: &'a [u8], pos: usize, } impl Future for WriteFuture<'_, W> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let mut w: Pin<&mut W> = Pin::new(f.w); // try to write all the data before producing a result while f.pos < f.buf.len() { match w.as_mut().poll_write(cx, &f.buf[f.pos..]) { Poll::Ready(result) => match result { Ok(size) => f.pos += size, Err(e) => return Poll::Ready(Err(e)), }, Poll::Pending => return Poll::Pending, } } Poll::Ready(Ok(f.buf.len())) } } impl Drop for WriteFuture<'_, W> { fn drop(&mut self) { self.w.cancel(); } } pub struct CloseFuture<'a, W: AsyncWrite + ?Sized> { w: &'a mut W, } impl Future for CloseFuture<'_, W> { type Output = Result<(), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let w: Pin<&mut W> = Pin::new(f.w); w.poll_close(cx) } } impl Drop for CloseFuture<'_, W> { fn drop(&mut self) { self.w.cancel(); } } fn get_start_offset(bufs: &[io::IoSlice], pos: usize) -> (usize, usize) { let mut start = 0; let mut offset = pos; for buf in bufs { if offset < buf.len() { break; } start += 1; offset -= buf.len(); } (start, offset) } pub struct WriteVectoredFuture<'a, W: AsyncWrite + ?Sized + Unpin> { w: &'a mut W, bufs: &'a [io::IoSlice<'a>], pos: usize, } impl Future for WriteVectoredFuture<'_, W> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let mut w: Pin<&mut W> = Pin::new(f.w); // try to write all the data before producing a result loop { let (start, offset) = get_start_offset(f.bufs, f.pos); if start >= f.bufs.len() { break; } if offset == 0 { match w.as_mut().poll_write_vectored(cx, &f.bufs[start..]) { Poll::Ready(result) => match result { Ok(size) => f.pos += size, Err(e) => return Poll::Ready(Err(e)), }, Poll::Pending => return Poll::Pending, } } else { match w.as_mut().poll_write(cx, &f.bufs[start][offset..]) { Poll::Ready(result) => match result { Ok(size) => f.pos += size, Err(e) => return Poll::Ready(Err(e)), }, Poll::Pending => return Poll::Pending, } } } Poll::Ready(Ok(f.pos)) } } impl Drop for WriteVectoredFuture<'_, W> { fn drop(&mut self) { self.w.cancel(); } } pub struct WriteSharedFuture<'a, W: AsyncWrite + ?Sized + Unpin, B: AsRef<[u8]>> { w: &'a mut W, buf: &'a RefCell, } impl> Future for WriteSharedFuture<'_, W, B> { type Output = Result; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; let w: Pin<&mut W> = Pin::new(f.w); w.poll_write(cx, f.buf.borrow().as_ref()) } } impl> Drop for WriteSharedFuture<'_, W, B> { fn drop(&mut self) { self.w.cancel(); } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use std::cmp; use std::task::Context; struct TestBuffer { data: Vec, } impl TestBuffer { fn new() -> Self { Self { data: Vec::new() } } } impl AsyncRead for TestBuffer { fn poll_read( mut self: Pin<&mut Self>, _cx: &mut Context, buf: &mut [u8], ) -> Poll> { let size = cmp::min(buf.len(), self.data.len()); let left = self.data.split_off(size); (&mut buf[..size]).copy_from_slice(&self.data); self.data = left; Poll::Ready(Ok(size)) } fn cancel(&mut self) {} } impl AsyncWrite for TestBuffer { fn poll_write( mut self: Pin<&mut Self>, _cx: &mut Context, buf: &[u8], ) -> Poll> { let size = self.data.write(buf).unwrap(); Poll::Ready(Ok(size)) } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { true } fn cancel(&mut self) {} } #[test] fn async_read_write() { let executor = Executor::new(1); executor .spawn(async { let mut buf = TestBuffer::new(); let mut data = [0; 16]; assert_eq!(buf.read(&mut data).await.unwrap(), 0); assert_eq!(buf.write(b"hello").await.unwrap(), 5); assert_eq!(buf.read(&mut data).await.unwrap(), 5); assert_eq!(&data[..5], b"hello"); }) .unwrap(); executor.run(|_| Ok(())).unwrap(); } #[test] fn async_read_write_concurrent() { let executor = Executor::new(1); executor .spawn(async { let buf = RefCell::new(TestBuffer::new()); let (mut r, mut w) = io_split(&buf); let mut data = [0; 16]; let write_fut = w.write(b"hello"); let read_fut = r.read(&mut data); assert_eq!(write_fut.await.unwrap(), 5); assert_eq!(read_fut.await.unwrap(), 5); assert_eq!(&data[..5], b"hello"); }) .unwrap(); executor.run(|_| Ok(())).unwrap(); } #[test] fn async_write_vectored() { let executor = Executor::new(1); executor .spawn(async { let mut buf = TestBuffer::new(); let mut data = [0; 16]; assert_eq!(buf.read(&mut data).await.unwrap(), 0); assert_eq!( buf.write_vectored(&[ io::IoSlice::new(b"he"), io::IoSlice::new(b"l"), io::IoSlice::new(b"lo") ]) .await .unwrap(), 5 ); assert_eq!(buf.read(&mut data).await.unwrap(), 5); assert_eq!(&data[..5], b"hello"); }) .unwrap(); executor.run(|_| Ok(())).unwrap(); } } pushpin-1.41.0/src/core/jwt.cpp000066400000000000000000000124211504671364300163230ustar00rootroot00000000000000/* * Copyright (C) 2012-2022 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "jwt.h" #include #include #include #include #include namespace Jwt { EncodingKey::Private::Private() : type((KeyType)-1), raw(0) { } EncodingKey::Private::Private(ffi::JwtEncodingKey key) : type((KeyType)key.type), raw(key.key) { } EncodingKey::Private::~Private() { ffi::jwt_encoding_key_destroy(raw); } EncodingKey EncodingKey::fromSecret(const QByteArray &key) { EncodingKey k; k.d = new Private(ffi::jwt_encoding_key_from_secret((const quint8 *)key.data(), key.size())); return k; } EncodingKey EncodingKey::fromPem(const QByteArray &key) { EncodingKey k; k.d = new Private(ffi::jwt_encoding_key_from_pem((const quint8 *)key.data(), key.size())); return k; } EncodingKey EncodingKey::fromFile(const QString &fileName) { QFile f(fileName); if(!f.open(QFile::ReadOnly)) { return EncodingKey(); } QByteArray data = f.readAll(); if(data.startsWith("-----BEGIN")) return fromPem(data); else return fromSecret(QByteArray::fromHex(data.trimmed())); } EncodingKey EncodingKey::fromConfigString(const QString &s, const QDir &baseDir) { if(s.startsWith("file:")) { QString keyFile = s.mid(5); QFileInfo fi(keyFile); if(fi.isRelative()) keyFile = QFileInfo(baseDir, keyFile).filePath(); return fromFile(keyFile); } else { QByteArray secret; if(s.startsWith("base64:")) secret = QByteArray::fromBase64(s.mid(7).toUtf8()); else secret = s.toUtf8(); return fromSecret(secret); } } DecodingKey::Private::Private() : type((KeyType)-1), raw(0) { } DecodingKey::Private::Private(ffi::JwtDecodingKey key) : type((KeyType)key.type), raw(key.key) { } DecodingKey::Private::~Private() { ffi::jwt_decoding_key_destroy(raw); } DecodingKey DecodingKey::fromSecret(const QByteArray &key) { DecodingKey k; k.d = new Private(ffi::jwt_decoding_key_from_secret((const quint8 *)key.data(), key.size())); return k; } DecodingKey DecodingKey::fromPem(const QByteArray &key) { DecodingKey k; k.d = new Private(ffi::jwt_decoding_key_from_pem((const quint8 *)key.data(), key.size())); return k; } DecodingKey DecodingKey::fromFile(const QString &fileName) { QFile f(fileName); if(!f.open(QFile::ReadOnly)) { return DecodingKey(); } QByteArray data = f.readAll(); if(data.startsWith("-----BEGIN")) return fromPem(data); else return fromSecret(QByteArray::fromHex(data.trimmed())); } DecodingKey DecodingKey::fromConfigString(const QString &s, const QDir &baseDir) { if(s.startsWith("file:")) { QString keyFile = s.mid(5); QFileInfo fi(keyFile); if(fi.isRelative()) keyFile = QFileInfo(baseDir, keyFile).filePath(); return fromFile(keyFile); } else { QByteArray secret; if(s.startsWith("base64:")) secret = QByteArray::fromBase64(s.mid(7).toUtf8()); else secret = s.toUtf8(); return fromSecret(secret); } } QByteArray encodeWithAlgorithm(Algorithm alg, const QByteArray &claim, const EncodingKey &key) { char *token; if(ffi::jwt_encode((int)alg, (const char *)claim.data(), key.raw(), &token) != 0) { // error return QByteArray(); } QByteArray out = QByteArray(token); ffi::jwt_str_destroy(token); return out; } QByteArray decodeWithAlgorithm(Algorithm alg, const QByteArray &token, const DecodingKey &key) { char *claim; if(ffi::jwt_decode((int)alg, (const char *)token.data(), key.raw(), &claim) != 0) { // error return QByteArray(); } QByteArray out = QByteArray(claim); ffi::jwt_str_destroy(claim); return out; } QByteArray encode(const QVariant &claim, const EncodingKey &key) { Algorithm alg; switch(key.type()) { case Jwt::KeyType::Secret: alg = Jwt::HS256; break; case Jwt::KeyType::Ec: alg = Jwt::ES256; break; case Jwt::KeyType::Rsa: alg = Jwt::RS256; break; default: return QByteArray(); } QByteArray claimJson = QJsonDocument(QJsonObject::fromVariantMap(claim.toMap())).toJson(QJsonDocument::Compact); if(claimJson.isNull()) return QByteArray(); return encodeWithAlgorithm(alg, claimJson, key); } QVariant decode(const QByteArray &token, const DecodingKey &key) { Algorithm alg; switch(key.type()) { case Jwt::KeyType::Secret: alg = Jwt::HS256; break; case Jwt::KeyType::Ec: alg = Jwt::ES256; break; case Jwt::KeyType::Rsa: alg = Jwt::RS256; break; default: return QVariant(); } QByteArray claimJson = decodeWithAlgorithm(alg, token, key); if(claimJson.isEmpty()) return QVariant(); QJsonParseError error; QJsonDocument doc = QJsonDocument::fromJson(claimJson, &error); if(error.error != QJsonParseError::NoError || !doc.isObject()) return QVariant(); return doc.object().toVariantMap(); } } pushpin-1.41.0/src/core/jwt.h000066400000000000000000000055451504671364300160010ustar00rootroot00000000000000/* * Copyright (C) 2012-2022 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef JWT_H #define JWT_H #include #include #include #include #include "rust/bindings.h" class QString; namespace Jwt { enum KeyType { Secret = ffi::JWT_KEYTYPE_SECRET, Ec = ffi::JWT_KEYTYPE_EC, Rsa = ffi::JWT_KEYTYPE_RSA, }; enum Algorithm { HS256 = ffi::JWT_ALGORITHM_HS256, ES256 = ffi::JWT_ALGORITHM_ES256, RS256 = ffi::JWT_ALGORITHM_RS256, }; class EncodingKey { public: bool isNull() const { return !d; } KeyType type() const { if(d) { return d->type; } else { return (KeyType)-1; } } const ffi::EncodingKey *raw() const { if(d) { return d->raw; } else { return 0; } } static EncodingKey fromSecret(const QByteArray &key); static EncodingKey fromPem(const QByteArray &key); static EncodingKey fromFile(const QString &fileName); static EncodingKey fromConfigString(const QString &s, const QDir &baseDir = QDir()); private: class Private : public QSharedData { public: KeyType type; ffi::EncodingKey *raw; Private(); Private(ffi::JwtEncodingKey key); ~Private(); }; QSharedDataPointer d; }; class DecodingKey { public: bool isNull() const { return !d; } KeyType type() const { if(d) { return d->type; } else { return (KeyType)-1; } } const ffi::DecodingKey *raw() const { if(d) { return d->raw; } else { return 0; } } static DecodingKey fromSecret(const QByteArray &key); static DecodingKey fromPem(const QByteArray &key); static DecodingKey fromFile(const QString &fileName); static DecodingKey fromConfigString(const QString &s, const QDir &baseDir = QDir()); private: class Private : public QSharedData { public: KeyType type; ffi::DecodingKey *raw; Private(); Private(ffi::JwtDecodingKey key); ~Private(); }; QSharedDataPointer d; }; // returns token, null on error QByteArray encodeWithAlgorithm(Algorithm alg, const QByteArray &claim, const EncodingKey &key); // returns claim, null on error QByteArray decodeWithAlgorithm(Algorithm alg, const QByteArray &token, const DecodingKey &key); QByteArray encode(const QVariant &claim, const EncodingKey &key); QVariant decode(const QByteArray &token, const DecodingKey &key); } #endif pushpin-1.41.0/src/core/jwt.rs000066400000000000000000000247001504671364300161700ustar00rootroot00000000000000/* * Copyright (C) 2022 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ use jsonwebtoken::{DecodingKey, EncodingKey, Header, TokenData, Validation}; pub fn encode( header: &Header, claim: &str, key: &EncodingKey, ) -> Result { // claim is already serialized, but the jsonwebtoken crate requires a // serializable object. so we'll deserialize to a generic value first let claim: serde_json::Value = serde_json::from_str(claim)?; jsonwebtoken::encode(header, &claim, key) } pub fn decode( token: &str, key: &DecodingKey, validation: &Validation, ) -> Result { let token_data: TokenData = jsonwebtoken::decode(token, key, validation)?; Ok(serde_json::to_string(&token_data.claims)?) } mod ffi { use super::*; use std::collections::HashSet; use std::ffi::{CStr, CString}; use std::os::raw::c_char; use std::ptr; use std::slice; pub const JWT_KEYTYPE_SECRET: libc::c_int = 0; pub const JWT_KEYTYPE_EC: libc::c_int = 1; pub const JWT_KEYTYPE_RSA: libc::c_int = 2; pub const JWT_ALGORITHM_HS256: libc::c_int = 0; pub const JWT_ALGORITHM_ES256: libc::c_int = 1; pub const JWT_ALGORITHM_RS256: libc::c_int = 2; #[repr(C)] pub struct JwtEncodingKey { r#type: libc::c_int, key: *mut jsonwebtoken::EncodingKey, } #[repr(C)] pub struct JwtDecodingKey { r#type: libc::c_int, key: *mut jsonwebtoken::DecodingKey, } type EncodingKeyFromPemFn = fn(&[u8]) -> Result; type DecodingKeyFromPemFn = fn(&[u8]) -> Result; fn load_encoding_key_pem( key: &[u8], ) -> Result<(libc::c_int, jsonwebtoken::EncodingKey), jsonwebtoken::errors::Error> { // pem data includes the key type, however the jsonwebtoken crate // requires specifying the expected type when decoding. we'll just try // the data against multiple possible types let decoders: [(libc::c_int, EncodingKeyFromPemFn); 2] = [ (JWT_KEYTYPE_EC, jsonwebtoken::EncodingKey::from_ec_pem), (JWT_KEYTYPE_RSA, jsonwebtoken::EncodingKey::from_rsa_pem), ]; let mut last_err = None; for (ktype, f) in decoders { match f(key) { Ok(key) => return Ok((ktype, key)), Err(e) => last_err = Some(e), } } Err(last_err.unwrap()) } fn load_decoding_key_pem( key: &[u8], ) -> Result<(libc::c_int, jsonwebtoken::DecodingKey), jsonwebtoken::errors::Error> { // pem data includes the key type, however the jsonwebtoken crate // requires specifying the expected type when decoding. we'll just try // the data against multiple possible types let decoders: [(libc::c_int, DecodingKeyFromPemFn); 2] = [ (JWT_KEYTYPE_EC, jsonwebtoken::DecodingKey::from_ec_pem), (JWT_KEYTYPE_RSA, jsonwebtoken::DecodingKey::from_rsa_pem), ]; let mut last_err = None; for (ktype, f) in decoders { match f(key) { Ok(key) => return Ok((ktype, key)), Err(e) => last_err = Some(e), } } Err(last_err.unwrap()) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_encoding_key_from_secret( data: *const u8, len: libc::size_t, ) -> JwtEncodingKey { let key = jsonwebtoken::EncodingKey::from_secret(slice::from_raw_parts(data, len)); JwtEncodingKey { r#type: JWT_KEYTYPE_SECRET, key: Box::into_raw(Box::new(key)), } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_encoding_key_from_pem( data: *const u8, len: libc::size_t, ) -> JwtEncodingKey { match load_encoding_key_pem(slice::from_raw_parts(data, len)) { Ok((ktype, key)) => JwtEncodingKey { r#type: ktype, key: Box::into_raw(Box::new(key)), }, Err(_) => JwtEncodingKey { r#type: -1, key: ptr::null_mut(), }, } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_encoding_key_destroy(key: *mut jsonwebtoken::EncodingKey) { if !key.is_null() { drop(Box::from_raw(key)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_decoding_key_from_secret( data: *const u8, len: libc::size_t, ) -> JwtDecodingKey { let key = jsonwebtoken::DecodingKey::from_secret(slice::from_raw_parts(data, len)); JwtDecodingKey { r#type: JWT_KEYTYPE_SECRET, key: Box::into_raw(Box::new(key)), } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_decoding_key_from_pem( data: *const u8, len: libc::size_t, ) -> JwtDecodingKey { match load_decoding_key_pem(slice::from_raw_parts(data, len)) { Ok((ktype, key)) => JwtDecodingKey { r#type: ktype, key: Box::into_raw(Box::new(key)), }, Err(_) => JwtDecodingKey { r#type: -1, key: ptr::null_mut(), }, } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_decoding_key_destroy(key: *mut jsonwebtoken::DecodingKey) { if !key.is_null() { drop(Box::from_raw(key)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_str_destroy(s: *mut c_char) { if !s.is_null() { drop(CString::from_raw(s)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_encode( alg: libc::c_int, claim: *const c_char, key: *const jsonwebtoken::EncodingKey, out_token: *mut *mut c_char, ) -> libc::c_int { if claim.is_null() || out_token.is_null() { return 1; // null pointers } let key = match key.as_ref() { Some(r) => r, None => return 1, // null pointer }; let header = match alg { JWT_ALGORITHM_HS256 => jsonwebtoken::Header::new(jsonwebtoken::Algorithm::HS256), JWT_ALGORITHM_ES256 => jsonwebtoken::Header::new(jsonwebtoken::Algorithm::ES256), JWT_ALGORITHM_RS256 => jsonwebtoken::Header::new(jsonwebtoken::Algorithm::RS256), _ => return 1, // unsupported algorithm }; let claim = match CStr::from_ptr(claim).to_str() { Ok(s) => s, Err(_) => return 1, // claim is a JSON string which will be valid UTF-8 }; let token = match encode(&header, claim, key) { Ok(token) => token, Err(_) => return 1, // failed to sign }; let token = match CString::new(token) { Ok(s) => s, Err(_) => return 1, // unexpected token string format }; *out_token = token.into_raw(); 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn jwt_decode( alg: libc::c_int, token: *const c_char, key: *const jsonwebtoken::DecodingKey, out_claim: *mut *mut c_char, ) -> libc::c_int { if token.is_null() || out_claim.is_null() { return 1; // null pointers } let key = match key.as_ref() { Some(r) => r, None => return 1, // null pointer }; let mut validation = match alg { JWT_ALGORITHM_HS256 => jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::HS256), JWT_ALGORITHM_ES256 => jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::ES256), JWT_ALGORITHM_RS256 => jsonwebtoken::Validation::new(jsonwebtoken::Algorithm::RS256), _ => return 1, // unsupported algorithm }; // don't check exp or anything. that's left to the caller validation.required_spec_claims = HashSet::new(); let token = match CStr::from_ptr(token).to_str() { Ok(s) => s, Err(_) => return 1, // token string will be valid UTF-8 }; let claim = match decode(token, key, &validation) { Ok(claim) => claim, Err(_) => return 1, // failed to validate }; let claim = match CString::new(claim) { Ok(s) => s, Err(_) => return 1, // unexpected claim string format }; *out_claim = claim.into_raw(); 0 } } #[cfg(test)] mod tests { use super::*; use jsonwebtoken::Algorithm; use serde::Deserialize; use serde::Serialize; use std::time::{SystemTime, UNIX_EPOCH}; #[test] fn encode_decode() { #[derive(Debug, Serialize, Deserialize)] struct Claim { iss: String, exp: u64, } let claim = Claim { iss: "nobody".to_string(), exp: SystemTime::now() .duration_since(UNIX_EPOCH) .unwrap() .as_secs(), }; let claim = serde_json::to_string(&claim).unwrap(); let token = encode( &Header::new(Algorithm::HS256), &claim, &EncodingKey::from_secret(b"secret"), ) .unwrap(); let claim = decode( &token, &DecodingKey::from_secret(b"secret"), &Validation::new(Algorithm::HS256), ) .unwrap(); let claim: Claim = serde_json::from_str(&claim).unwrap(); assert_eq!(claim.iss, "nobody"); } } pushpin-1.41.0/src/core/jwttest.cpp000066400000000000000000000161471504671364300172340ustar00rootroot00000000000000/* * Copyright (C) 2013-2022 Fanout, Inc. * Copyright (C) 2024-2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "test.h" #include #include #include "qtcompat.h" #include "jwt.h" static const char *test_ec_private_key_pem = "-----BEGIN PRIVATE KEY-----\n" "MIGHAgEAMBMGByqGSM49AgEGCCqGSM49AwEHBG0wawIBAQQgFcZQVV16cpGC4QUQ\n" "8O8H85totFiAB54WBTxKQQElI7KhRANCAAQA3D4/QkBACQuC99MFqZllTOaamPAJ\n" "3+Z3JkPsrd/z651PYmlywcdEGVWRiD2PNhvdzM7Nckxx1ZofDLlkvoxH\n" "-----END PRIVATE KEY-----\n"; static const char *test_ec_public_key_pem = "-----BEGIN PUBLIC KEY-----\n" "MFkwEwYHKoZIzj0CAQYIKoZIzj0DAQcDQgAEANw+P0JAQAkLgvfTBamZZUzmmpjw\n" "Cd/mdyZD7K3f8+udT2JpcsHHRBlVkYg9jzYb3czOzXJMcdWaHwy5ZL6MRw==\n" "-----END PUBLIC KEY-----\n"; static const char *test_rsa_private_key_pem = "-----BEGIN PRIVATE KEY-----\n" "MIIEvgIBADANBgkqhkiG9w0BAQEFAASCBKgwggSkAgEAAoIBAQDOgE+exziD5kFF\n" "4x3F76G64XAccp1KqWMgrTYM3c4C/7hxwuu7kMdGXlXL+xQOHe6vX6EM/H9tWaIf\n" "CyQ+KfdyBBDO05MXZcxEl3496ShN/UN1TghJk12gg3yPm3+V2mfh+NQi7jEFt1uv\n" "beco5T1ve5yhtu58PrCC87TuWINW+iFrUg41MEHcXWL/7COBR/azFOZqedZPCdnL\n" "5SoY1H5WAazZUftD6W7PvCYmQN+uCSr1SjbGg5g+9OQ6i7viHXRg0U9mIZVII56V\n" "g2sD0w6ClO4Tq+mQs94frKakD960drvg2QNCvW0cUiRBLkadOSqZkIp0It4r5ivi\n" "S2gJQO8XAgMBAAECggEBAJs4W6DwAw0yULIlq8WTALCmsEzR4mWyuW5ghJZbS3V5\n" "nrz0VZmhlAjS9A7l5gdOfJGagkZuraIWlARdrZqElRlA8Rlmc9RMkqSkcyI6Vi95\n" "RfGw/A3CFciHzWNs8RRFHX0AOwUeof63+tT8+ZsF5Y4dDnmINe9yd9+XLNNT+TWw\n" "aCFJ+RQ8j7xGtZb2N/AOI0prTCka/SNRYxNommdS1x9qCaTVKd1fXM/ZhRjIlsEo\n" "OzmcoG0Kdfq6pu2OgJ8DzSigXyWbCEy/amSWgPX80kubG1Xjc8MSFlQcg493Gve1\n" "JagUZEbKIQFNCxN42cAzuuf3hKV9vIT+L8yApuacwQECgYEA+Lgp8UtANVFOBSuE\n" "5HHP+dB3Ot8HdbK2FIQEQ+xwVUHgLnnWpQHhw8COpZgAoMGMPl37KGrTuPW/C/4o\n" "yGj/hK+df1ksLR8ViXFVpB5GbzfdsvMgPo1GCYVFVGJlVHO/oFxV6YQtydhiAMp+\n" "dcgQO3paKrzEoFSJdomNtoqMdUECgYEA1IvGiaiwk5yPafs2mbsoMM7K6NpzlO3x\n" "pPlTqgGgVgIM+Lg6FWEm3kWN6A/hELyfCIosHP5pdkPKxgkzs6OqVFxKa2anHSRT\n" "1lLUhU0kOrkYyfq1oMXumPb5Kc4zzbOnxScF7lCIzMo9y82OJSjHDbjAgmzNyJbm\n" "CEhOgf2RllcCgYAfyqKJ1j2R0x+u534oGSglXXEwFDwG3l4Jx0ooSHufWjlGl4pJ\n" "MzFhbSaOohxKcBL2Eds9slH3zWmrJcSewVUP58aw9XwBFH0TQWpZ/QixxKlQ62TO\n" "ug4ev2s6Ow2KuvTekY7lt2CG8WKtiTSa54SzpZMK7XAQsl2TykdT8ue7QQKBgGrG\n" "KR/gkYwmG1m3bK9/+OnECOU/UM8hVcJ1ylTeakiq0Q9lpTA2VQtWT7qjt4Hr78yf\n" "dRe/qwVRex1PZBy7fIbSskQQFqWqKT/C7qZkoW2qrMxS2UmCBaHseDFLOHT+6qo9\n" "N1qINKEEfFTU17LNMGoxROyAckRxoe/JOz9MPgYTAoGBAJKreX73d6s1s9oVB3u/\n" "DS1YXRmek+OkXQhFxekKXB3KxG8obx2uveeg18PtNf0RoYq9LF0hKcTqSCusfF9m\n" "lM+s5xc1mQfXI55AEOjT+8AssmhebHbFkpjr1/DSUUsCssO+1znkeZwAOApm/4kR\n" "pGokHI67k9CxNFZW3Z0U9EeW\n" "-----END PRIVATE KEY-----\n"; static const char *test_rsa_public_key_pem = "-----BEGIN PUBLIC KEY-----\n" "MIIBIjANBgkqhkiG9w0BAQEFAAOCAQ8AMIIBCgKCAQEAzoBPnsc4g+ZBReMdxe+h\n" "uuFwHHKdSqljIK02DN3OAv+4ccLru5DHRl5Vy/sUDh3ur1+hDPx/bVmiHwskPin3\n" "cgQQztOTF2XMRJd+PekoTf1DdU4ISZNdoIN8j5t/ldpn4fjUIu4xBbdbr23nKOU9\n" "b3ucobbufD6wgvO07liDVvoha1IONTBB3F1i/+wjgUf2sxTmannWTwnZy+UqGNR+\n" "VgGs2VH7Q+luz7wmJkDfrgkq9Uo2xoOYPvTkOou74h10YNFPZiGVSCOelYNrA9MO\n" "gpTuE6vpkLPeH6ympA/etHa74NkDQr1tHFIkQS5GnTkqmZCKdCLeK+Yr4ktoCUDv\n" "FwIDAQAB\n" "-----END PUBLIC KEY-----\n"; static void validToken() { QVariant vclaim = Jwt::decode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJmb28iOiAiYmFyIn0.oBia0Fph39FwQWv0TS7Disg4qa0aFa8qpMaYDrIXZqs", Jwt::DecodingKey::fromSecret("secret")); TEST_ASSERT(typeId(vclaim) == QMetaType::QVariantMap); QVariantMap claim = vclaim.toMap(); TEST_ASSERT(claim.value("foo") == "bar"); } static void validTokenBinaryKey() { QByteArray key; key += 0x01; key += 0x61; key += 0x80; key += 0xfe; QVariant vclaim = Jwt::decode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJmb28iOiAiYmFyIn0.-eLxyGEITnd6IP4WvGJx9CmIOt--Qcs3LB6wblJ7KXI", Jwt::DecodingKey::fromSecret(key)); TEST_ASSERT(typeId(vclaim) == QMetaType::QVariantMap); QVariantMap claim = vclaim.toMap(); TEST_ASSERT(claim.value("foo") == "bar"); } static void invalidKey() { QVariant vclaim = Jwt::decode("eyJhbGciOiAiSFMyNTYiLCAidHlwIjogIkpXVCJ9.eyJmb28iOiAiYmFyIn0.oBia0Fph39FwQWv0TS7Disg4qa0aFa8qpMaYDrIXZqs", Jwt::DecodingKey::fromSecret("wrong")); TEST_ASSERT(vclaim.isNull()); } static void es256EncodeDecode() { Jwt::EncodingKey privateKey = Jwt::EncodingKey::fromPem(QByteArray(test_ec_private_key_pem)); TEST_ASSERT(!privateKey.isNull()); TEST_ASSERT_EQ(privateKey.type(), Jwt::KeyType::Ec); Jwt::DecodingKey publicKey = Jwt::DecodingKey::fromPem(QByteArray(test_ec_public_key_pem)); TEST_ASSERT(!publicKey.isNull()); TEST_ASSERT_EQ(publicKey.type(), Jwt::KeyType::Ec); QVariantMap claim; claim["iss"] = "nobody"; QByteArray claimJson = QJsonDocument(QJsonObject::fromVariantMap(claim)).toJson(QJsonDocument::Compact); TEST_ASSERT(!claimJson.isNull()); QByteArray token = Jwt::encodeWithAlgorithm(Jwt::ES256, claimJson, privateKey); TEST_ASSERT(!token.isNull()); QByteArray resultJson = Jwt::decodeWithAlgorithm(Jwt::ES256, token, publicKey); TEST_ASSERT(!resultJson.isNull()); QJsonParseError error; QJsonDocument doc = QJsonDocument::fromJson(resultJson, &error); TEST_ASSERT(error.error == QJsonParseError::NoError); TEST_ASSERT(doc.isObject()); QVariantMap result = doc.object().toVariantMap(); TEST_ASSERT_EQ(result["iss"], "nobody"); } static void rs256EncodeDecode() { Jwt::EncodingKey privateKey = Jwt::EncodingKey::fromPem(QByteArray(test_rsa_private_key_pem)); TEST_ASSERT(!privateKey.isNull()); TEST_ASSERT_EQ(privateKey.type(), Jwt::KeyType::Rsa); Jwt::DecodingKey publicKey = Jwt::DecodingKey::fromPem(QByteArray(test_rsa_public_key_pem)); TEST_ASSERT(!publicKey.isNull()); TEST_ASSERT_EQ(publicKey.type(), Jwt::KeyType::Rsa); QVariantMap claim; claim["iss"] = "nobody"; QByteArray claimJson = QJsonDocument(QJsonObject::fromVariantMap(claim)).toJson(QJsonDocument::Compact); TEST_ASSERT(!claimJson.isNull()); QByteArray token = Jwt::encodeWithAlgorithm(Jwt::RS256, claimJson, privateKey); TEST_ASSERT(!token.isNull()); QByteArray resultJson = Jwt::decodeWithAlgorithm(Jwt::RS256, token, publicKey); TEST_ASSERT(!resultJson.isNull()); QJsonParseError error; QJsonDocument doc = QJsonDocument::fromJson(resultJson, &error); TEST_ASSERT(error.error == QJsonParseError::NoError); TEST_ASSERT(doc.isObject()); QVariantMap result = doc.object().toVariantMap(); TEST_ASSERT_EQ(result["iss"], "nobody"); } extern "C" int jwt_test(ffi::TestException *out_ex) { TEST_CATCH(validToken()); TEST_CATCH(validTokenBinaryKey()); TEST_CATCH(invalidKey()); TEST_CATCH(es256EncodeDecode()); TEST_CATCH(rs256EncodeDecode()); return 0; } pushpin-1.41.0/src/core/layertracker.cpp000066400000000000000000000026171504671364300202150ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "layertracker.h" #include LayerTracker::LayerTracker() : plain_(0) { } void LayerTracker::reset() { plain_ = 0; items_.clear(); } void LayerTracker::addPlain(int plain) { plain_ += plain; } void LayerTracker::specifyEncoded(int encoded, int plain) { // can't specify more bytes than we have assert(plain <= plain_); plain_ -= plain; Item i; i.plain = plain; i.encoded = encoded; items_ += i; } int LayerTracker::finished(int encoded) { int plain = 0; for(QList::Iterator it = items_.begin(); it != items_.end();) { Item &i = *it; // not enough? if(encoded < i.encoded) { i.encoded -= encoded; break; } encoded -= i.encoded; plain += i.plain; it = items_.erase(it); } return plain; } pushpin-1.41.0/src/core/layertracker.h000066400000000000000000000017561504671364300176650ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef LAYERTRACKER_H #define LAYERTRACKER_H #include class LayerTracker { public: LayerTracker(); void reset(); void addPlain(int plain); void specifyEncoded(int encoded, int plain); int finished(int encoded); private: class Item { public: int plain; int encoded; }; int plain_; QList items_; }; #endif pushpin-1.41.0/src/core/list.rs000066400000000000000000000231761504671364300163450ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use std::ops::IndexMut; pub struct Node { pub prev: Option, pub next: Option, pub value: T, } impl Node { pub fn new(value: T) -> Self { Self { prev: None, next: None, value, } } } #[derive(Default, Clone, Copy)] pub struct List { pub head: Option, pub tail: Option, } impl List { pub fn is_empty(&self) -> bool { self.head.is_none() } pub fn insert(&mut self, nodes: &mut S, after: Option, key: usize) where S: IndexMut>, { let next = if let Some(pkey) = after { let pn = &mut nodes[pkey]; let next = pn.next; pn.next = Some(key); let n = &mut nodes[key]; n.prev = Some(pkey); next } else { let next = self.head; self.head = Some(key); let n = &mut nodes[key]; n.prev = None; next }; let n = &mut nodes[key]; n.next = next; if let Some(nkey) = next { let nn = &mut nodes[nkey]; nn.prev = Some(key); } else { self.tail = Some(key); } } pub fn remove(&mut self, nodes: &mut S, key: usize) where S: IndexMut>, { let n = &mut nodes[key]; let prev = n.prev.take(); let next = n.next.take(); if let Some(pkey) = prev { let pn = &mut nodes[pkey]; pn.next = next; } if let Some(nkey) = next { let nn = &mut nodes[nkey]; nn.prev = prev; } if let Some(hkey) = self.head { if hkey == key { self.head = next; } } if let Some(tkey) = self.tail { if tkey == key { self.tail = prev; } } } pub fn pop_front(&mut self, nodes: &mut S) -> Option where S: IndexMut>, { match self.head { Some(key) => { self.remove(nodes, key); Some(key) } None => None, } } pub fn push_back(&mut self, nodes: &mut S, key: usize) where S: IndexMut>, { self.insert(nodes, self.tail, key); } pub fn concat(&mut self, nodes: &mut S, other: &mut Self) where S: IndexMut>, { if other.is_empty() { // nothing to do return; } // other is non-empty so this is guaranteed to succeed let hkey = other.head.unwrap(); let next = nodes[hkey].next; // since we're inserting after the tail, this will set next=None self.insert(nodes, self.tail, hkey); // restore the node's next key nodes[hkey].next = next; self.tail = other.tail; other.head = None; other.tail = None; } pub fn iter<'a, T, S>(&self, nodes: &'a S) -> ListIterator<'a, S> where S: IndexMut>, { ListIterator { nodes, next: self.head, } } } pub struct ListIterator<'a, S> { nodes: &'a S, next: Option, } impl<'a, T, S> Iterator for ListIterator<'a, S> where T: 'a, S: IndexMut>, { type Item = (usize, &'a T); fn next(&mut self) -> Option { if let Some(nkey) = self.next.take() { let n = &self.nodes[nkey]; self.next = n.next; Some((nkey, &n.value)) } else { None } } } #[cfg(test)] mod tests { use super::*; use slab::Slab; #[test] fn test_list_push_pop() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); let n2 = nodes.insert(Node::new("n2")); let n3 = nodes.insert(Node::new("n3")); // prevent unused warning on data field assert_eq!(nodes[n1].value, "n1"); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); assert_eq!(nodes[n2].prev, None); assert_eq!(nodes[n2].next, None); assert_eq!(nodes[n3].prev, None); assert_eq!(nodes[n3].next, None); let mut l = List::default(); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(l.pop_front(&mut nodes), None); l.push_back(&mut nodes, n1); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n1)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); l.push_back(&mut nodes, n2); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n2)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, None); l.push_back(&mut nodes, n3); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, Some(n3)); assert_eq!(nodes[n3].prev, Some(n2)); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n1)); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n2)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n2].prev, None); assert_eq!(nodes[n2].next, Some(n3)); assert_eq!(nodes[n3].prev, Some(n2)); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n2)); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n3)); assert_eq!(l.tail, Some(n3)); assert_eq!(nodes[n3].prev, None); assert_eq!(nodes[n3].next, None); let key = l.pop_front(&mut nodes); assert_eq!(key, Some(n3)); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(l.pop_front(&mut nodes), None); } #[test] fn test_remove() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); let mut l = List::default(); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); l.push_back(&mut nodes, n1); assert_eq!(l.is_empty(), false); assert_eq!(l.head, Some(n1)); assert_eq!(l.tail, Some(n1)); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); l.remove(&mut nodes, n1); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); // already removed l.remove(&mut nodes, n1); assert_eq!(l.is_empty(), true); assert_eq!(l.head, None); assert_eq!(l.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, None); } #[test] fn test_list_concat() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); let n2 = nodes.insert(Node::new("n2")); let mut a = List::default(); let mut b = List::default(); a.concat(&mut nodes, &mut b); assert_eq!(a.is_empty(), true); assert_eq!(a.head, None); assert_eq!(a.tail, None); assert_eq!(b.is_empty(), true); assert_eq!(b.head, None); assert_eq!(b.tail, None); a.push_back(&mut nodes, n1); b.push_back(&mut nodes, n2); a.concat(&mut nodes, &mut b); assert_eq!(a.is_empty(), false); assert_eq!(a.head, Some(n1)); assert_eq!(a.tail, Some(n2)); assert_eq!(b.is_empty(), true); assert_eq!(b.head, None); assert_eq!(b.tail, None); assert_eq!(nodes[n1].prev, None); assert_eq!(nodes[n1].next, Some(n2)); assert_eq!(nodes[n2].prev, Some(n1)); assert_eq!(nodes[n2].next, None); } #[test] fn test_list_iter() { let mut nodes = Slab::new(); let n1 = nodes.insert(Node::new("n1")); let n2 = nodes.insert(Node::new("n2")); let n3 = nodes.insert(Node::new("n3")); let mut l = List::default(); l.push_back(&mut nodes, n1); l.push_back(&mut nodes, n2); l.push_back(&mut nodes, n3); let mut it = l.iter(&nodes); assert_eq!(it.next(), Some((n1, &"n1"))); assert_eq!(it.next(), Some((n2, &"n2"))); assert_eq!(it.next(), Some((n3, &"n3"))); assert_eq!(it.next(), None); } } pushpin-1.41.0/src/core/log.cpp000066400000000000000000000066141504671364300163070ustar00rootroot00000000000000/* * Copyright (C) 2012-2022 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "log.h" #include #include #include #include #include #include Q_GLOBAL_STATIC(QMutex, g_mutex) static int g_level = LOG_LEVEL_DEBUG; static QElapsedTimer g_time; static QString *g_filename; static FILE *g_file; static void log(const char *s) { FILE *out; if(g_file) out = g_file; else out = stdout; fprintf(out, "%s\n", s); fflush(out); } static void log(int level, const char *fmt, va_list ap) { g_mutex()->lock(); int current_level = g_level; int elapsed; if(g_time.isValid()) elapsed = g_time.elapsed(); else elapsed = -1; g_mutex()->unlock(); if(level <= current_level) { QString str = QString::vasprintf(fmt, ap); const char *lstr; switch(level) { case LOG_LEVEL_ERROR: lstr = "ERR"; break; case LOG_LEVEL_WARNING: lstr = "WARN"; break; case LOG_LEVEL_INFO: lstr = "INFO"; break; case LOG_LEVEL_DEBUG: default: lstr = "DEBUG"; break; } QString tstr; if(elapsed != -1) { QTime t(0, 0); t = t.addMSecs(elapsed); tstr = t.toString("HH:mm:ss.zzz"); } else { tstr = QDateTime::currentDateTime().toString("yyyy-MM-dd HH:mm:ss.zzz"); } FILE *out; if(g_file) out = g_file; else out = stdout; fprintf(out, "[%s] %s %s\n", lstr, qPrintable(tstr), qPrintable(str)); fflush(out); } } void log_startClock() { QMutexLocker locker(g_mutex()); g_time.start(); } int log_outputLevel() { QMutexLocker locker(g_mutex()); return g_level; } void log_setOutputLevel(int level) { QMutexLocker locker(g_mutex()); g_level = level; } bool log_setFile(const QString &fname) { QMutexLocker locker(g_mutex()); if(g_file) { fclose(g_file); delete g_filename; g_filename = 0; g_file = 0; } if(fname.isEmpty()) return true; FILE *f = fopen(fname.toLocal8Bit().data(), "a"); if(!f) return false; setbuf(f, NULL); g_filename = new QString(fname); g_file = f; return true; } bool log_rotate() { QMutexLocker locker(g_mutex()); if(!g_file) return true; if(!freopen(g_filename->toLocal8Bit().data(), "a", g_file)) return false; setbuf(g_file, NULL); return true; } void log(int level, const char *fmt, ...) { va_list ap; va_start(ap, fmt); log(level, fmt, ap); va_end(ap); } void log_error(const char *fmt, ...) { va_list ap; va_start(ap, fmt); log(LOG_LEVEL_ERROR, fmt, ap); va_end(ap); } void log_warning(const char *fmt, ...) { va_list ap; va_start(ap, fmt); log(LOG_LEVEL_WARNING, fmt, ap); va_end(ap); } void log_info(const char *fmt, ...) { va_list ap; va_start(ap, fmt); log(LOG_LEVEL_INFO, fmt, ap); va_end(ap); } void log_debug(const char *fmt, ...) { va_list ap; va_start(ap, fmt); log(LOG_LEVEL_DEBUG, fmt, ap); va_end(ap); } void log_raw(const char *s) { log(s); } pushpin-1.41.0/src/core/log.h000066400000000000000000000024401504671364300157450ustar00rootroot00000000000000/* * Copyright (C) 2012-2016 Fanout, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef LOG_H #define LOG_H #include // really simply logging stuff #define LOG_LEVEL_ERROR 0 #define LOG_LEVEL_WARNING 1 #define LOG_LEVEL_INFO 2 #define LOG_LEVEL_DEBUG 3 void log_startClock(); int log_outputLevel(); void log_setOutputLevel(int level); bool log_setFile(const QString &fname); bool log_rotate(); void log(int level, const char *fmt, ...); void log_error(const char *fmt, ...); void log_warning(const char *fmt, ...); void log_info(const char *fmt, ...); void log_debug(const char *fmt, ...); // log without prefixing or anything. useful for forwarding log data void log_raw(const char *line); #endif pushpin-1.41.0/src/core/log.rs000066400000000000000000000115541504671364300161500ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use log::{Level, Log, Metadata, Record}; use std::fs::File; use std::io::{self, Write}; use std::str; use std::sync::{Mutex, OnceLock}; use time::macros::format_description; use time::{OffsetDateTime, UtcOffset}; enum SharedOutput<'a> { Stdout(io::Stdout), File(&'a Mutex), } impl Write for SharedOutput<'_> { fn write(&mut self, buf: &[u8]) -> Result { match self { Self::Stdout(g) => g.write(buf), Self::File(g) => (*g).lock().unwrap().write(buf), } } fn flush(&mut self) -> Result<(), io::Error> { match self { Self::Stdout(g) => g.flush(), Self::File(g) => (*g).lock().unwrap().flush(), } } } pub struct SimpleLogger { local_offset: Option, output_file: Option>, runner_mode: bool, } impl Log for SimpleLogger { fn enabled(&self, metadata: &Metadata) -> bool { metadata.level() <= Level::Trace } fn log(&self, record: &Record) { if !self.enabled(record.metadata()) { return; } let mut output = match &self.output_file { Some(f) => SharedOutput::File(f), None => SharedOutput::Stdout(io::stdout()), }; if self.runner_mode { writeln!(&mut output, "{}", record.args()).expect("failed to write log output"); return; } let now = OffsetDateTime::now_utc().to_offset(self.local_offset.unwrap_or(UtcOffset::UTC)); let format = format_description!( "[year]-[month]-[day] [hour]:[minute]:[second].[subsecond digits:3]" ); let mut ts = [0u8; 64]; let size = { let mut ts = io::Cursor::new(&mut ts[..]); now.format_into(&mut ts, &format) .expect("failed to write timestamp"); ts.position() as usize }; let ts = str::from_utf8(&ts[..size]).expect("timestamp is not utf-8"); let lname = match record.level() { log::Level::Error => "ERR", log::Level::Warn => "WARN", log::Level::Info => "INFO", log::Level::Debug => "DEBUG", log::Level::Trace => "TRACE", }; if record.level() <= log::Level::Info { writeln!(&mut output, "[{}] {} {}", lname, ts, record.args()) .expect("failed to write log output"); } else { writeln!( &mut output, "[{}] {} [{}] {}", lname, ts, record.target(), record.args() ) .expect("failed to write log output"); } } fn flush(&self) {} } // SAFETY: this method is unsound on platforms where another thread may // modify environment vars unsafe fn get_offset() -> Option { time::util::local_offset::set_soundness(time::util::local_offset::Soundness::Unsound); let offset = UtcOffset::current_local_offset().ok(); time::util::local_offset::set_soundness(time::util::local_offset::Soundness::Sound); offset } static LOGGER: OnceLock = OnceLock::new(); pub fn ensure_init_simple_logger(output_file: Option, runner_mode: bool) { LOGGER.get_or_init(|| { // SAFETY: we accept that this call is unsound. on some platforms it // is the only way to know the time zone, with a chance of UB if // another thread modifies environment vars during the call. the risk // is low, as this call will happen very early in the program, and // only once. we would rather accept this low risk and know the time // zone than not know the time zone let local_offset = unsafe { get_offset() }; SimpleLogger { local_offset, output_file: output_file.map(Mutex::new), runner_mode, } }); } pub fn get_simple_logger() -> &'static SimpleLogger { ensure_init_simple_logger(None, false); // logger is guaranteed to have been initialized LOGGER.get().expect("logger should be initialized") } pub fn local_offset_check() { if get_simple_logger().local_offset.is_none() { log::warn!("Failed to determine local time offset. Log timestamps will be in UTC."); } } pushpin-1.41.0/src/core/logutil.cpp000066400000000000000000000125051504671364300172010ustar00rootroot00000000000000/* * Copyright (C) 2017-2022 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "logutil.h" #include #include #include "qtcompat.h" #include "tnetstring.h" #include "log.h" #define MAX_DATA_LENGTH 1000 #define MAX_CONTENT_LENGTH 1000 namespace LogUtil { static QString trim(const QString &in, int max) { if(in.length() > max && max >= 7) return in.mid(0, max / 2) + "..." + in.mid(in.length() - (max / 2) + 3); else return in; } static QByteArray trim(const QByteArray &in, int max) { if(in.size() > max && max >= 7) return in.mid(0, max / 2) + "..." + in.mid(in.size() - (max / 2) + 3); else return in; } static QString makeLastIdsStr(const HttpHeaders &headers) { QString out; bool first = true; foreach(const HttpHeaderParameters ¶ms, headers.getAllAsParameters("Grip-Last")) { if(!first) out += ' '; out += QString("#%1=%2").arg(QString::fromUtf8(params[0].first), QString::fromUtf8(params.get("last-id"))); first = false; } return out; } static void logPacket(int level, const QString &message, const QVariant &data = QVariant(), int dataMax = -1, const QByteArray &content = QByteArray(), int contentMax = -1) { QString out = message; if(data.isValid()) { out += ' ' + trim(TnetString::variantToString(data, -1), dataMax); } if(!content.isNull()) { out += ' ' + QString::number(content.size()) + ' '; QByteArray buf = trim(content, contentMax); out += TnetString::variantToString(QVariant(buf), -1); } log(level, "%s", qPrintable(out)); } static void logPacket(int level, const QVariant &data, const char *fmt, va_list ap) { logPacket(level, QString::vasprintf(fmt, ap), data, MAX_DATA_LENGTH); } static void logPacket(int level, const QByteArray &content, const char *fmt, va_list ap) { logPacket(level, QString::vasprintf(fmt, ap), QVariant(), -1, content, MAX_CONTENT_LENGTH); } static void logPacket(int level, const QVariant &data, const QString &contentField, const char *fmt, va_list ap) { QVariant meta; QByteArray content; if(typeId(data) == QMetaType::QVariantHash) { // extract content. meta is the remaining data QVariantHash hdata = data.toHash(); content = hdata.value(contentField).toByteArray(); hdata.remove(contentField); meta = hdata; } else { // if data isn't a hash, then we can't extract content, so // the meta part will be the entire data meta = data; } logPacket(level, QString::vasprintf(fmt, ap), meta, MAX_DATA_LENGTH, content, MAX_CONTENT_LENGTH); } void logVariant(int level, const QVariant &data, const char *fmt, ...) { va_list ap; va_start(ap, fmt); logPacket(level, data, fmt, ap); va_end(ap); } void logByteArray(int level, const QByteArray &content, const char *fmt, ...) { va_list ap; va_start(ap, fmt); logPacket(level, content, fmt, ap); va_end(ap); } void logVariantWithContent(int level, const QVariant &data, const QString &contentField, const char *fmt, ...) { va_list ap; va_start(ap, fmt); logPacket(level, data, contentField, fmt, ap); va_end(ap); } void logRequest(int level, const RequestData &data, const Config &config) { QString msg = QString("%1 %2").arg(data.requestData.method, data.requestData.uri.toString(QUrl::FullyEncoded)); if(!data.targetStr.isEmpty()) msg += QString(" -> %1").arg(data.targetStr); if(data.requestData.uri.scheme() != "http" && data.requestData.uri.scheme() != "https" && data.targetOverHttp) msg += "[http]"; if(config.fromAddress && !data.fromAddress.isNull()) msg += QString(" from=%1").arg(data.fromAddress.toString()); QUrl ref = QUrl(QString::fromUtf8(data.requestData.headers.get("Referer"))); if(!ref.isEmpty()) msg += QString(" ref=%1").arg(ref.toString(QUrl::FullyEncoded)); if(!data.routeId.isEmpty()) msg += QString(" route=%1").arg(data.routeId); if(data.status == LogUtil::Response) { msg += QString(" code=%1 %2").arg(QString::number(data.responseData.code), QString::number(data.responseBodySize)); } else if(data.status == LogUtil::Accept) { msg += " accept"; } else { msg += " error"; } if(data.retry) msg += " retry"; if(data.sharedBy) msg += QString::asprintf(" shared=%p", data.sharedBy); if(config.userAgent) { QString userAgent = data.requestData.headers.get("User-Agent"); if(!userAgent.isEmpty()) msg += QString(" ua=%1").arg(userAgent); } QString lastIdsStr = makeLastIdsStr(data.requestData.headers); if(!lastIdsStr.isEmpty()) msg += ' ' + lastIdsStr; log(level, "%s", qPrintable(msg)); } void logForRoute(const RouteInfo &routeInfo, const char *fmt, ...) { va_list ap; va_start(ap, fmt); QString msg = QString::vasprintf(fmt, ap); if(!routeInfo.id.isEmpty()) msg += QString(" route=%1").arg(routeInfo.id); logPacket(routeInfo.logLevel, msg); va_end(ap); } } pushpin-1.41.0/src/core/logutil.h000066400000000000000000000040401504671364300166410ustar00rootroot00000000000000/* * Copyright (C) 2017 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef LOGUTIL_H #define LOGUTIL_H #include #include "log.h" #include "packet/httprequestdata.h" #include "packet/httpresponsedata.h" namespace LogUtil { enum RequestStatus { Response, Accept, Error }; class RequestData { public: QString routeId; RequestStatus status; HttpRequestData requestData; HttpResponseData responseData; int responseBodySize; QString targetStr; bool targetOverHttp; bool retry; void *sharedBy; QHostAddress fromAddress; RequestData() : status(Response), responseBodySize(-1), targetOverHttp(false), retry(false), sharedBy(0) { } }; class Config { public: bool fromAddress; bool userAgent; Config() : fromAddress(false), userAgent(false) { } }; class RouteInfo { public: QString id; int logLevel; RouteInfo(const QString &initId = QString(), int initLogLevel = LOG_LEVEL_DEBUG) : id(initId), logLevel(initLogLevel) { } }; void logVariant(int level, const QVariant &data, const char *fmt, ...); void logByteArray(int level, const QByteArray &content, const char *fmt, ...); void logVariantWithContent(int level, const QVariant &data, const QString &contentField, const char *fmt, ...); void logRequest(int level, const RequestData &data, const Config &config = Config()); void logForRoute(const RouteInfo &routeInfo, const char *fmt, ...); } #endif pushpin-1.41.0/src/core/mod.rs000066400000000000000000000103441504671364300161420ustar00rootroot00000000000000/* * Copyright (C) 2024-2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ pub mod arena; pub mod buffer; pub mod channel; pub mod config; pub mod defer; pub mod event; pub mod eventloop; pub mod executor; pub mod fs; pub mod http1; pub mod io; pub mod jwt; pub mod list; pub mod log; pub mod net; pub mod reactor; pub mod select; pub mod shuffle; pub mod task; pub mod test; pub mod time; pub mod timer; pub mod tnetstring; pub mod waker; pub mod zmq; use std::env; use std::ffi::{CString, OsStr}; use std::os::unix::ffi::OsStrExt; #[cfg(test)] use std::path::{Path, PathBuf}; pub fn is_debug_build() -> bool { cfg!(debug_assertions) } pub fn version() -> &'static str { env!("APP_VERSION") } /// # Safety /// /// * `main_fn` must be safe to call. pub unsafe fn call_c_main( main_fn: unsafe extern "C" fn(libc::c_int, *const *const libc::c_char) -> libc::c_int, args: I, ) -> u8 where I: IntoIterator, S: AsRef, { let args: Vec = args .into_iter() .map(|s| CString::new(s.as_ref().as_bytes()).unwrap()) .collect(); let args: Vec<*const libc::c_char> = args.iter().map(|s| s.as_ptr()).collect(); main_fn(args.len() as libc::c_int, args.as_ptr()) as u8 } #[cfg(test)] pub fn test_dir() -> PathBuf { // "cargo test" ensures this is present let out_dir = PathBuf::from(env::var("OUT_DIR").unwrap()); out_dir.join("test-work") } #[cfg(test)] pub fn ensure_example_config(dest: &Path) { use std::fs; use std::sync::Once; static INIT: Once = Once::new(); INIT.call_once(|| { let src_dir = Path::new("examples").join("config"); let dest_dir = dest.join("examples").join("config"); fs::create_dir_all(&dest_dir).unwrap(); fs::copy(src_dir.join("pushpin.conf"), dest_dir.join("pushpin.conf")).unwrap(); fs::copy(src_dir.join("routes"), dest_dir.join("routes")).unwrap(); }); } mod ffi { use std::os::raw::c_int; #[no_mangle] pub extern "C" fn is_debug_build() -> c_int { if super::is_debug_build() { 1 } else { 0 } } } #[cfg(test)] mod tests { use crate::core::test::{run_cpp, TestException}; use crate::ffi; fn httpheaders_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::httpheaders_test(out_ex) == 0 } } fn jwt_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::jwt_test(out_ex) == 0 } } fn timer_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::timer_test(out_ex) == 0 } } fn defercall_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::defercall_test(out_ex) == 0 } } fn tcpstream_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::tcpstream_test(out_ex) == 0 } } fn unixstream_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::unixstream_test(out_ex) == 0 } } fn eventloop_test(out_ex: &mut TestException) -> bool { // SAFETY: safe to call unsafe { ffi::eventloop_test(out_ex) == 0 } } #[test] fn httpheaders() { run_cpp(httpheaders_test); } #[test] fn jwt() { run_cpp(jwt_test); } #[test] fn timer() { run_cpp(timer_test); } #[test] fn defercall() { run_cpp(defercall_test); } #[test] fn tcpstream() { run_cpp(tcpstream_test); } #[test] fn unixstream() { run_cpp(unixstream_test); } #[test] fn eventloop() { run_cpp(eventloop_test); } } pushpin-1.41.0/src/core/net.rs000066400000000000000000001104151504671364300161510ustar00rootroot00000000000000/* * Copyright (C) 2022 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::event::ReadinessExt; use crate::core::io::{AsyncRead, AsyncWrite}; use crate::core::reactor::IoEvented; use crate::core::task::get_reactor; use log::error; use mio::net::{TcpListener, TcpStream, UnixListener, UnixStream}; use socket2::Socket; use std::fmt; use std::future::Future; use std::io::{self, Read, Write}; use std::os::unix::io::{FromRawFd, IntoRawFd}; use std::path::Path; use std::pin::Pin; use std::ptr; use std::task::{Context, Poll}; pub fn set_socket_opts(stream: &mut TcpStream) { if let Err(e) = stream.set_nodelay(true) { error!("set nodelay failed: {:?}", e); } // safety: we move the value out of stream and replace it at the end let ret = unsafe { let s = ptr::read(stream); let socket = Socket::from_raw_fd(s.into_raw_fd()); let ret = socket.set_keepalive(true); ptr::write(stream, TcpStream::from_raw_fd(socket.into_raw_fd())); ret }; if let Err(e) = ret { error!("set keepalive failed: {:?}", e); } } #[derive(Debug)] pub enum SocketAddr { Ip(std::net::SocketAddr), Unix(std::os::unix::net::SocketAddr), } impl fmt::Display for SocketAddr { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { match self { Self::Ip(a) => write!(f, "{}", a), Self::Unix(a) => write!(f, "{:?}", a), } } } #[derive(Debug)] pub enum NetListener { Tcp(TcpListener), Unix(UnixListener), } #[derive(Debug)] pub enum NetStream { Tcp(TcpStream), Unix(UnixStream), } pub struct AsyncTcpListener { evented: IoEvented, } impl AsyncTcpListener { pub fn new(l: TcpListener) -> Self { let evented = IoEvented::new(l, mio::Interest::READABLE, &get_reactor()).unwrap(); evented.registration().set_ready(true); Self { evented } } pub fn bind(addr: std::net::SocketAddr) -> Result { let listener = TcpListener::bind(addr)?; Ok(Self::new(listener)) } pub fn local_addr(&self) -> Result { self.evented.io().local_addr() } pub fn accept(&self) -> AcceptFuture<'_> { AcceptFuture { l: self } } pub fn into_inner(self) -> TcpListener { self.evented.into_inner() } } pub struct AsyncUnixListener { evented: IoEvented, } impl AsyncUnixListener { pub fn new(l: UnixListener) -> Self { let evented = IoEvented::new(l, mio::Interest::READABLE, &get_reactor()).unwrap(); evented.registration().set_ready(true); Self { evented } } pub fn bind>(path: P) -> Result { let listener = UnixListener::bind(path)?; Ok(Self::new(listener)) } pub fn local_addr(&self) -> Result { self.evented.io().local_addr() } pub fn accept(&self) -> UnixAcceptFuture<'_> { UnixAcceptFuture { l: self } } } pub enum AsyncNetListener { Tcp(AsyncTcpListener), Unix(AsyncUnixListener), } impl AsyncNetListener { pub fn new(l: NetListener) -> Self { match l { NetListener::Tcp(l) => Self::Tcp(AsyncTcpListener::new(l)), NetListener::Unix(l) => Self::Unix(AsyncUnixListener::new(l)), } } pub fn accept(&self) -> NetAcceptFuture<'_> { match self { Self::Tcp(l) => NetAcceptFuture::Tcp(l.accept()), Self::Unix(l) => NetAcceptFuture::Unix(l.accept()), } } } pub struct AsyncTcpStream { evented: IoEvented, } impl AsyncTcpStream { pub fn new(s: TcpStream) -> Self { let evented = IoEvented::new( s, mio::Interest::READABLE | mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); // when constructing via new(), assume I/O operations are ready to be // attempted evented .registration() .set_readiness(Some(mio::Interest::READABLE | mio::Interest::WRITABLE)); Self { evented } } pub async fn connect(addrs: &[std::net::SocketAddr]) -> Result { let mut last_err = None; for addr in addrs { let stream = match TcpStream::connect(*addr) { Ok(stream) => stream, Err(e) => { last_err = Some(e); continue; } }; let mut stream = Self::new(stream); // when constructing via connect(), the ready state should start out // false because we need to wait for a writability indication stream.evented.registration().set_ready(false); let fut = TcpConnectFuture { s: &mut stream }; if let Err(e) = fut.await { last_err = Some(e); continue; } return Ok(stream); } Err(last_err.unwrap_or_else(|| io::Error::from(io::ErrorKind::InvalidInput))) } pub fn peer_addr(&self) -> Result { self.evented.io().peer_addr() } pub fn into_inner(self) -> TcpStream { self.evented.into_inner() } pub fn into_std(self) -> std::net::TcpStream { unsafe { std::net::TcpStream::from_raw_fd(self.evented.into_inner().into_raw_fd()) } } // assumes stream is in non-blocking mode pub fn from_std(stream: std::net::TcpStream) -> Self { Self::new(TcpStream::from_std(stream)) } pub fn into_evented(self) -> IoEvented { self.evented } } pub struct AsyncUnixStream { evented: IoEvented, } impl AsyncUnixStream { pub fn new(s: UnixStream) -> Self { let evented = IoEvented::new( s, mio::Interest::READABLE | mio::Interest::WRITABLE, &get_reactor(), ) .unwrap(); // when constructing via new(), assume I/O operations are ready to be // attempted evented .registration() .set_readiness(Some(mio::Interest::READABLE | mio::Interest::WRITABLE)); Self { evented } } pub async fn connect>(path: P) -> Result { let stream = UnixStream::connect(path)?; let mut stream = Self::new(stream); // when constructing via connect(), the ready state should start out // false because we need to wait for a writability indication stream.evented.registration().set_ready(false); let fut = UnixConnectFuture { s: &mut stream }; fut.await?; Ok(stream) } } pub struct AcceptFuture<'a> { l: &'a AsyncTcpListener, } impl Future for AcceptFuture<'_> { type Output = Result<(TcpStream, std::net::SocketAddr), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.l.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f.l.evented.registration().is_ready() { return Poll::Pending; } if !f.l.evented.registration().pull_from_budget() { return Poll::Pending; } match f.l.evented.io().accept() { Ok((stream, peer_addr)) => Poll::Ready(Ok((stream, peer_addr))), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.l.evented.registration().set_ready(false); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for AcceptFuture<'_> { fn drop(&mut self) { self.l.evented.registration().clear_waker(); } } pub struct UnixAcceptFuture<'a> { l: &'a AsyncUnixListener, } impl Future for UnixAcceptFuture<'_> { type Output = Result<(UnixStream, std::os::unix::net::SocketAddr), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.l.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f.l.evented.registration().is_ready() { return Poll::Pending; } if !f.l.evented.registration().pull_from_budget() { return Poll::Pending; } match f.l.evented.io().accept() { Ok((stream, peer_addr)) => Poll::Ready(Ok((stream, peer_addr))), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.l.evented.registration().set_ready(false); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for UnixAcceptFuture<'_> { fn drop(&mut self) { self.l.evented.registration().clear_waker(); } } pub enum NetAcceptFuture<'a> { Tcp(AcceptFuture<'a>), Unix(UnixAcceptFuture<'a>), } impl Future for NetAcceptFuture<'_> { type Output = Result<(NetStream, SocketAddr), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let ret = match &mut *self { Self::Tcp(fut) => match Pin::new(fut).poll(cx) { Poll::Ready(ret) => match ret { Ok((stream, peer_addr)) => { Ok((NetStream::Tcp(stream), SocketAddr::Ip(peer_addr))) } Err(e) => Err(e), }, Poll::Pending => return Poll::Pending, }, Self::Unix(fut) => match Pin::new(fut).poll(cx) { Poll::Ready(ret) => match ret { Ok((stream, peer_addr)) => { Ok((NetStream::Unix(stream), SocketAddr::Unix(peer_addr))) } Err(e) => Err(e), }, Poll::Pending => return Poll::Pending, }, }; Poll::Ready(ret) } } pub struct TcpConnectFuture<'a> { s: &'a mut AsyncTcpStream, } impl Future for TcpConnectFuture<'_> { type Output = Result<(), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } // mio documentation says to use take_error() and peer_addr() to // check for connected if let Ok(Some(e)) | Err(e) = f.s.evented.io().take_error() { return Poll::Ready(Err(e)); } match f.s.evented.io().peer_addr() { Ok(_) => Poll::Ready(Ok(())), Err(e) if e.kind() == io::ErrorKind::NotConnected => { f.s.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for TcpConnectFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } pub struct UnixConnectFuture<'a> { s: &'a mut AsyncUnixStream, } impl Future for UnixConnectFuture<'_> { type Output = Result<(), io::Error>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let f = &mut *self; f.s.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f.s.evented.registration().is_ready() { return Poll::Pending; } // mio documentation says to use take_error() and peer_addr() to // check for connected if let Ok(Some(e)) | Err(e) = f.s.evented.io().take_error() { return Poll::Ready(Err(e)); } match f.s.evented.io().peer_addr() { Ok(_) => Poll::Ready(Ok(())), Err(e) if e.kind() == io::ErrorKind::NotConnected => { f.s.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } } impl Drop for UnixConnectFuture<'_> { fn drop(&mut self) { self.s.evented.registration().clear_waker(); } } impl AsyncRead for AsyncTcpStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::READABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().read(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::READABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn cancel(&mut self) { self.evented .registration() .clear_waker_interest(mio::Interest::READABLE); } } impl AsyncWrite for AsyncTcpStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().write(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context, bufs: &[io::IoSlice<'_>], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().write_vectored(bufs) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { self.evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) } fn cancel(&mut self) { self.evented .registration() .clear_waker_interest(mio::Interest::WRITABLE); } } impl AsyncRead for AsyncUnixStream { fn poll_read( mut self: Pin<&mut Self>, cx: &mut Context, buf: &mut [u8], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::READABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::READABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().read(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::READABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn cancel(&mut self) { self.evented .registration() .clear_waker_interest(mio::Interest::READABLE); } } impl AsyncWrite for AsyncUnixStream { fn poll_write( mut self: Pin<&mut Self>, cx: &mut Context, buf: &[u8], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().write(buf) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn poll_write_vectored( mut self: Pin<&mut Self>, cx: &mut Context, bufs: &[io::IoSlice<'_>], ) -> Poll> { let f = &mut *self; f.evented .registration() .set_waker(cx.waker(), mio::Interest::WRITABLE); if !f .evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) { return Poll::Pending; } if !f.evented.registration().pull_from_budget() { return Poll::Pending; } match f.evented.io().write_vectored(bufs) { Ok(size) => Poll::Ready(Ok(size)), Err(ref e) if e.kind() == io::ErrorKind::WouldBlock => { f.evented .registration() .clear_readiness(mio::Interest::WRITABLE); Poll::Pending } Err(e) => Poll::Ready(Err(e)), } } fn poll_close(self: Pin<&mut Self>, _cx: &mut Context) -> Poll> { Poll::Ready(Ok(())) } fn is_writable(&self) -> bool { self.evented .registration() .readiness() .contains_any(mio::Interest::WRITABLE) } fn cancel(&mut self) { self.evented .registration() .clear_waker_interest(mio::Interest::WRITABLE); } } mod ffi { use std::convert::TryInto; use std::ffi::{CStr, OsStr}; use std::io::{Read, Write}; use std::os::fd::{AsRawFd, FromRawFd, IntoRawFd}; use std::os::raw::{c_char, c_int}; use std::os::unix::ffi::OsStrExt; use std::path::Path; use std::ptr; use std::slice; pub struct TcpListener(std::net::TcpListener); pub struct TcpStream(std::net::TcpStream); pub struct UnixListener(std::os::unix::net::UnixListener); pub struct UnixStream(std::os::unix::net::UnixStream); #[allow(clippy::missing_safety_doc)] unsafe fn io_read( r: &mut R, buf: *mut u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { assert!(!buf.is_null()); let buf = slice::from_raw_parts_mut(buf, size); assert!(!out_errno.is_null()); let size = match r.read(buf) { Ok(size) => size, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } }; size.try_into().expect("read size should fit in a ssize_t") } #[allow(clippy::missing_safety_doc)] unsafe fn io_write( w: &mut W, buf: *const u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { assert!(!buf.is_null()); let buf = slice::from_raw_parts(buf, size); assert!(!out_errno.is_null()); let size = match w.write(buf) { Ok(size) => size, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } }; size.try_into().expect("write size should fit in a ssize_t") } #[no_mangle] pub extern "C" fn tcp_listener_bind( ip: *const c_char, port: u16, out_errno: *mut c_int, ) -> *mut TcpListener { assert!(!out_errno.is_null()); let ip = unsafe { CStr::from_ptr(ip) }; let ip = match ip.to_str() { Ok(s) => s, Err(_) => { unsafe { out_errno.write(libc::EINVAL) }; return ptr::null_mut(); } }; let ip: std::net::IpAddr = match ip.parse() { Ok(ip) => ip, Err(_) => { unsafe { out_errno.write(libc::EINVAL) }; return ptr::null_mut(); } }; let addr = std::net::SocketAddr::new(ip, port); let l = match std::net::TcpListener::bind(addr) { Ok(l) => l, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } }; if let Err(e) = l.set_nonblocking(true) { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } Box::into_raw(Box::new(TcpListener(l))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_listener_destroy(l: *mut TcpListener) { if !l.is_null() { drop(Box::from_raw(l)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_listener_local_addr( l: *const TcpListener, out_ip: *mut c_char, out_ip_size: *mut libc::size_t, out_port: *mut u16, ) -> c_int { let l = l.as_ref().unwrap(); let out_ip_size = out_ip_size.as_mut().unwrap(); assert!(!out_port.is_null()); let addr = match l.0.local_addr() { Ok(addr) => addr, Err(_) => return -1, }; let ip = addr.ip().to_string(); if ip.len() > *out_ip_size { // if value doesn't fit, return success with empty value *out_ip_size = 0; return 0; } ptr::copy(ip.as_bytes().as_ptr() as *const c_char, out_ip, ip.len()); *out_ip_size = ip.len(); out_port.write(addr.port()); 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_listener_as_raw_fd(l: *const TcpListener) -> c_int { let l = l.as_ref().unwrap(); l.0.as_raw_fd() } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_listener_accept( l: *const TcpListener, out_errno: *mut c_int, ) -> *mut TcpStream { let l = l.as_ref().unwrap(); assert!(!out_errno.is_null()); let s = match l.0.accept() { Ok((s, _)) => s, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); out_errno.write(code); return ptr::null_mut(); } }; if let Err(e) = s.set_nonblocking(true) { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } Box::into_raw(Box::new(TcpStream(s))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_connect( ip: *const c_char, port: u16, out_errno: *mut c_int, ) -> *mut TcpStream { assert!(!out_errno.is_null()); let ip = unsafe { CStr::from_ptr(ip) }; let ip = match ip.to_str() { Ok(s) => s, Err(_) => { unsafe { out_errno.write(libc::EINVAL) }; return ptr::null_mut(); } }; let ip: std::net::IpAddr = match ip.parse() { Ok(ip) => ip, Err(_) => { unsafe { out_errno.write(libc::EINVAL) }; return ptr::null_mut(); } }; let addr = std::net::SocketAddr::new(ip, port); // use mio to ensure socket begins in non-blocking mode let s = match mio::net::TcpStream::connect(addr) { Ok(s) => s, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } }; // SAFETY: converting from valid object let s = unsafe { std::net::TcpStream::from_raw_fd(s.into_raw_fd()) }; Box::into_raw(Box::new(TcpStream(s))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_destroy(s: *mut TcpStream) { if !s.is_null() { drop(Box::from_raw(s)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_check_connected( s: *const TcpStream, out_errno: *mut c_int, ) -> c_int { let s = s.as_ref().unwrap(); assert!(!out_errno.is_null()); // mio documentation says to use take_error() and peer_addr() to // check for connected if let Ok(Some(e)) | Err(e) = s.0.take_error() { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } // returns libc::ENOTCONN if not yet connected if let Err(e) = s.0.peer_addr() { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_as_raw_fd(s: *const TcpStream) -> c_int { let s = s.as_ref().unwrap(); s.0.as_raw_fd() } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_read( s: *mut TcpStream, buf: *mut u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { let s = s.as_mut().unwrap(); io_read(&mut s.0, buf, size, out_errno) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn tcp_stream_write( s: *mut TcpStream, buf: *const u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { let s = s.as_mut().unwrap(); io_write(&mut s.0, buf, size, out_errno) } #[no_mangle] pub extern "C" fn unix_listener_bind( path: *const c_char, out_errno: *mut c_int, ) -> *mut UnixListener { assert!(!out_errno.is_null()); let path = unsafe { CStr::from_ptr(path) }; let path = Path::new(OsStr::from_bytes(path.to_bytes())); let l = match std::os::unix::net::UnixListener::bind(path) { Ok(l) => l, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } }; if let Err(e) = l.set_nonblocking(true) { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } Box::into_raw(Box::new(UnixListener(l))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_listener_destroy(l: *mut UnixListener) { if !l.is_null() { drop(Box::from_raw(l)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_listener_as_raw_fd(l: *const UnixListener) -> c_int { let l = l.as_ref().unwrap(); l.0.as_raw_fd() } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_listener_accept( l: *const UnixListener, out_errno: *mut c_int, ) -> *mut UnixStream { let l = l.as_ref().unwrap(); assert!(!out_errno.is_null()); let s = match l.0.accept() { Ok((s, _)) => s, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); out_errno.write(code); return ptr::null_mut(); } }; if let Err(e) = s.set_nonblocking(true) { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } Box::into_raw(Box::new(UnixStream(s))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_connect( path: *const c_char, out_errno: *mut c_int, ) -> *mut UnixStream { assert!(!out_errno.is_null()); let path = unsafe { CStr::from_ptr(path) }; let path = Path::new(OsStr::from_bytes(path.to_bytes())); // use mio to ensure socket begins in non-blocking mode let s = match mio::net::UnixStream::connect(path) { Ok(s) => s, Err(e) => { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return ptr::null_mut(); } }; // SAFETY: converting from valid object let s = unsafe { std::os::unix::net::UnixStream::from_raw_fd(s.into_raw_fd()) }; Box::into_raw(Box::new(UnixStream(s))) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_destroy(s: *mut UnixStream) { if !s.is_null() { drop(Box::from_raw(s)); } } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_check_connected( s: *const UnixStream, out_errno: *mut c_int, ) -> c_int { let s = s.as_ref().unwrap(); assert!(!out_errno.is_null()); // mio documentation says to use take_error() and peer_addr() to // check for connected if let Ok(Some(e)) | Err(e) = s.0.take_error() { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } // returns libc::ENOTCONN if not yet connected if let Err(e) = s.0.peer_addr() { let code = e.raw_os_error().unwrap_or(libc::EINVAL); unsafe { out_errno.write(code) }; return -1; } 0 } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_as_raw_fd(s: *const UnixStream) -> c_int { let s = s.as_ref().unwrap(); s.0.as_raw_fd() } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_read( s: *mut UnixStream, buf: *mut u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { let s = s.as_mut().unwrap(); io_read(&mut s.0, buf, size, out_errno) } #[allow(clippy::missing_safety_doc)] #[no_mangle] pub unsafe extern "C" fn unix_stream_write( s: *mut UnixStream, buf: *const u8, size: libc::size_t, out_errno: *mut c_int, ) -> libc::ssize_t { let s = s.as_mut().unwrap(); io_write(&mut s.0, buf, size, out_errno) } } #[cfg(test)] mod tests { use super::*; use crate::core::executor::Executor; use crate::core::io::{AsyncReadExt, AsyncWriteExt}; use crate::core::reactor::Reactor; use std::fs; use std::str; #[test] fn async_tcpstream() { let reactor = Reactor::new(3); // 3 registrations let executor = Executor::new(2); // 2 tasks let spawner = executor.spawner(); executor .spawn(async move { let addr = "127.0.0.1:0".parse().unwrap(); let listener = AsyncTcpListener::bind(addr).expect("failed to bind"); let addr = listener.local_addr().unwrap(); spawner .spawn(async move { let mut stream = AsyncTcpStream::connect(&[addr]).await.unwrap(); let size = stream.write("hello".as_bytes()).await.unwrap(); assert_eq!(size, 5); }) .unwrap(); let (stream, _) = listener.accept().await.unwrap(); let mut stream = AsyncTcpStream::new(stream); let mut resp = [0u8; 1024]; let mut resp = io::Cursor::new(&mut resp[..]); loop { let mut buf = [0; 1024]; let size = stream.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.write(&buf[..size]).unwrap(); } let size = resp.position() as usize; let resp = str::from_utf8(&resp.get_ref()[..size]).unwrap(); assert_eq!(resp, "hello"); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); } #[test] fn async_unixstream() { // ensure pipe file doesn't exist match fs::remove_file("test-unixstream") { Ok(()) => {} Err(e) if e.kind() == io::ErrorKind::NotFound => {} Err(e) => panic!("{}", e), } let reactor = Reactor::new(3); // 3 registrations let executor = Executor::new(2); // 2 tasks let spawner = executor.spawner(); executor .spawn(async move { let listener = AsyncUnixListener::bind("test-unixstream").expect("failed to bind"); spawner .spawn(async move { let mut stream = AsyncUnixStream::connect("test-unixstream").await.unwrap(); let size = stream.write("hello".as_bytes()).await.unwrap(); assert_eq!(size, 5); }) .unwrap(); let (stream, _) = listener.accept().await.unwrap(); let mut stream = AsyncUnixStream::new(stream); let mut resp = Vec::new(); loop { let mut buf = [0; 1024]; let size = stream.read(&mut buf).await.unwrap(); if size == 0 { break; } resp.extend(&buf[..size]); } let resp = str::from_utf8(&resp).unwrap(); assert_eq!(resp, "hello"); }) .unwrap(); executor.run(|timeout| reactor.poll(timeout)).unwrap(); fs::remove_file("test-unixstream").unwrap(); } } pushpin-1.41.0/src/core/packet/000077500000000000000000000000001504671364300162625ustar00rootroot00000000000000pushpin-1.41.0/src/core/packet/httprequestdata.h000066400000000000000000000016231504671364300216570ustar00rootroot00000000000000/* * Copyright (C) 2012-2013 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef HTTPREQUESTDATA_H #define HTTPREQUESTDATA_H #include "../httpheaders.h" #include class HttpRequestData { public: QString method; QUrl uri; HttpHeaders headers; QByteArray body; }; #endif pushpin-1.41.0/src/core/packet/httpresponsedata.h000066400000000000000000000016611504671364300220270ustar00rootroot00000000000000/* * Copyright (C) 2012-2013 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef HTTPRESPONSEDATA_H #define HTTPRESPONSEDATA_H #include "../httpheaders.h" class HttpResponseData { public: int code; QByteArray reason; HttpHeaders headers; QByteArray body; HttpResponseData() : code(-1) { } }; #endif pushpin-1.41.0/src/core/packet/retryrequestpacket.cpp000066400000000000000000000214661504671364300227450ustar00rootroot00000000000000/* * Copyright (C) 2012-2023 Fanout, Inc. * Copyright (C) 2023-2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "retryrequestpacket.h" #include "qtcompat.h" RetryRequestPacket::RetryRequestPacket() : haveInspectInfo(false), retrySeq(-1) { } QVariant RetryRequestPacket::toVariant() const { QVariantHash obj; QVariantList vrequests; foreach(const Request &r, requests) { QVariantHash vrequest; QVariantHash vrid; vrid["sender"] = r.rid.first; vrid["id"] = r.rid.second; vrequest["rid"] = vrid; if(r.https) vrequest["https"] = true; if(!r.peerAddress.isNull()) vrequest["peer-address"] = r.peerAddress.toString().toUtf8(); if(r.debug) vrequest["debug"] = true; if(r.autoCrossOrigin) vrequest["auto-cross-origin"] = true; if(!r.jsonpCallback.isEmpty()) vrequest["jsonp-callback"] = r.jsonpCallback; if(r.jsonpExtendedResponse) vrequest["jsonp-extended-response"] = true; if(r.unreportedTime > 0) vrequest["unreported-time"] = r.unreportedTime; vrequest["in-seq"] = r.inSeq; vrequest["out-seq"] = r.outSeq; vrequest["out-credits"] = r.outCredits; if(r.routerResp) vrequest["router-resp"] = r.routerResp; if(r.userData.isValid()) vrequest["user-data"] = r.userData; vrequests += vrequest; } obj["requests"] = vrequests; QVariantHash vrequestData; vrequestData["method"] = requestData.method.toLatin1(); vrequestData["uri"] = requestData.uri.toEncoded(); QVariantList vheaders; foreach(const HttpHeader &h, requestData.headers) { QVariantList vheader; vheader += h.first; vheader += h.second; vheaders += QVariant(vheader); } vrequestData["headers"] = vheaders; vrequestData["body"] = requestData.body; obj["request-data"] = vrequestData; if(haveInspectInfo) { QVariantHash vinspect; vinspect["no-proxy"] = !inspectInfo.doProxy; if(!inspectInfo.sharingKey.isEmpty()) vinspect["sharing-key"] = inspectInfo.sharingKey; if(!inspectInfo.sid.isEmpty()) vinspect["sid"] = inspectInfo.sid; if(!inspectInfo.lastIds.isEmpty()) { QVariantHash vlastIds; QHashIterator it(inspectInfo.lastIds); while(it.hasNext()) { it.next(); vlastIds[QString::fromUtf8(it.key())] = it.value(); } vinspect["last-ids"] = vlastIds; } if(inspectInfo.userData.isValid()) vinspect["user-data"] = inspectInfo.userData; obj["inspect"] = vinspect; } if(!route.isEmpty()) obj["route"] = route; if(retrySeq >= 0) obj["retry-seq"] = retrySeq; return obj; } bool RetryRequestPacket::fromVariant(const QVariant &in) { if(typeId(in) != QMetaType::QVariantHash) return false; QVariantHash obj = in.toHash(); if(!obj.contains("requests") || typeId(obj["requests"]) != QMetaType::QVariantList) return false; requests.clear(); foreach(const QVariant &i, obj["requests"].toList()) { if(typeId(i) != QMetaType::QVariantHash) return false; QVariantHash vrequest = i.toHash(); Request r; if(!vrequest.contains("rid") || typeId(vrequest["rid"]) != QMetaType::QVariantHash) return false; QVariantHash vrid = vrequest["rid"].toHash(); QByteArray sender, id; if(!vrid.contains("sender") || typeId(vrid["sender"]) != QMetaType::QByteArray) return false; sender = vrid["sender"].toByteArray(); if(!vrid.contains("id") || typeId(vrid["id"]) != QMetaType::QByteArray) return false; id = vrid["id"].toByteArray(); r.rid = Rid(sender, id); if(vrequest.contains("https")) { if(typeId(vrequest["https"]) != QMetaType::Bool) return false; r.https = vrequest["https"].toBool(); } if(vrequest.contains("peer-address")) { if(typeId(vrequest["peer-address"]) != QMetaType::QByteArray) return false; r.peerAddress = QHostAddress(QString::fromUtf8(vrequest["peer-address"].toByteArray())); } if(vrequest.contains("debug")) { if(typeId(vrequest["debug"]) != QMetaType::Bool) return false; r.debug = vrequest["debug"].toBool(); } if(vrequest.contains("auto-cross-origin")) { if(typeId(vrequest["auto-cross-origin"]) != QMetaType::Bool) return false; r.autoCrossOrigin = vrequest["auto-cross-origin"].toBool(); } if(vrequest.contains("jsonp-callback")) { if(typeId(vrequest["jsonp-callback"]) != QMetaType::QByteArray) return false; r.jsonpCallback = vrequest["jsonp-callback"].toByteArray(); if(vrequest.contains("jsonp-extended-response")) { if(typeId(vrequest["jsonp-extended-response"]) != QMetaType::Bool) return false; r.jsonpExtendedResponse = vrequest["jsonp-extended-response"].toBool(); } } if(vrequest.contains("unreported-time")) { if(!canConvert(vrequest["unreported-time"], QMetaType::Int)) return false; r.unreportedTime = vrequest["unreported-time"].toInt(); } if(!vrequest.contains("in-seq") || !canConvert(vrequest["in-seq"], QMetaType::Int)) return false; r.inSeq = vrequest["in-seq"].toInt(); if(!vrequest.contains("out-seq") || !canConvert(vrequest["out-seq"], QMetaType::Int)) return false; r.outSeq = vrequest["out-seq"].toInt(); if(!vrequest.contains("out-credits") || !canConvert(vrequest["out-credits"], QMetaType::Int)) return false; r.outCredits = vrequest["out-credits"].toInt(); if(vrequest.contains("router-resp")) { if(typeId(vrequest["router-resp"]) != QMetaType::Bool) return false; r.routerResp = vrequest["router-resp"].toBool(); } if(vrequest.contains("user-data")) r.userData = vrequest["user-data"]; requests += r; } if(!obj.contains("request-data") || typeId(obj["request-data"]) != QMetaType::QVariantHash) return false; QVariantHash vrequestData = obj["request-data"].toHash(); if(!vrequestData.contains("method") || typeId(vrequestData["method"]) != QMetaType::QByteArray) return false; requestData.method = QString::fromLatin1(vrequestData["method"].toByteArray()); if(!vrequestData.contains("uri") || typeId(vrequestData["uri"]) != QMetaType::QByteArray) return false; requestData.uri = QUrl::fromEncoded(vrequestData["uri"].toByteArray(), QUrl::StrictMode); requestData.headers.clear(); if(vrequestData.contains("headers")) { if(typeId(vrequestData["headers"]) != QMetaType::QVariantList) return false; foreach(const QVariant &i, vrequestData["headers"].toList()) { QVariantList list = i.toList(); if(list.count() != 2) return false; if(typeId(list[0]) != QMetaType::QByteArray || typeId(list[1]) != QMetaType::QByteArray) return false; requestData.headers += QPair(list[0].toByteArray(), list[1].toByteArray()); } } if(!vrequestData.contains("body") || typeId(vrequestData["body"]) != QMetaType::QByteArray) return false; requestData.body = vrequestData["body"].toByteArray(); if(obj.contains("inspect")) { if(typeId(obj["inspect"]) != QMetaType::QVariantHash) return false; QVariantHash vinspect = obj["inspect"].toHash(); if(!vinspect.contains("no-proxy") || typeId(vinspect["no-proxy"]) != QMetaType::Bool) return false; inspectInfo.doProxy = !vinspect["no-proxy"].toBool(); inspectInfo.sharingKey.clear(); if(vinspect.contains("sharing-key")) { if(typeId(vinspect["sharing-key"]) != QMetaType::QByteArray) return false; inspectInfo.sharingKey = vinspect["sharing-key"].toByteArray(); } if(vinspect.contains("sid")) { if(typeId(vinspect["sid"]) != QMetaType::QByteArray) return false; inspectInfo.sid = vinspect["sid"].toByteArray(); } if(vinspect.contains("last-ids")) { if(typeId(vinspect["last-ids"]) != QMetaType::QVariantHash) return false; QVariantHash vlastIds = vinspect["last-ids"].toHash(); QHashIterator it(vlastIds); while(it.hasNext()) { it.next(); if(typeId(it.value()) != QMetaType::QByteArray) return false; QByteArray key = it.key().toUtf8(); QByteArray val = it.value().toByteArray(); inspectInfo.lastIds.insert(key, val); } } inspectInfo.userData = vinspect["user-data"]; haveInspectInfo = true; } if(obj.contains("route")) { if(typeId(obj["route"]) != QMetaType::QByteArray) return false; route = obj["route"].toByteArray(); } if(obj.contains("retry-seq")) { if(!canConvert(obj["retry-seq"], QMetaType::Int)) return false; retrySeq = obj["retry-seq"].toInt(); } return true; } pushpin-1.41.0/src/core/packet/retryrequestpacket.h000066400000000000000000000035731504671364300224110ustar00rootroot00000000000000/* * Copyright (C) 2012-2023 Fanout, Inc. * Copyright (C) 2023-2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef RETRYREQUESTPACKET_H #define RETRYREQUESTPACKET_H #include #include #include "httprequestdata.h" class RetryRequestPacket { public: typedef QPair Rid; class Request { public: Rid rid; bool https; QHostAddress peerAddress; bool debug; bool autoCrossOrigin; QByteArray jsonpCallback; bool jsonpExtendedResponse; int unreportedTime; // zhttp int inSeq; int outSeq; int outCredits; bool routerResp; QVariant userData; Request() : https(false), debug(false), autoCrossOrigin(false), jsonpExtendedResponse(false), unreportedTime(-1), inSeq(-1), outSeq(-1), outCredits(-1), routerResp(false) { } }; class InspectInfo { public: bool doProxy; QByteArray sharingKey; QByteArray sid; QHash lastIds; QVariant userData; InspectInfo() : doProxy(false) { } }; QList requests; HttpRequestData requestData; bool haveInspectInfo; InspectInfo inspectInfo; QByteArray route; int retrySeq; RetryRequestPacket(); QVariant toVariant() const; bool fromVariant(const QVariant &in); }; #endif pushpin-1.41.0/src/core/packet/statspacket.cpp000066400000000000000000000257641504671364300213320ustar00rootroot00000000000000/* * Copyright (C) 2014-2023 Fanout, Inc. * Copyright (C) 2023-2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "statspacket.h" #include "qtcompat.h" static bool tryGetInt(const QVariantHash &obj, const QString &name, int *result) { if(obj.contains(name)) { if(!canConvert(obj[name], QMetaType::Int)) return false; *result = obj[name].toInt(); } return true; } QVariant StatsPacket::toVariant() const { QVariantHash obj; if(!from.isEmpty()) obj["from"] = from; if(!route.isEmpty()) obj["route"] = route; if(type == Activity) { int x = count; if(x < 0) x = 0; obj["count"] = x; } else if(type == Message) { obj["channel"] = channel; if(!itemId.isNull()) obj["item-id"] = itemId; int x = count; if(x < 0) x = 0; obj["count"] = x; if(blocks >= 0) obj["blocks"] = blocks; obj["transport"] = transport; } else if(type == Connected || type == Disconnected) { obj["id"] = connectionId; if(type == Connected) { if(connectionType == WebSocket) obj["type"] = QByteArray("ws"); else // Http obj["type"] = QByteArray("http"); if(!peerAddress.isNull()) obj["peer-address"] = peerAddress.toString().toUtf8(); if(ssl) obj["ssl"] = true; obj["ttl"] = ttl; } else // Disconnected { obj["unavailable"] = true; } } else if(type == Subscribed || type == Unsubscribed) { obj["mode"] = mode; obj["channel"] = channel; if(type == Subscribed) { obj["ttl"] = ttl; if(subscribers >= 0) obj["subscribers"] = subscribers; } else // Unsubscribed { obj["unavailable"] = true; } } else if(type == Report) { if(connectionsMax != -1) obj["connections"] = connectionsMax; if(connectionsMinutes != -1) obj["minutes"] = connectionsMinutes; if(messagesReceived != -1) obj["received"] = messagesReceived; if(messagesSent != -1) obj["sent"] = messagesSent; if(httpResponseMessagesSent != -1) obj["http-response-sent"] = httpResponseMessagesSent; if(blocksReceived >= 0) obj["blocks-received"] = blocksReceived; if(blocksSent >= 0) obj["blocks-sent"] = blocksSent; if(duration >= 0) obj["duration"] = duration; if(clientHeaderBytesReceived >= 0) obj["client-header-bytes-received"] = clientHeaderBytesReceived; if(clientHeaderBytesSent >= 0) obj["client-header-bytes-sent"] = clientHeaderBytesSent; if(clientContentBytesReceived >= 0) obj["client-content-bytes-received"] = clientContentBytesReceived; if(clientContentBytesSent >= 0) obj["client-content-bytes-sent"] = clientContentBytesSent; if(clientMessagesReceived >= 0) obj["client-messages-received"] = clientMessagesReceived; if(clientMessagesSent >= 0) obj["client-messages-sent"] = clientMessagesSent; if(serverHeaderBytesReceived >= 0) obj["server-header-bytes-received"] = serverHeaderBytesReceived; if(serverHeaderBytesSent >= 0) obj["server-header-bytes-sent"] = serverHeaderBytesSent; if(serverContentBytesReceived >= 0) obj["server-content-bytes-received"] = serverContentBytesReceived; if(serverContentBytesSent >= 0) obj["server-content-bytes-sent"] = serverContentBytesSent; if(serverMessagesReceived >= 0) obj["server-messages-received"] = serverMessagesReceived; if(serverMessagesSent >= 0) obj["server-messages-sent"] = serverMessagesSent; } else if(type == Counts) { if(requestsReceived > 0) obj["requests-received"] = requestsReceived; } else // ConnectionsMax { obj["max"] = qMax(connectionsMax, 0); obj["ttl"] = qMax(ttl, 0); if(retrySeq >= 0) obj["retry-seq"] = retrySeq; } return obj; } bool StatsPacket::fromVariant(const QByteArray &_type, const QVariant &in) { if(typeId(in) != QMetaType::QVariantHash) return false; QVariantHash obj = in.toHash(); if(obj.contains("from")) { if(typeId(obj["from"]) != QMetaType::QByteArray) return false; from = obj["from"].toByteArray(); } if(obj.contains("route")) { if(typeId(obj["route"]) != QMetaType::QByteArray) return false; route = obj["route"].toByteArray(); } if(_type == "activity") { type = Activity; if(!obj.contains("count") || !canConvert(obj["count"], QMetaType::Int)) return false; count = obj["count"].toInt(); if(count < 0) return false; } else if(_type == "message") { type = Message; if(!obj.contains("channel") || typeId(obj["channel"]) != QMetaType::QByteArray) return false; channel = obj["channel"].toByteArray(); if(obj.contains("item-id")) { if(typeId(obj["item-id"]) != QMetaType::QByteArray) return false; itemId = obj["item-id"].toByteArray(); } if(!obj.contains("count") || !canConvert(obj["count"], QMetaType::Int)) return false; count = obj["count"].toInt(); if(count < 0) return false; if(obj.contains("blocks")) { if(!canConvert(obj["blocks"], QMetaType::Int)) return false; blocks = obj["blocks"].toInt(); } if(!obj.contains("transport") || typeId(obj["transport"]) != QMetaType::QByteArray) return false; transport = obj["transport"].toByteArray(); } else if(_type == "conn") { if(!obj.contains("id") || typeId(obj["id"]) != QMetaType::QByteArray) return false; connectionId = obj["id"].toByteArray(); type = Connected; if(obj.contains("unavailable")) { if(typeId(obj["unavailable"]) != QMetaType::Bool) return false; if(obj["unavailable"].toBool()) type = Disconnected; } if(type == Connected) { if(!obj.contains("type") || typeId(obj["type"]) != QMetaType::QByteArray) return false; QByteArray typeStr = obj["type"].toByteArray(); if(typeStr == "ws") connectionType = WebSocket; else if(typeStr == "http") connectionType = Http; else return false; if(obj.contains("peer-address")) { if(typeId(obj["peer-address"]) != QMetaType::QByteArray) return false; QByteArray peerAddressStr = obj["peer-address"].toByteArray(); if(!peerAddress.setAddress(QString::fromUtf8(peerAddressStr))) return false; } if(obj.contains("ssl")) { if(typeId(obj["ssl"]) != QMetaType::Bool) return false; ssl = obj["ssl"].toBool(); } if(!obj.contains("ttl") || !canConvert(obj["ttl"], QMetaType::Int)) return false; ttl = obj["ttl"].toInt(); if(ttl < 0) return false; } } else if(_type == "sub") { if(!obj.contains("mode") || typeId(obj["mode"]) != QMetaType::QByteArray) return false; mode = obj["mode"].toByteArray(); if(!obj.contains("channel") || typeId(obj["channel"]) != QMetaType::QByteArray) return false; channel = obj["channel"].toByteArray(); type = Subscribed; if(obj.contains("unavailable")) { if(typeId(obj["unavailable"]) != QMetaType::Bool) return false; if(obj["unavailable"].toBool()) type = Unsubscribed; } if(type == Subscribed) { if(!obj.contains("ttl") || !canConvert(obj["ttl"], QMetaType::Int)) return false; ttl = obj["ttl"].toInt(); if(ttl < 0) return false; if(obj.contains("subscribers")) { if(!canConvert(obj["subscribers"], QMetaType::Int)) return false; subscribers = obj["subscribers"].toInt(); if(subscribers < 0) return false; } } } else if(_type == "report") { type = Report; if(obj.contains("connections")) { if(!canConvert(obj["connections"], QMetaType::Int)) return false; connectionsMax = obj["connections"].toInt(); } if(obj.contains("minutes")) { if(!canConvert(obj["minutes"], QMetaType::Int)) return false; connectionsMinutes = obj["minutes"].toInt(); } if(obj.contains("received")) { if(!canConvert(obj["received"], QMetaType::Int)) return false; messagesReceived = obj["received"].toInt(); } if(obj.contains("sent")) { if(!canConvert(obj["sent"], QMetaType::Int)) return false; messagesSent = obj["sent"].toInt(); } if(obj.contains("http-response-sent")) { if(!canConvert(obj["http-response-sent"], QMetaType::Int)) return false; httpResponseMessagesSent = obj["http-response-sent"].toInt(); } if(obj.contains("blocks-received")) { if(!canConvert(obj["blocks-received"], QMetaType::Int)) return false; blocksReceived = obj["blocks-received"].toInt(); } if(obj.contains("blocks-sent")) { if(!canConvert(obj["blocks-sent"], QMetaType::Int)) return false; blocksSent = obj["blocks-sent"].toInt(); } if(obj.contains("duration")) { if(!canConvert(obj["duration"], QMetaType::Int)) return false; duration = obj["duration"].toInt(); } if(!tryGetInt(obj, "client-header-bytes-received", &clientHeaderBytesReceived)) return false; if(!tryGetInt(obj, "client-header-bytes-sent", &clientHeaderBytesSent)) return false; if(!tryGetInt(obj, "client-content-bytes-received", &clientContentBytesReceived)) return false; if(!tryGetInt(obj, "client-content-bytes-sent", &clientContentBytesSent)) return false; if(!tryGetInt(obj, "client-messages-received", &clientMessagesReceived)) return false; if(!tryGetInt(obj, "client-messages-sent", &clientMessagesSent)) return false; if(!tryGetInt(obj, "server-header-bytes-received", &serverHeaderBytesReceived)) return false; if(!tryGetInt(obj, "server-header-bytes-sent", &serverHeaderBytesSent)) return false; if(!tryGetInt(obj, "server-content-bytes-received", &serverContentBytesReceived)) return false; if(!tryGetInt(obj, "server-content-bytes-sent", &serverContentBytesSent)) return false; if(!tryGetInt(obj, "server-messages-received", &serverMessagesReceived)) return false; if(!tryGetInt(obj, "server-messages-sent", &serverMessagesSent)) return false; } else if(_type == "counts") { type = Counts; if(obj.contains("requests-received")) { if(!canConvert(obj["requests-received"], QMetaType::Int)) return false; int x = obj["requests-received"].toInt(); if(x < 0) return false; requestsReceived = x; } } else if(_type == "conn-max") { type = ConnectionsMax; if(!obj.contains("max") || !canConvert(obj["max"], QMetaType::Int)) return false; int x = obj["max"].toInt(); if(x < 0) return false; connectionsMax = x; if(!obj.contains("ttl") || !canConvert(obj["ttl"], QMetaType::Int)) return false; x = obj["ttl"].toInt(); if(x < 0) return false; ttl = x; if(obj.contains("retry-seq")) { if(!canConvert(obj["retry-seq"], QMetaType::LongLong)) return false; int x = obj["retry-seq"].toLongLong(); if(x < 0) return false; retrySeq = x; } } else return false; return true; } pushpin-1.41.0/src/core/packet/statspacket.h000066400000000000000000000062471504671364300207720ustar00rootroot00000000000000/* * Copyright (C) 2014-2022 Fanout, Inc. * Copyright (C) 2023 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef STATSPACKET_H #define STATSPACKET_H #include #include #include class StatsPacket { public: enum Type { Activity, Message, Connected, Disconnected, Subscribed, Unsubscribed, Report, Counts, ConnectionsMax, }; enum ConnectionType { Http, WebSocket }; Type type; QByteArray from; QByteArray route; qint64 retrySeq; // connections max int count; // activity, message QByteArray connectionId; // connected, disconnected ConnectionType connectionType; // connected QHostAddress peerAddress; // connected bool ssl; // connected int ttl; // connected, subscribed, connections max QByteArray mode; // subscribed, unsubscribed QByteArray channel; // message, subscribed, unsubscribed QByteArray itemId; // message QByteArray transport; // message int blocks; // message int subscribers; // subscribed int connectionsMax; // report, connections max int connectionsMinutes; // report int messagesReceived; // report int messagesSent; // report int httpResponseMessagesSent; // report int blocksReceived; // report int blocksSent; // report int duration; // report int requestsReceived; // counts int clientHeaderBytesReceived; // report int clientHeaderBytesSent; // report int clientContentBytesReceived; // report int clientContentBytesSent; // report int clientMessagesReceived; // report int clientMessagesSent; // report int serverHeaderBytesReceived; // report int serverHeaderBytesSent; // report int serverContentBytesReceived; // report int serverContentBytesSent; // report int serverMessagesReceived; // report int serverMessagesSent; // report StatsPacket() : type((Type)-1), retrySeq(-1), count(-1), connectionType((ConnectionType)-1), ssl(false), ttl(-1), blocks(-1), subscribers(-1), connectionsMax(-1), connectionsMinutes(-1), messagesReceived(-1), messagesSent(-1), httpResponseMessagesSent(-1), blocksReceived(-1), blocksSent(-1), duration(-1), requestsReceived(-1), clientHeaderBytesReceived(-1), clientHeaderBytesSent(-1), clientContentBytesReceived(-1), clientContentBytesSent(-1), clientMessagesReceived(-1), clientMessagesSent(-1), serverHeaderBytesReceived(-1), serverHeaderBytesSent(-1), serverContentBytesReceived(-1), serverContentBytesSent(-1), serverMessagesReceived(-1), serverMessagesSent(-1) { } QVariant toVariant() const; bool fromVariant(const QByteArray &type, const QVariant &in); }; #endif pushpin-1.41.0/src/core/packet/wscontrolpacket.cpp000066400000000000000000000244541504671364300222210ustar00rootroot00000000000000/* * Copyright (C) 2014-2022 Fanout, Inc. * Copyright (C) 2024-2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "wscontrolpacket.h" #include #include "qtcompat.h" // FIXME: rewrite packet class using this code? /*class WsControlPacket { public: class Message { public: enum Type { Here, Gone, Cancel, Grip }; Type type; QString cid; QString channelPrefix; // here only QByteArray message; // grip only }; QString channelPrefix; QList messages; static WsControlPacket fromVariant(const QVariant &in, bool *ok = 0, QString *errorMessage = 0) { QString pn = "wscontrol packet"; if(!isKeyedObject(in)) { setError(ok, errorMessage, QString("%1 is not an object").arg(pn)); return WsControlPacket(); } pn = "wscontrol object"; bool ok_; QVariantList vitems = getList(in, pn, "items", false, &ok_, errorMessage); if(!ok_) { if(ok) *ok = false; return WsControlPacket(); } WsControlPacket out; foreach(const QVariant &vitem, vitems) { Message msg; pn = "wscontrol item"; QString type = getString(vitem, pn, "type", true, &ok_, errorMessage); if(!ok_) { if(ok) *ok = false; return WsControlPacket(); } if(type == "here") msg.type = Message::Here; else if(type == "gone") msg.type = Message::Gone; else if(type == "cancel") msg.type = Message::Cancel; else if(type == "grip") msg.type = Message::Grip; else { setError(ok, errorMessage, QString("'type' contains unknown value: %1").arg(type)); return WsControlPacket(); } msg.cid = getString(vitem, pn, "cid", true, &ok_, errorMessage); if(!ok_) { if(ok) *ok = false; return WsControlPacket(); } msg.uri = QUrl::fromEncoded(getString(vitem, pn, "uri", false, &ok_, errorMessage).toUtf8(), QUrl::StrictMode); if(!ok_) { if(ok) *ok = false; return WsControlPacket(); } msg.channelPrefix = getString(vitem, pn, "channel-prefix", false, &ok_, errorMessage); if(!ok_) { if(ok) *ok = false; return WsControlPacket(); } if(msg.type == Message::Grip) { if(!keyedObjectContains(vitem, "message")) { setError(ok, errorMessage, QString("'%1' does not contain 'message'").arg(pn)); return WsControlPacket(); } QVariant vmessage = keyedObjectGetValue(vitem, "message"); if(vmessage.type() != QVariant::ByteArray) { setError(ok, errorMessage, QString("'%1' contains 'message' with wrong type").arg(pn)); return WsControlPacket(); } msg.message = vmessage.toByteArray(); } out.messages += msg; } setSuccess(ok, errorMessage); return out; } };*/ QVariant WsControlPacket::toVariant() const { QVariantHash obj; obj["from"] = from; QVariantList vitems; foreach(const Item &item, items) { QVariantHash vitem; vitem["cid"] = item.cid; QByteArray typeStr; switch(item.type) { case Item::Here: typeStr = "here"; break; case Item::KeepAlive: typeStr = "keep-alive"; break; case Item::Gone: typeStr = "gone"; break; case Item::Grip: typeStr = "grip"; break; case Item::KeepAliveSetup: typeStr = "keep-alive-setup"; break; case Item::Cancel: typeStr = "cancel"; break; case Item::Send: typeStr = "send"; break; case Item::NeedKeepAlive: typeStr = "need-keep-alive"; break; case Item::Subscribe: typeStr = "subscribe"; break; case Item::Refresh: typeStr = "refresh"; break; case Item::Close: typeStr = "close"; break; case Item::Detach: typeStr = "detach"; break; case Item::Ack: typeStr = "ack"; break; default: assert(0); } vitem["type"] = typeStr; if(!item.requestId.isEmpty()) vitem["req-id"] = item.requestId; if(!item.uri.isEmpty()) vitem["uri"] = item.uri.toEncoded(); if(!item.contentType.isEmpty()) vitem["content-type"] = item.contentType; if(!item.message.isNull()) vitem["message"] = item.message; if(item.queue) vitem["queue"] = true; if(item.code >= 0) vitem["code"] = item.code; if(!item.reason.isEmpty()) vitem["reason"] = item.reason; if(item.debug) vitem["debug"] = true; if(!item.route.isEmpty()) vitem["route"] = item.route; if(item.separateStats) vitem["separate-stats"] = true; if(!item.channelPrefix.isEmpty()) vitem["channel-prefix"] = item.channelPrefix; if(item.logLevel >= 0) vitem["log-level"] = item.logLevel; if(item.trusted) vitem["trusted"] = true; if(!item.channel.isEmpty()) vitem["channel"] = item.channel; if(item.ttl >= 0) vitem["ttl"] = item.ttl; if(item.timeout >= 0) vitem["timeout"] = item.timeout; if(!item.keepAliveMode.isEmpty()) vitem["keep-alive-mode"] = item.keepAliveMode; vitems += vitem; } obj["items"] = vitems; return obj; } bool WsControlPacket::fromVariant(const QVariant &in) { if(typeId(in) != QMetaType::QVariantHash) return false; QVariantHash obj = in.toHash(); if(!obj.contains("from") || typeId(obj["from"]) != QMetaType::QByteArray) return false; from = obj["from"].toByteArray(); if(!obj.contains("items") || typeId(obj["items"]) != QMetaType::QVariantList) return false; QVariantList vitems = obj["items"].toList(); items.clear(); foreach(const QVariant &v, vitems) { if(typeId(v) != QMetaType::QVariantHash) return false; QVariantHash vitem = v.toHash(); Item item; if(!vitem.contains("cid") || typeId(vitem["cid"]) != QMetaType::QByteArray) return false; item.cid = vitem["cid"].toByteArray(); if(!vitem.contains("type") || typeId(vitem["type"]) != QMetaType::QByteArray) return false; QByteArray typeStr = vitem["type"].toByteArray(); if(typeStr == "here") item.type = Item::Here; else if(typeStr == "keep-alive") item.type = Item::KeepAlive; else if(typeStr == "gone") item.type = Item::Gone; else if(typeStr == "grip") item.type = Item::Grip; else if(typeStr == "keep-alive-setup") item.type = Item::KeepAliveSetup; else if(typeStr == "cancel") item.type = Item::Cancel; else if(typeStr == "send") item.type = Item::Send; else if(typeStr == "need-keep-alive") item.type = Item::NeedKeepAlive; else if(typeStr == "subscribe") item.type = Item::Subscribe; else if(typeStr == "refresh") item.type = Item::Refresh; else if(typeStr == "close") item.type = Item::Close; else if(typeStr == "detach") item.type = Item::Detach; else if(typeStr == "ack") item.type = Item::Ack; else return false; if(vitem.contains("req-id")) { if(typeId(vitem["req-id"]) != QMetaType::QByteArray) return false; item.requestId = vitem["req-id"].toByteArray(); } if(vitem.contains("uri")) { if(typeId(vitem["uri"]) != QMetaType::QByteArray) return false; item.uri = QUrl::fromEncoded(vitem["uri"].toByteArray(), QUrl::StrictMode); } if(vitem.contains("content-type")) { if(typeId(vitem["content-type"]) != QMetaType::QByteArray) return false; QByteArray contentType = vitem["content-type"].toByteArray(); if(!contentType.isEmpty()) item.contentType = contentType; } if(vitem.contains("message")) { if(typeId(vitem["message"]) != QMetaType::QByteArray) return false; item.message = vitem["message"].toByteArray(); } if(vitem.contains("queue")) { if(typeId(vitem["queue"]) != QMetaType::Bool) return false; item.queue = vitem["queue"].toBool(); } if(vitem.contains("code")) { if(!canConvert(vitem["code"], QMetaType::Int)) return false; item.code = vitem["code"].toInt(); } if(vitem.contains("reason")) { if(typeId(vitem["reason"]) != QMetaType::QByteArray) return false; item.reason = vitem["reason"].toByteArray(); } if(vitem.contains("debug")) { if(typeId(vitem["debug"]) != QMetaType::Bool) return false; item.debug = vitem["debug"].toBool(); } if(vitem.contains("route")) { if(typeId(vitem["route"]) != QMetaType::QByteArray) return false; QByteArray route = vitem["route"].toByteArray(); if(!route.isEmpty()) item.route = route; } if(vitem.contains("separate-stats")) { if(typeId(vitem["separate-stats"]) != QMetaType::Bool) return false; item.separateStats = vitem["separate-stats"].toBool(); } if(vitem.contains("channel-prefix")) { if(typeId(vitem["channel-prefix"]) != QMetaType::QByteArray) return false; QByteArray channelPrefix = vitem["channel-prefix"].toByteArray(); if(!channelPrefix.isEmpty()) item.channelPrefix = channelPrefix; } if(vitem.contains("log-level")) { if(!canConvert(vitem["log-level"], QMetaType::Int)) return false; item.logLevel = vitem["log-level"].toInt(); } if(vitem.contains("trusted")) { if(typeId(vitem["trusted"]) != QMetaType::Bool) return false; item.trusted = vitem["trusted"].toBool(); } if(vitem.contains("channel")) { if(typeId(vitem["channel"]) != QMetaType::QByteArray) return false; QByteArray channel = vitem["channel"].toByteArray(); if(!channel.isEmpty()) item.channel = channel; } if(vitem.contains("ttl")) { if(!canConvert(vitem["ttl"], QMetaType::Int)) return false; item.ttl = vitem["ttl"].toInt(); if(item.ttl < 0) item.ttl = 0; } if(vitem.contains("timeout")) { if(!canConvert(vitem["timeout"], QMetaType::Int)) return false; item.timeout = vitem["timeout"].toInt(); if(item.timeout < 0) item.timeout = 0; } if(vitem.contains("keep-alive-mode")) { if(!canConvert(vitem["keep-alive-mode"], QMetaType::QByteArray)) return false; QByteArray keepAliveMode = vitem["keep-alive-mode"].toByteArray(); if(!keepAliveMode.isEmpty()) item.keepAliveMode = keepAliveMode; } items += item; } return true; } pushpin-1.41.0/src/core/packet/wscontrolpacket.h000066400000000000000000000033341504671364300216600ustar00rootroot00000000000000/* * Copyright (C) 2014-2022 Fanout, Inc. * Copyright (C) 2024-2025 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef WSCONTROLPACKET_H #define WSCONTROLPACKET_H #include #include #include #include class WsControlPacket { public: class Item { public: enum Type { Here, KeepAlive, Gone, Grip, NeedKeepAlive, Subscribe, Cancel, Send, KeepAliveSetup, Refresh, Close, Detach, Ack }; QByteArray cid; Type type; QByteArray requestId; QUrl uri; QByteArray contentType; QByteArray message; bool queue; int code; QByteArray reason; bool debug; QByteArray route; bool separateStats; QByteArray channelPrefix; int logLevel; bool trusted; QByteArray channel; int ttl; int timeout; QByteArray keepAliveMode; Item() : type((Type)-1), queue(false), code(-1), debug(false), separateStats(false), logLevel(-1), trusted(false), ttl(-1), timeout(-1) { } }; QByteArray from; QList items; QVariant toVariant() const; bool fromVariant(const QVariant &in); }; #endif pushpin-1.41.0/src/core/packet/zrpcrequestpacket.cpp000066400000000000000000000033411504671364300225460ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "zrpcrequestpacket.h" #include "qtcompat.h" QVariant ZrpcRequestPacket::toVariant() const { QVariantHash obj; if(!from.isEmpty()) obj["from"] = from; if(!id.isEmpty()) obj["id"] = id; obj["method"] = method.toUtf8(); if(!args.isEmpty()) obj["args"] = args; return obj; } bool ZrpcRequestPacket::fromVariant(const QVariant &in) { if(typeId(in) != QMetaType::QVariantHash) return false; QVariantHash obj = in.toHash(); if(obj.contains("from")) { if(typeId(obj["from"]) != QMetaType::QByteArray) return false; from = obj["from"].toByteArray(); } if(obj.contains("id")) { if(typeId(obj["id"]) != QMetaType::QByteArray) return false; id = obj["id"].toByteArray(); } if(!obj.contains("method") || typeId(obj["method"]) != QMetaType::QByteArray) return false; method = QString::fromUtf8(obj["method"].toByteArray()); if(obj.contains("args")) { if(typeId(obj["args"]) != QMetaType::QVariantHash) return false; args = obj["args"].toHash(); } return true; } pushpin-1.41.0/src/core/packet/zrpcrequestpacket.h000066400000000000000000000017751504671364300222240ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef ZRPCREQUESTPACKET_H #define ZRPCREQUESTPACKET_H #include #include class ZrpcRequestPacket { public: QByteArray from; QByteArray id; QString method; QVariantHash args; QVariant toVariant() const; bool fromVariant(const QVariant &in); }; #endif pushpin-1.41.0/src/core/packet/zrpcresponsepacket.cpp000066400000000000000000000040001504671364300227050ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "zrpcresponsepacket.h" #include "qtcompat.h" QVariant ZrpcResponsePacket::toVariant() const { QVariantHash obj; if(!id.isEmpty()) obj["id"] = id; obj["success"] = success; if(success) { if(typeId(value) == QMetaType::QString) obj["value"] = value.toString().toUtf8(); else obj["value"] = value; } else { obj["condition"] = condition; if(value.isValid()) { if(typeId(value) == QMetaType::QString) obj["value"] = value.toString().toUtf8(); else obj["value"] = value; } } return obj; } bool ZrpcResponsePacket::fromVariant(const QVariant &in) { if(typeId(in) != QMetaType::QVariantHash) return false; QVariantHash obj = in.toHash(); if(obj.contains("id")) { if(typeId(obj["id"]) != QMetaType::QByteArray) return false; id = obj["id"].toByteArray(); } if(!obj.contains("success") || typeId(obj["success"]) != QMetaType::Bool) return false; success = obj["success"].toBool(); value.clear(); condition.clear(); if(success) { if(!obj.contains("value")) return false; value = obj["value"]; } else { if(!obj.contains("condition") || typeId(obj["condition"]) != QMetaType::QByteArray) return false; condition = obj["condition"].toByteArray(); if(obj.contains("value")) value = obj["value"]; } return true; } pushpin-1.41.0/src/core/packet/zrpcresponsepacket.h000066400000000000000000000020151504671364300223560ustar00rootroot00000000000000/* * Copyright (C) 2014 Fanout, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef ZRPCRESPONSEPACKET_H #define ZRPCRESPONSEPACKET_H #include #include class ZrpcResponsePacket { public: QByteArray id; bool success; QVariant value; QByteArray condition; ZrpcResponsePacket() : success(false) { } QVariant toVariant() const; bool fromVariant(const QVariant &in); }; #endif pushpin-1.41.0/src/core/processquit.cpp000066400000000000000000000065041504671364300201050ustar00rootroot00000000000000/* * Copyright (C) 2006 Justin Karneges * Copyright (C) 2017 Fanout, Inc. * Copyright (C) 2025 Fastly, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include "processquit.h" #include #include #include #include #include "socketnotifier.h" Q_GLOBAL_STATIC(QMutex, pq_mutex) static ProcessQuit *g_pq = nullptr; class ProcessQuit::Private { public: ProcessQuit *q; Connection activatedConnection; bool done; int sig_pipe[2]; std::unique_ptr sig_notifier; Private(ProcessQuit *_q) : q(_q) { done = false; if(pipe(sig_pipe) == -1) { // no support then return; } sig_notifier = std::make_unique(sig_pipe[0], SocketNotifier::Read); activatedConnection = sig_notifier->activated.connect(boost::bind(&Private::sig_activated, this, boost::placeholders::_1)); sig_notifier->clearReadiness(SocketNotifier::Read); unixWatchAdd(SIGINT); unixWatchAdd(SIGHUP); unixWatchAdd(SIGTERM); } ~Private() { unixWatchRemove(SIGINT); unixWatchRemove(SIGHUP); unixWatchRemove(SIGTERM); activatedConnection.disconnect(); sig_notifier.reset(); close(sig_pipe[0]); close(sig_pipe[1]); } static void unixHandler(int sig) { Q_UNUSED(sig); unsigned char c = 0; if(sig == SIGHUP) c = 1; if(::write(g_pq->d->sig_pipe[1], &c, 1) == -1) { // TODO: error handling? return; } } void unixWatchAdd(int sig) { struct sigaction sa; sigaction(sig, NULL, &sa); // if the signal is ignored, don't take it over. this is // recommended by the glibc manual if(sa.sa_handler == SIG_IGN) return; sigemptyset(&(sa.sa_mask)); sa.sa_flags = 0; sa.sa_handler = unixHandler; sigaction(sig, &sa, 0); } void unixWatchRemove(int sig) { struct sigaction sa; sigaction(sig, NULL, &sa); // ignored means we skipped it earlier, so we should // skip it again if(sa.sa_handler == SIG_IGN) return; sigemptyset(&(sa.sa_mask)); sa.sa_flags = 0; sa.sa_handler = SIG_DFL; sigaction(sig, &sa, 0); } void sig_activated(int) { sig_notifier->clearReadiness(SocketNotifier::Read); unsigned char c; if(::read(sig_pipe[0], &c, 1) == -1) { // TODO: error handling? return; } if(c == 1) // SIGHUP { q->hup(); return; } do_emit(); } private: void do_emit() { // only signal once if(!done) { done = true; q->quit(); } } }; ProcessQuit::ProcessQuit() { d = new Private(this); } ProcessQuit::~ProcessQuit() { delete d; } ProcessQuit *ProcessQuit::instance() { QMutexLocker locker(pq_mutex()); if(!g_pq) g_pq = new ProcessQuit; return g_pq; } void ProcessQuit::reset() { QMutexLocker locker(pq_mutex()); if(g_pq) g_pq->d->done = false; } void ProcessQuit::cleanup() { delete g_pq; g_pq = nullptr; } pushpin-1.41.0/src/core/processquit.h000066400000000000000000000065231504671364300175530ustar00rootroot00000000000000/* * Copyright (C) 2006 Justin Karneges * Copyright (C) 2017 Fanout, Inc. * Copyright (C) 2025 Fastly, Inc. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #ifndef PROCESSQUIT_H #define PROCESSQUIT_H #include using Signal = boost::signals2::signal; using SignalInt = boost::signals2::signal; using Connection = boost::signals2::scoped_connection; /** \brief Listens for termination requests ProcessQuit listens for requests to terminate the application process. These are the signals SIGINT, SIGHUP, and SIGTERM. For GUI programs, ProcessQuit is not a substitute for QSessionManager. The only safe way to handle termination of a GUI program in the usual way is to use QSessionManager. However, ProcessQuit does give additional benefit to GUI programs that might be terminated unconventionally, so it can't hurt to support both. When a termination request is received, the application should exit gracefully, and generally without user interaction. Otherwise, it is at risk of being terminated outside of its control. Using ProcessQuit is easy, and it usually amounts to a single line: \code myapp.connect(ProcessQuit::instance(), SIGNAL(quit()), SLOT(do_quit())); \endcode Calling instance() returns a pointer to the global ProcessQuit instance, which will be created if necessary. The quit() signal is emitted when a request to terminate is received. The quit() signal is only emitted once, future termination requests are ignored. Call reset() to allow the quit() signal to be emitted again. */ class ProcessQuit { public: /** \brief Returns the global ProcessQuit instance If the global instance does not exist yet, it will be created, and the termination handlers will be installed. \sa cleanup */ static ProcessQuit *instance(); /** \brief Allows the quit() signal to be emitted again ProcessQuit only emits the quit() signal once, so that if a user repeatedly presses Ctrl-C or sends SIGTERM, your shutdown slot will not be called multiple times. This is normally the desired behavior, but if you are ignoring the termination request then you may want to allow future notifications. Calling this function will allow the quit() signal to be emitted again, if a new termination request arrives. \sa quit */ static void reset(); /** \brief Frees all resources used by ProcessQuit This function will free any resources used by ProcessQuit, including the global instance, and the termination handlers will be uninstalled (reverted to default). Future termination requests will cause the application to exit abruptly. \sa instance */ static void cleanup(); Signal quit; Signal hup; private: class Private; friend class Private; Private *d; ProcessQuit(); ~ProcessQuit(); }; #endif pushpin-1.41.0/src/core/qtcompat.h000066400000000000000000000021021504671364300170070ustar00rootroot00000000000000/* * Copyright (C) 2024 Fastly, Inc. * * This file is part of Pushpin. * * $FANOUT_BEGIN_LICENSE:APACHE2$ * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. * * $FANOUT_END_LICENSE$ */ #include #include inline QMetaType::Type typeId(const QVariant &v) { #if QT_VERSION >= 0x060000 return (QMetaType::Type)v.typeId(); #else return (QMetaType::Type)v.type(); #endif } inline bool canConvert(const QVariant &v, QMetaType::Type type) { #if QT_VERSION >= 0x060000 return v.canConvert(QMetaType(type)); #else return v.canConvert(type); #endif } pushpin-1.41.0/src/core/qzmqcontext.cpp000066400000000000000000000025411504671364300201160ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "qzmqcontext.h" #include #include "rust/bindings.h" using namespace ffi; namespace QZmq { Context::Context(int ioThreads) { context_ = wzmq_init(ioThreads); assert(context_); } Context::~Context() { wzmq_term(context_); } } pushpin-1.41.0/src/core/qzmqcontext.h000066400000000000000000000025111504671364300175600ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef QZMQCONTEXT_H #define QZMQCONTEXT_H namespace QZmq { class Context { public: Context(int ioThreads = 1); ~Context(); // the zmq context void *context() { return context_; } private: void *context_; }; } #endif pushpin-1.41.0/src/core/qzmqreprouter.cpp000066400000000000000000000045151504671364300204640ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "qzmqreprouter.h" #include "qzmqsocket.h" #include "qzmqreqmessage.h" namespace QZmq { class RepRouter::Private { public: RepRouter *q; std::unique_ptr sock; Connection mWConnection; Connection rrConnection; Private(RepRouter *_q) : q(_q) { sock = std::make_unique(Socket::Router); rrConnection = sock->readyRead.connect(boost::bind(&Private::sock_readyRead, this)); mWConnection = sock->messagesWritten.connect(boost::bind(&Private::sock_messagesWritten, this, boost::placeholders::_1)); } void sock_messagesWritten(int count) { q->messagesWritten(count); } void sock_readyRead() { q->readyRead(); } }; RepRouter::RepRouter() { d = std::make_unique(this); } RepRouter::~RepRouter() = default; void RepRouter::setShutdownWaitTime(int msecs) { d->sock->setShutdownWaitTime(msecs); } void RepRouter::connectToAddress(const QString &addr) { d->sock->connectToAddress(addr); } bool RepRouter::bind(const QString &addr) { return d->sock->bind(addr); } bool RepRouter::canRead() const { return d->sock->canRead(); } ReqMessage RepRouter::read() { return ReqMessage(d->sock->read()); } void RepRouter::write(const ReqMessage &message) { d->sock->write(message.toRawMessage()); } } pushpin-1.41.0/src/core/qzmqreprouter.h000066400000000000000000000035611504671364300201310ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef QZMQREPROUTER_H #define QZMQREPROUTER_H #include class QString; using Signal = boost::signals2::signal; using SignalInt = boost::signals2::signal; using Connection = boost::signals2::scoped_connection; namespace QZmq { class ReqMessage; class RepRouter { public: RepRouter(); ~RepRouter(); void setShutdownWaitTime(int msecs); void connectToAddress(const QString &addr); bool bind(const QString &addr); bool canRead() const; ReqMessage read(); void write(const ReqMessage &message); Signal readyRead; SignalInt messagesWritten; private: RepRouter(const RepRouter &) = delete; RepRouter &operator=(const RepRouter &) = delete; class Private; friend class Private; std::unique_ptr d; }; } #endif pushpin-1.41.0/src/core/qzmqreqmessage.h000066400000000000000000000042051504671364300202320ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef QZMQREQMESSAGE_H #define QZMQREQMESSAGE_H namespace QZmq { class ReqMessage { public: ReqMessage() { } ReqMessage(const QList &headers, const QList &content) : headers_(headers), content_(content) { } ReqMessage(const QList &rawMessage) { bool collectHeaders = true; foreach(const QByteArray &part, rawMessage) { if(part.isEmpty()) { collectHeaders = false; continue; } if(collectHeaders) headers_ += part; else content_ += part; } } bool isNull() const { return headers_.isEmpty() && content_.isEmpty(); } QList headers() const { return headers_; } QList content() const { return content_; } ReqMessage createReply(const QList &content) { return ReqMessage(headers_, content); } QList toRawMessage() const { QList out; out += headers_; out += QByteArray(); out += content_; return out; } private: QList headers_; QList content_; }; } #endif pushpin-1.41.0/src/core/qzmqsocket.cpp000066400000000000000000000354761504671364300177370ustar00rootroot00000000000000/* * Copyright (C) 2012-2020 Justin Karneges * Copyright (C) 2024-2025 Fastly, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "qzmqsocket.h" #include #include #include #include #include #include "rust/bindings.h" #include "qzmqcontext.h" #include "timer.h" #include "socketnotifier.h" using Connection = boost::signals2::scoped_connection; using namespace ffi; namespace QZmq { static int get_fd(void *sock) { int fd; size_t opt_len = sizeof(fd); int ret = wzmq_getsockopt(sock, WZMQ_FD, &fd, &opt_len); assert(ret == 0); return fd; } static void set_subscribe(void *sock, const char *data, int size) { size_t opt_len = size; int ret = wzmq_setsockopt(sock, WZMQ_SUBSCRIBE, data, opt_len); assert(ret == 0); } static void set_unsubscribe(void *sock, const char *data, int size) { size_t opt_len = size; wzmq_setsockopt(sock, WZMQ_UNSUBSCRIBE, data, opt_len); // note: we ignore errors, such as unsubscribing a nonexisting filter } static void set_linger(void *sock, int value) { size_t opt_len = sizeof(value); int ret = wzmq_setsockopt(sock, WZMQ_LINGER, &value, opt_len); assert(ret == 0); } static int get_identity(void *sock, char *data, int size) { size_t opt_len = size; int ret = wzmq_getsockopt(sock, WZMQ_IDENTITY, data, &opt_len); assert(ret == 0); return (int)opt_len; } static void set_identity(void *sock, const char *data, int size) { size_t opt_len = size; int ret = wzmq_setsockopt(sock, WZMQ_IDENTITY, data, opt_len); if(ret != 0) printf("%d\n", errno); assert(ret == 0); } #if WZMQ_VERSION_MAJOR >= 4 static void set_immediate(void *sock, bool on) { int v = on ? 1 : 0; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_IMMEDIATE, &v, opt_len); assert(ret == 0); } static void set_router_mandatory(void *sock, bool on) { int v = on ? 1 : 0; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_ROUTER_MANDATORY, &v, opt_len); assert(ret == 0); } static void set_probe_router(void *sock, bool on) { int v = on ? 1 : 0; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_PROBE_ROUTER, &v, opt_len); assert(ret == 0); } #else static void set_immediate(void *sock, bool on) { int v = on ? 1 : 0; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_DELAY_ATTACH_ON_CONNECT, &v, opt_len); assert(ret == 0); } #endif #if (WZMQ_VERSION_MAJOR >= 4) || ((WZMQ_VERSION_MAJOR >= 3) && (WZMQ_VERSION_MINOR >= 2)) #define USE_MSG_IO static bool get_rcvmore(void *sock) { int more; size_t opt_len = sizeof(more); int ret = wzmq_getsockopt(sock, WZMQ_RCVMORE, &more, &opt_len); assert(ret == 0); return more ? true : false; } static int get_events(void *sock) { while(true) { int events; size_t opt_len = sizeof(events); int ret = wzmq_getsockopt(sock, WZMQ_EVENTS, &events, &opt_len); if(ret == 0) { return (int)events; } assert(errno == EINTR); } } static int get_sndhwm(void *sock) { int hwm; size_t opt_len = sizeof(hwm); int ret = wzmq_getsockopt(sock, WZMQ_SNDHWM, &hwm, &opt_len); assert(ret == 0); return (int)hwm; } static void set_sndhwm(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_SNDHWM, &v, opt_len); assert(ret == 0); } static int get_rcvhwm(void *sock) { int hwm; size_t opt_len = sizeof(hwm); int ret = wzmq_getsockopt(sock, WZMQ_RCVHWM, &hwm, &opt_len); assert(ret == 0); return (int)hwm; } static void set_rcvhwm(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_RCVHWM, &v, opt_len); assert(ret == 0); } static int get_hwm(void *sock) { return get_sndhwm(sock); } static void set_hwm(void *sock, int value) { set_sndhwm(sock, value); set_rcvhwm(sock, value); } static void set_tcp_keepalive(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_TCP_KEEPALIVE, &v, opt_len); assert(ret == 0); } static void set_tcp_keepalive_idle(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_TCP_KEEPALIVE_IDLE, &v, opt_len); assert(ret == 0); } static void set_tcp_keepalive_cnt(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_TCP_KEEPALIVE_CNT, &v, opt_len); assert(ret == 0); } static void set_tcp_keepalive_intvl(void *sock, int value) { int v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_TCP_KEEPALIVE_INTVL, &v, opt_len); assert(ret == 0); } #else static bool get_rcvmore(void *sock) { qint64 more; size_t opt_len = sizeof(more); int ret = wzmq_getsockopt(sock, WZMQ_RCVMORE, &more, &opt_len); assert(ret == 0); return more ? true : false; } static int get_events(void *sock) { while(true) { quint32 events; size_t opt_len = sizeof(events); int ret = wzmq_getsockopt(sock, WZMQ_EVENTS, &events, &opt_len); if(ret == 0) { return (int)events; } assert(errno == EINTR); } } static int get_hwm(void *sock) { quint64 hwm; size_t opt_len = sizeof(hwm); int ret = wzmq_getsockopt(sock, WZMQ_HWM, &hwm, &opt_len); assert(ret == 0); return (int)hwm; } static void set_hwm(void *sock, int value) { quint64 v = value; size_t opt_len = sizeof(v); int ret = wzmq_setsockopt(sock, WZMQ_HWM, &v, opt_len); assert(ret == 0); } static int get_sndhwm(void *sock) { return get_hwm(sock); } static void set_sndhwm(void *sock, int value) { set_hwm(sock, value); } static int get_rcvhwm(void *sock) { return get_hwm(sock); } static void set_rcvhwm(void *sock, int value) { set_hwm(sock, value); } static void set_tcp_keepalive(void *sock, int value) { // not supported for this zmq version Q_UNUSED(sock); Q_UNUSED(on); } static void set_tcp_keepalive_idle(void *sock, int value) { // not supported for this zmq version Q_UNUSED(sock); Q_UNUSED(on); } static void set_tcp_keepalive_cnt(void *sock, int value) { // not supported for this zmq version Q_UNUSED(sock); Q_UNUSED(on); } static void set_tcp_keepalive_intvl(void *sock, int value) { // not supported for this zmq version Q_UNUSED(sock); Q_UNUSED(on); } #endif Q_GLOBAL_STATIC(QMutex, g_mutex) class Global { public: Context context; int refs; Global() : refs(0) { } }; static Global *global = 0; static Context *addGlobalContextRef() { QMutexLocker locker(g_mutex()); if(!global) global = new Global; ++(global->refs); return &(global->context); } static void removeGlobalContextRef() { QMutexLocker locker(g_mutex()); assert(global); assert(global->refs > 0); --(global->refs); if(global->refs == 0) { delete global; global = 0; } } class Socket::Private { public: Socket *q; bool usingGlobalContext; Context *context; void *sock; std::unique_ptr sn_read; bool canWrite, canRead; QList< QList > pendingWrites; int pendingWritten; std::unique_ptr updateTimer; Connection updateTimerConnection; bool pendingUpdate; int shutdownWaitTime; bool writeQueueEnabled; Private(Socket *_q, Socket::Type type, Context *_context) : q(_q), canWrite(false), canRead(false), pendingWritten(0), pendingUpdate(false), shutdownWaitTime(-1), writeQueueEnabled(true) { if(_context) { usingGlobalContext = false; context = _context; } else { usingGlobalContext = true; context = addGlobalContextRef(); } int ztype = 0; switch(type) { case Socket::Pair: ztype = WZMQ_PAIR; break; case Socket::Dealer: ztype = WZMQ_DEALER; break; case Socket::Router: ztype = WZMQ_ROUTER; break; case Socket::Req: ztype = WZMQ_REQ; break; case Socket::Rep: ztype = WZMQ_REP; break; case Socket::Push: ztype = WZMQ_PUSH; break; case Socket::Pull: ztype = WZMQ_PULL; break; case Socket::Pub: ztype = WZMQ_PUB; break; case Socket::Sub: ztype = WZMQ_SUB; break; default: assert(0); } sock = wzmq_socket(context->context(), ztype); assert(sock != NULL); sn_read = std::make_unique(get_fd(sock), SocketNotifier::Read); sn_read->activated.connect(boost::bind(&Private::sn_read_activated, this)); sn_read->setReadEnabled(true); updateTimer = std::make_unique(); updateTimerConnection = updateTimer->timeout.connect(boost::bind(&Private::update_timeout, this)); updateTimer->setSingleShot(true); // socket notifier starts out ready. attempt to read events if(processEvents()) { // if there are events, queue them for processing update(); } } ~Private() { sn_read.reset(); set_linger(sock, shutdownWaitTime); wzmq_close(sock); if(usingGlobalContext) removeGlobalContextRef(); } void update() { if(!pendingUpdate) { pendingUpdate = true; updateTimer->start(); } } QList read() { if(canRead) { QList out; bool ok = true; do { wzmq_msg_t msg; int ret = wzmq_msg_init(&msg); assert(ret == 0); #ifdef USE_MSG_IO ret = wzmq_msg_recv(&msg, sock, WZMQ_DONTWAIT); #else ret = wzmq_recv(sock, &msg, WZMQ_NOBLOCK); #endif if(ret < 0) { ret = wzmq_msg_close(&msg); assert(ret == 0); ok = false; break; } QByteArray buf((const char *)wzmq_msg_data(&msg), wzmq_msg_size(&msg)); ret = wzmq_msg_close(&msg); assert(ret == 0); out += buf; } while(get_rcvmore(sock)); processEvents(); if((canWrite && !pendingWrites.isEmpty()) || canRead) update(); if(ok) return out; else return QList(); } else return QList(); } void write(const QList &message) { assert(!message.isEmpty()); if(writeQueueEnabled) { pendingWrites += message; if(canWrite) update(); } else { if(zmqWrite(message)) { ++pendingWritten; } processEvents(); if(pendingWritten > 0 || canRead) update(); } } // return true if flags changed bool processEvents() { int flags = get_events(sock); sn_read->clearReadiness(SocketNotifier::Read); bool canWriteOld = canWrite; bool canReadOld = canRead; canWrite = (flags & WZMQ_POLLOUT); canRead = (flags & WZMQ_POLLIN); return (canWrite != canWriteOld || canRead != canReadOld); } bool zmqWrite(const QList &message) { for(int n = 0; n < message.count(); ++n) { const QByteArray &buf = message[n]; wzmq_msg_t msg; int ret = wzmq_msg_init_size(&msg, buf.size()); assert(ret == 0); memcpy(wzmq_msg_data(&msg), buf.data(), buf.size()); #ifdef USE_MSG_IO ret = wzmq_msg_send(&msg, sock, WZMQ_DONTWAIT | (n + 1 < message.count() ? WZMQ_SNDMORE : 0)); #else ret = wzmq_send(sock, &msg, WZMQ_NOBLOCK | (n + 1 < message.count() ? WZMQ_SNDMORE : 0)); #endif if(ret < 0) { ret = wzmq_msg_close(&msg); assert(ret == 0); return false; } ret = wzmq_msg_close(&msg); assert(ret == 0); } return true; } void tryWrite() { while(canWrite && !pendingWrites.isEmpty()) { // whether this write succeeds or not, we assume we // can't write afterwards canWrite = false; if(zmqWrite(pendingWrites.first())) { pendingWrites.removeFirst(); ++pendingWritten; } processEvents(); } } void doUpdate() { tryWrite(); if(canRead) { std::weak_ptr self = q->d; q->readyRead(); if(self.expired()) return; } if(pendingWritten > 0) { int count = pendingWritten; pendingWritten = 0; q->messagesWritten(count); } } void update_timeout() { pendingUpdate = false; doUpdate(); } void sn_read_activated() { if(!processEvents()) return; if(pendingUpdate) { pendingUpdate = false; updateTimer->stop(); } doUpdate(); } }; Socket::Socket(Type type) { d = std::make_shared(this, type, nullptr); } Socket::Socket(Type type, Context *context) { d = std::make_shared(this, type, context); } Socket::~Socket() = default; void Socket::setShutdownWaitTime(int msecs) { d->shutdownWaitTime = msecs; } void Socket::setWriteQueueEnabled(bool enable) { d->writeQueueEnabled = enable; } void Socket::subscribe(const QByteArray &filter) { set_subscribe(d->sock, filter.data(), filter.size()); } void Socket::unsubscribe(const QByteArray &filter) { set_unsubscribe(d->sock, filter.data(), filter.size()); } QByteArray Socket::identity() const { QByteArray buf(255, 0); buf.resize(get_identity(d->sock, buf.data(), buf.size())); return buf; } void Socket::setIdentity(const QByteArray &id) { set_identity(d->sock, id.data(), id.size()); } int Socket::hwm() const { return get_hwm(d->sock); } void Socket::setHwm(int hwm) { set_hwm(d->sock, hwm); } int Socket::sendHwm() const { return get_sndhwm(d->sock); } int Socket::receiveHwm() const { return get_rcvhwm(d->sock); } void Socket::setSendHwm(int hwm) { set_sndhwm(d->sock, hwm); } void Socket::setReceiveHwm(int hwm) { set_rcvhwm(d->sock, hwm); } void Socket::setImmediateEnabled(bool on) { set_immediate(d->sock, on); } void Socket::setRouterMandatoryEnabled(bool on) { set_router_mandatory(d->sock, on); } void Socket::setProbeRouterEnabled(bool on) { set_probe_router(d->sock, on); } void Socket::setTcpKeepAliveEnabled(bool on) { set_tcp_keepalive(d->sock, on ? 1 : 0); } void Socket::setTcpKeepAliveParameters(int idle, int count, int interval) { set_tcp_keepalive_idle(d->sock, idle); set_tcp_keepalive_cnt(d->sock, count); set_tcp_keepalive_intvl(d->sock, interval); } void Socket::connectToAddress(const QString &addr) { int ret = wzmq_connect(d->sock, addr.toUtf8().data()); assert(ret == 0); } bool Socket::bind(const QString &addr) { int ret = wzmq_bind(d->sock, addr.toUtf8().data()); if(ret != 0) return false; return true; } bool Socket::canRead() const { return d->canRead; } bool Socket::canWriteImmediately() const { return d->canWrite; } QList Socket::read() { return d->read(); } void Socket::write(const QList &message) { d->write(message); } } pushpin-1.41.0/src/core/qzmqsocket.h000066400000000000000000000067241504671364300173760ustar00rootroot00000000000000/* * Copyright (C) 2012-2015 Justin Karneges * Copyright (C) 2024-2025 Fastly, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef QZMQSOCKET_H #define QZMQSOCKET_H #include #include #include class QString; using Signal = boost::signals2::signal; using SignalInt = boost::signals2::signal; using Connection = boost::signals2::scoped_connection; namespace QZmq { class Context; class Socket { public: enum Type { Pair, Dealer, Router, Req, Rep, Push, Pull, Pub, Sub }; Socket(Type type); Socket(Type type, Context *context); ~Socket(); // 0 means drop queue and don't block, -1 means infinite (default = -1) void setShutdownWaitTime(int msecs); // if enabled, messages are queued internally until the socket is able // to accept them. the messagesWritten signal is emitted once writes // have succeeded. otherwise, messages are passed directly to // zmq_send and dropped if they can't be written. default enabled. // disabling the queue is good for socket types where the HWM has a // drop policy. enabling the queue is good when the HWM has a // blocking policy. void setWriteQueueEnabled(bool enable); void subscribe(const QByteArray &filter); void unsubscribe(const QByteArray &filter); QByteArray identity() const; void setIdentity(const QByteArray &id); // deprecated, zmq 2.x int hwm() const; void setHwm(int hwm); int sendHwm() const; int receiveHwm() const; void setSendHwm(int hwm); void setReceiveHwm(int hwm); void setImmediateEnabled(bool on); void setRouterMandatoryEnabled(bool on); void setProbeRouterEnabled(bool on); void setTcpKeepAliveEnabled(bool on); void setTcpKeepAliveParameters(int idle = -1, int count = -1, int interval = -1); void connectToAddress(const QString &addr); bool bind(const QString &addr); bool canRead() const; // returns true if this object believes the next write to zmq will // succeed immediately. note that it starts out false until the // value is discovered. also note that the write could still end up // needing to be queued, if the conditions change in between. bool canWriteImmediately() const; QList read(); void write(const QList &message); Signal readyRead; SignalInt messagesWritten; private: Socket(const Socket &) = delete; Socket &operator=(const Socket &) = delete; class Private; friend class Private; std::shared_ptr d; }; } #endif pushpin-1.41.0/src/core/qzmqvalve.cpp000066400000000000000000000052411504671364300175470ustar00rootroot00000000000000/* * Copyright (C) 2012-2020 Justin Karneges * Copyright (C) 2025 Fastly, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #include "qzmqvalve.h" #include "qzmqsocket.h" #include "defercall.h" namespace QZmq { class Valve::Private { public: Valve *q; QZmq::Socket *sock; bool isOpen; bool pendingRead; int maxReadsPerEvent; boost::signals2::scoped_connection rrConnection; DeferCall deferCall; Private(Valve *_q) : q(_q), sock(0), isOpen(false), pendingRead(false), maxReadsPerEvent(100) { } void setup(QZmq::Socket *_sock) { sock = _sock; rrConnection = sock->readyRead.connect(boost::bind(&Private::sock_readyRead, this)); } void queueRead() { if(pendingRead) return; pendingRead = true; deferCall.defer([=] { queuedRead(); }); } void tryRead() { std::weak_ptr self = q->d; int count = 0; while(isOpen && sock->canRead()) { if(count >= maxReadsPerEvent) { queueRead(); return; } QList msg = sock->read(); if(!msg.isEmpty()) { q->readyRead(msg); if(self.expired()) return; } ++count; } } void sock_readyRead() { if(pendingRead) return; tryRead(); } void queuedRead() { pendingRead = false; tryRead(); } }; Valve::Valve(QZmq::Socket *sock) { d = std::make_shared(this); d->setup(sock); } Valve::~Valve() = default; bool Valve::isOpen() const { return d->isOpen; } void Valve::setMaxReadsPerEvent(int max) { d->maxReadsPerEvent = max; } void Valve::open() { if(!d->isOpen) { d->isOpen = true; if(!d->pendingRead && d->sock->canRead()) d->queueRead(); } } void Valve::close() { d->isOpen = false; } } pushpin-1.41.0/src/core/qzmqvalve.h000066400000000000000000000032461504671364300172170ustar00rootroot00000000000000/* * Copyright (C) 2012 Justin Karneges * Copyright (C) 2025 Fastly, Inc. * * Permission is hereby granted, free of charge, to any person obtaining a * copy of this software and associated documentation files (the * "Software"), to deal in the Software without restriction, including * without limitation the rights to use, copy, modify, merge, publish, * distribute, sublicense, and/or sell copies of the Software, and to * permit persons to whom the Software is furnished to do so, subject to * the following conditions: * * The above copyright notice and this permission notice shall be included * in all copies or substantial portions of the Software. * * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS * OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF * MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. * IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY * CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, * TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE * SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. */ #ifndef QZMQVALVE_H #define QZMQVALVE_H #include #include #include using SignalList = boost::signals2::signal&)>; using Connection = boost::signals2::scoped_connection; namespace QZmq { class Socket; class Valve { public: Valve(QZmq::Socket *sock); ~Valve(); bool isOpen() const; void setMaxReadsPerEvent(int max); void open(); void close(); SignalList readyRead; private: class Private; friend class Private; std::shared_ptr d; }; } #endif pushpin-1.41.0/src/core/reactor.rs000066400000000000000000001110141504671364300170160ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::arena; use crate::core::event; use crate::core::event::ReadinessExt; use crate::core::timer::TimerWheel; use slab::Slab; use std::cell::{Cell, RefCell}; use std::cmp; use std::io; use std::os::unix::io::RawFd; use std::rc::{Rc, Weak}; use std::task::Waker; use std::time::{Duration, Instant}; const TICK_DURATION_MS: u64 = 10; const EXPIRE_MAX: usize = 100; thread_local! { static REACTOR: RefCell>> = const { RefCell::new(None) }; } fn duration_to_ticks_round_down(d: Duration) -> u64 { (d.as_millis() / (TICK_DURATION_MS as u128)) as u64 } fn duration_to_ticks_round_up(d: Duration) -> u64 { d.as_millis().div_ceil(TICK_DURATION_MS as u128) as u64 } fn ticks_to_duration(t: u64) -> Duration { Duration::from_millis(t * TICK_DURATION_MS) } enum WakerInterest { Single(Waker, mio::Interest), Separate(Waker, Waker), } impl WakerInterest { fn interest(&self) -> mio::Interest { match self { Self::Single(_, interest) => *interest, Self::Separate(_, _) => mio::Interest::READABLE | mio::Interest::WRITABLE, } } fn change(self, waker: &Waker, interest: mio::Interest) -> Self { match self { Self::Single(current_waker, current_interest) => { if (interest.is_readable() && interest.is_writable()) || current_interest == interest { // all interest or interest unchanged. stay using a // single waker let waker = if current_waker.will_wake(waker) { // keep the current waker current_waker } else { // switch to the new waker waker.clone() }; Self::Single(waker, interest) } else { assert!(interest.is_readable() != interest.is_writable()); // one interest was specified when we had at least the // opposite interest. switch to separate match (interest.is_readable(), interest.is_writable()) { (true, false) => Self::Separate(waker.clone(), current_waker), (false, true) => Self::Separate(current_waker, waker.clone()), _ => unreachable!(), } } } Self::Separate(read_waker, write_waker) => { match (interest.is_readable(), interest.is_writable()) { (true, true) => { // if multiple interests on one waker, switch to single let waker = if read_waker.will_wake(waker) { read_waker } else if write_waker.will_wake(waker) { write_waker } else { waker.clone() }; Self::Single(waker, interest) } (true, false) => { let read_waker = if read_waker.will_wake(waker) { // keep the current waker read_waker } else { // switch to the new waker waker.clone() }; Self::Separate(read_waker, write_waker) } (false, true) => { let write_waker = if write_waker.will_wake(waker) { // keep the current waker write_waker } else { // switch to the new waker waker.clone() }; Self::Separate(read_waker, write_waker) } (false, false) => unreachable!(), // interest always has a value } } } } fn merge(self, other: Self) -> Self { match self { Self::Single(waker, interest) => { if (interest.is_readable() && interest.is_writable()) || interest == other.interest() { // there is already a single waker of both interests or // of one interest that is the same interest as the // other. leave alone Self::Single(waker, interest) } else { assert!(interest.is_readable() != interest.is_writable()); // there is a single waker of one interest, and the other // has at least the opposite interest. switch to separate match (interest.is_readable(), interest.is_writable()) { (true, false) => { let other_waker = match other { Self::Single(waker, _) => waker, Self::Separate(_, waker) => waker, }; Self::Separate(waker, other_waker) } (false, true) => { let other_waker = match other { Self::Single(waker, _) => waker, Self::Separate(waker, _) => waker, }; Self::Separate(other_waker, waker) } _ => unreachable!(), } } } separate => { // there are already separate wakers for both interests. // leave alone separate } } } fn clear_interest(self, interest: mio::Interest) -> Option { match self { Self::Single(waker, cur) => cur.remove(interest).map(|i| Self::Single(waker, i)), Self::Separate(read_waker, write_waker) => { match (interest.is_readable(), interest.is_writable()) { (true, true) => None, // clear all (true, false) => Some(Self::Single(write_waker, mio::Interest::WRITABLE)), (false, true) => Some(Self::Single(read_waker, mio::Interest::READABLE)), (false, false) => unreachable!(), // interest always has a value } } } } fn wake(self, readiness: mio::Interest) -> Option { match self { Self::Single(waker, interest) => { if (interest.is_readable() && readiness.is_readable()) || (interest.is_writable() && readiness.is_writable()) { waker.wake(); None } else { Some(Self::Single(waker, interest)) } } Self::Separate(read_waker, write_waker) => { match (readiness.is_readable(), readiness.is_writable()) { (true, true) => { read_waker.wake(); write_waker.wake(); None } (true, false) => { read_waker.wake(); Some(Self::Single(write_waker, mio::Interest::WRITABLE)) } (false, true) => { write_waker.wake(); Some(Self::Single(read_waker, mio::Interest::READABLE)) } (false, false) => unreachable!(), // interest always has a value } } } } fn wake_by_ref(&self, readiness: mio::Interest) { match self { Self::Single(waker, interest) => { if (interest.is_readable() && readiness.is_readable()) || (interest.is_writable() && readiness.is_writable()) { waker.wake_by_ref(); } } Self::Separate(read_waker, write_waker) => { if readiness.is_readable() { read_waker.wake_by_ref(); } if readiness.is_writable() { write_waker.wake_by_ref(); } } } } } pub struct Registration { reactor: Weak, key: usize, } impl Registration { pub fn reactor(&self) -> Reactor { let reactor = self.reactor.upgrade().expect("reactor is gone"); Reactor { inner: reactor } } pub fn set_waker_persistent(&self, enabled: bool) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.waker_persistent = enabled; } pub fn readiness(&self) -> event::Readiness { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &*reactor.registrations.borrow(); let reg_data = ®istrations[self.key]; reg_data.readiness } pub fn set_readiness(&self, readiness: event::Readiness) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.readiness = readiness; } pub fn clear_readiness(&self, readiness: mio::Interest) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; if let Some(cur) = reg_data.readiness.take() { reg_data.readiness = cur.remove(readiness); } } pub fn is_ready(&self) -> bool { self.readiness().is_some() } pub fn set_ready(&self, ready: bool) { let readiness = if ready { Some(mio::Interest::READABLE) } else { None }; self.set_readiness(readiness); } pub fn set_waker(&self, waker: &Waker, interest: mio::Interest) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; match reg_data.waker.take() { Some(wi) => reg_data.waker = Some(wi.change(waker, interest)), None => reg_data.waker = Some(WakerInterest::Single(waker.clone(), interest)), } } pub fn clear_waker(&self) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; reg_data.waker = None; } pub fn clear_waker_interest(&self, interest: mio::Interest) { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; if let Some(wi) = reg_data.waker.take() { reg_data.waker = wi.clear_interest(interest); } } pub fn deregister_io(&self, source: &mut S) -> Result<(), io::Error> { let reactor = self.reactor.upgrade().expect("reactor is gone"); let poll = &reactor.poll.borrow(); poll.deregister(source) } pub fn deregister_custom(&self, handle: &event::Registration) -> Result<(), io::Error> { let reactor = self.reactor.upgrade().expect("reactor is gone"); let poll = &reactor.poll.borrow(); poll.deregister_custom(handle) } pub fn deregister_custom_local( &self, handle: &event::LocalRegistration, ) -> Result<(), io::Error> { let reactor = self.reactor.upgrade().expect("reactor is gone"); let poll = &reactor.poll.borrow(); poll.deregister_custom_local(handle) } pub fn pull_from_budget(&self) -> bool { let reactor = self.reactor.upgrade().expect("reactor is gone"); let mut registrations = reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; if reg_data.waker.is_none() { panic!("pull_from_budget requires a waker to be set"); } let ok = self.pull_from_budget_inner(); if !ok { let wi = reg_data.waker.take().unwrap(); let persistent = reg_data.waker_persistent; drop(registrations); let wi_remaining = if persistent { wi.wake_by_ref(mio::Interest::READABLE | mio::Interest::WRITABLE); Some(wi) } else { wi.wake(mio::Interest::READABLE | mio::Interest::WRITABLE) }; if let Some(wi_remaining) = wi_remaining { let mut registrations = reactor.registrations.borrow_mut(); if let Some(event_reg) = registrations.get_mut(self.key) { match event_reg.waker.take() { Some(wi) => event_reg.waker = Some(wi.merge(wi_remaining)), None => event_reg.waker = Some(wi_remaining), } } } } ok } pub fn pull_from_budget_with_waker(&self, waker: &Waker) -> bool { let ok = self.pull_from_budget_inner(); if !ok { waker.wake_by_ref(); } ok } fn pull_from_budget_inner(&self) -> bool { let reactor = self.reactor.upgrade().expect("reactor is gone"); let budget = &mut *reactor.budget.borrow_mut(); match budget { Some(budget) => { if *budget > 0 { *budget -= 1; true } else { false } } None => true, } } pub fn reregister_timer(&self, expires: Instant) -> Result<(), io::Error> { let reactor = self.reactor.upgrade().expect("reactor is gone"); let registrations = &mut *reactor.registrations.borrow_mut(); let reg_data = &mut registrations[self.key]; let timer = &mut *reactor.timer.borrow_mut(); if let Some(timer_key) = reg_data.timer_key { timer.wheel.remove(timer_key); } let expires_ticks = duration_to_ticks_round_up(expires - timer.start); let timer_key = match timer.wheel.add(expires_ticks, self.key) { Ok(timer_key) => timer_key, Err(_) => return Err(io::Error::from(io::ErrorKind::Other)), }; reg_data.timer_key = Some(timer_key); Ok(()) } } impl Drop for Registration { fn drop(&mut self) { if let Some(reactor) = self.reactor.upgrade() { let registrations = &mut *reactor.registrations.borrow_mut(); if let Some(timer_key) = registrations[self.key].timer_key { let timer = &mut *reactor.timer.borrow_mut(); timer.wheel.remove(timer_key); } registrations.remove(self.key); } } } struct RegistrationData { readiness: event::Readiness, waker: Option, timer_key: Option, waker_persistent: bool, } struct TimerData { wheel: TimerWheel, start: Instant, current_ticks: u64, } struct ReactorData { registrations: RefCell>, poll: RefCell, timer: RefCell, budget: RefCell>, } #[derive(Clone)] pub struct Reactor { inner: Rc, } impl Reactor { pub fn new(registrations_max: usize) -> Self { Self::new_with_time(registrations_max, Instant::now()) } pub fn new_with_time(registrations_max: usize, start_time: Instant) -> Self { let timer_data = TimerData { wheel: TimerWheel::new(registrations_max), start: start_time, current_ticks: 0, }; let inner = Rc::new(ReactorData { registrations: RefCell::new(Slab::with_capacity(registrations_max)), poll: RefCell::new(event::Poller::new(registrations_max).unwrap()), timer: RefCell::new(timer_data), budget: RefCell::new(None), }); REACTOR.with(|r| { if r.borrow().is_some() { panic!("thread already has a Reactor"); } r.replace(Some(Rc::downgrade(&inner))); }); Self { inner } } pub fn register_io( &self, source: &mut S, interest: mio::Interest, ) -> Result where S: mio::event::Source + ?Sized, { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, waker_persistent: false, }); if let Err(e) = self .inner .poll .borrow() .register(source, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_custom( &self, handle: &event::Registration, interest: mio::Interest, ) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, waker_persistent: false, }); if let Err(e) = self.inner .poll .borrow() .register_custom(handle, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_custom_local( &self, handle: &event::LocalRegistration, interest: mio::Interest, ) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, waker_persistent: false, }); if let Err(e) = self.inner .poll .borrow() .register_custom_local(handle, mio::Token(key + 1), interest) { registrations.remove(key); return Err(e); } Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } pub fn register_timer(&self, expires: Instant) -> Result { let registrations = &mut *self.inner.registrations.borrow_mut(); if registrations.len() == registrations.capacity() { return Err(io::Error::from(io::ErrorKind::WriteZero)); } let key = registrations.insert(RegistrationData { readiness: None, waker: None, timer_key: None, waker_persistent: false, }); let timer = &mut *self.inner.timer.borrow_mut(); let expires_ticks = duration_to_ticks_round_up(expires - timer.start); let timer_key = match timer.wheel.add(expires_ticks, key) { Ok(timer_key) => timer_key, Err(_) => { registrations.remove(key); return Err(io::Error::from(io::ErrorKind::Other)); } }; registrations[key].timer_key = Some(timer_key); Ok(Registration { reactor: Rc::downgrade(&self.inner), key, }) } // we advance time after polling. this way, Reactor::now() is accurate // during task processing. we assume the actual time doesn't change much // between task processing and the next poll pub fn poll(&self, timeout: Option) -> Result<(), io::Error> { self.poll_for_events(self.next_timeout(timeout))?; self.advance_time(Instant::now()); self.process_events(); Ok(()) } // return the timeout that would have been used for a blocking poll pub fn poll_nonblocking(&self, current_time: Instant) -> Result, io::Error> { let timeout = self.next_timeout(None); self.poll_for_events(Some(Duration::from_millis(0)))?; self.advance_time(current_time); self.process_events(); Ok(timeout) } pub fn now(&self) -> Instant { let timer = &*self.inner.timer.borrow(); timer.start + ticks_to_duration(timer.current_ticks) } pub fn set_budget(&self, budget: Option) { *self.inner.budget.borrow_mut() = budget; } pub fn current() -> Option { REACTOR.with(|r| { (*r.borrow_mut()).as_mut().map(|inner| Self { inner: inner.upgrade().unwrap(), }) }) } pub fn local_registration_memory(&self) -> Rc> { self.inner.poll.borrow().local_registration_memory().clone() } fn next_timeout(&self, user_timeout: Option) -> Option { let timer = &mut *self.inner.timer.borrow_mut(); let timer_timeout = timer.wheel.timeout().map(ticks_to_duration); match user_timeout { Some(user_timeout) => Some(match timer_timeout { Some(timer_timeout) => cmp::min(user_timeout, timer_timeout), None => user_timeout, }), None => timer_timeout, } } fn poll_for_events(&self, timeout: Option) -> Result<(), io::Error> { let poll = &mut *self.inner.poll.borrow_mut(); poll.poll(timeout) } fn advance_time(&self, current_time: Instant) { let timer = &mut *self.inner.timer.borrow_mut(); timer.current_ticks = duration_to_ticks_round_down(current_time - timer.start); timer.wheel.update(timer.current_ticks); } fn process_events(&self) { let poll = &mut *self.inner.poll.borrow_mut(); for event in poll.iter_events() { let key = usize::from(event.token()); assert!(key > 0); let key = key - 1; let mut registrations = self.inner.registrations.borrow_mut(); if let Some(event_reg) = registrations.get_mut(key) { let event_readiness = event.readiness(); let (became_readable, became_writable) = { let prev_readiness = event_reg.readiness; event_reg.readiness.merge(event_readiness); ( !prev_readiness.contains_any(mio::Interest::READABLE) && event_reg.readiness.contains_any(mio::Interest::READABLE), !prev_readiness.contains_any(mio::Interest::WRITABLE) && event_reg.readiness.contains_any(mio::Interest::WRITABLE), ) }; if became_readable || became_writable { if let Some(wi) = event_reg.waker.take() { let interest = wi.interest(); if (became_readable && interest.is_readable()) || (became_writable && interest.is_writable()) { let persistent = event_reg.waker_persistent; drop(registrations); let wi_remaining = if persistent { wi.wake_by_ref(event_readiness); Some(wi) } else { wi.wake(event_readiness) }; if let Some(wi_remaining) = wi_remaining { let mut registrations = self.inner.registrations.borrow_mut(); if let Some(event_reg) = registrations.get_mut(key) { match event_reg.waker.take() { Some(wi) => event_reg.waker = Some(wi.merge(wi_remaining)), None => event_reg.waker = Some(wi_remaining), } } } } } } } } let timer = &mut *self.inner.timer.borrow_mut(); let mut expire_count = 0; while let Some((_, key)) = timer.wheel.take_expired() { let mut registrations = self.inner.registrations.borrow_mut(); if let Some(event_reg) = registrations.get_mut(key) { event_reg.readiness = Some(mio::Interest::READABLE); event_reg.timer_key = None; if let Some(wi) = event_reg.waker.take() { let persistent = event_reg.waker_persistent; drop(registrations); let wi_remaining = if persistent { wi.wake_by_ref(mio::Interest::READABLE); Some(wi) } else { wi.wake(mio::Interest::READABLE) }; if let Some(wi_remaining) = wi_remaining { let mut registrations = self.inner.registrations.borrow_mut(); if let Some(event_reg) = registrations.get_mut(key) { match event_reg.waker.take() { Some(wi) => event_reg.waker = Some(wi.merge(wi_remaining)), None => event_reg.waker = Some(wi_remaining), } } } } } expire_count += 1; if expire_count >= EXPIRE_MAX { break; } } } } impl Drop for Reactor { fn drop(&mut self) { REACTOR.with(|r| { if Rc::strong_count(&self.inner) == 1 { r.replace(None); } }); } } struct IoEventedInner { registration: Registration, io: S, } pub struct IoEvented { inner: Option>, } impl IoEvented { pub fn new(mut io: S, interest: mio::Interest, reactor: &Reactor) -> Result { let registration = reactor.register_io(&mut io, interest)?; Ok(Self { inner: Some(IoEventedInner { registration, io }), }) } pub fn registration(&self) -> &Registration { &self.inner.as_ref().unwrap().registration } pub fn io(&self) -> &S { &self.inner.as_ref().unwrap().io } // return registration and io object, without deregistering it pub fn into_parts(mut self) -> (Registration, S) { let inner = self.inner.take().unwrap(); (inner.registration, inner.io) } pub fn into_inner(mut self) -> S { let mut inner = self.inner.take().unwrap(); inner.registration.deregister_io(&mut inner.io).unwrap(); inner.io } } impl Drop for IoEvented { fn drop(&mut self) { if let Some(mut inner) = self.inner.take() { inner.registration.deregister_io(&mut inner.io).unwrap(); } } } pub struct FdEvented { registration: Registration, fd: RawFd, } impl FdEvented { pub fn new(fd: RawFd, interest: mio::Interest, reactor: &Reactor) -> Result { let registration = reactor.register_io(&mut mio::unix::SourceFd(&fd), interest)?; Ok(Self { registration, fd }) } pub fn registration(&self) -> &Registration { &self.registration } pub fn fd(&self) -> &RawFd { &self.fd } } impl Drop for FdEvented { fn drop(&mut self) { self.registration() .deregister_io(&mut mio::unix::SourceFd(&self.fd)) .unwrap(); } } pub struct CustomEvented { registration: Registration, } impl CustomEvented { pub fn new( event_reg: &event::Registration, interest: mio::Interest, reactor: &Reactor, ) -> Result { let registration = reactor.register_custom(event_reg, interest)?; Ok(Self { registration }) } pub fn new_local( event_reg: &event::LocalRegistration, interest: mio::Interest, reactor: &Reactor, ) -> Result { let registration = reactor.register_custom_local(event_reg, interest)?; Ok(Self { registration }) } pub fn registration(&self) -> &Registration { &self.registration } } pub struct TimerEvented { registration: Registration, expires: Cell, } impl TimerEvented { pub fn new(expires: Instant, reactor: &Reactor) -> Result { let registration = reactor.register_timer(expires)?; Ok(Self { registration, expires: Cell::new(expires), }) } pub fn registration(&self) -> &Registration { &self.registration } pub fn expires(&self) -> Instant { self.expires.get() } pub fn set_expires(&self, expires: Instant) -> Result<(), io::Error> { if self.expires.get() == expires { // no change return Ok(()); } self.expires.set(expires); self.registration.reregister_timer(expires) } } #[cfg(test)] mod tests { use super::*; use crate::core::waker; use std::cell::Cell; use std::mem; use std::os::unix::io::AsRawFd; use std::rc::Rc; use std::thread; struct TestWaker { waked: Cell, } impl TestWaker { fn new() -> Self { Self { waked: Cell::new(false), } } fn into_std(self: Rc) -> Waker { waker::into_std(self) } fn was_waked(&self) -> bool { self.waked.get() } } impl waker::RcWake for TestWaker { fn wake(self: Rc) { self.waked.set(true); } } #[test] fn test_reactor_io() { let reactor = Reactor::new(1); let addr = "127.0.0.1:0".parse().unwrap(); let listener = mio::net::TcpListener::bind(addr).unwrap(); let evented = IoEvented::new(listener, mio::Interest::READABLE, &reactor).unwrap(); let addr = evented.io().local_addr().unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { std::net::TcpStream::connect(addr).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_fd() { let reactor = Reactor::new(1); let addr: std::net::SocketAddr = "127.0.0.1:0".parse().unwrap(); let listener = std::net::TcpListener::bind(addr).unwrap(); let evented = FdEvented::new(listener.as_raw_fd(), mio::Interest::READABLE, &reactor).unwrap(); let addr = listener.local_addr().unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { std::net::TcpStream::connect(addr).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_custom() { let reactor = Reactor::new(1); let (reg, sr) = event::Registration::new(); let evented = CustomEvented::new(®, mio::Interest::READABLE, &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); let thread = thread::spawn(move || { sr.set_readiness(mio::Interest::READABLE).unwrap(); }); assert_eq!(waker.was_waked(), false); reactor.poll(None).unwrap(); assert_eq!(waker.was_waked(), true); thread.join().unwrap(); } #[test] fn test_reactor_timer() { let now = Instant::now(); let reactor = Reactor::new_with_time(1, now); let evented = TimerEvented::new(now + Duration::from_millis(100), &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(waker.was_waked(), false); assert_eq!(reactor.now(), now); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(20)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(100))); assert_eq!(reactor.now(), now + Duration::from_millis(20)); assert_eq!(waker.was_waked(), false); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(40)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(80))); assert_eq!(reactor.now(), now + Duration::from_millis(40)); assert_eq!(waker.was_waked(), false); let timeout = reactor .poll_nonblocking(now + Duration::from_millis(100)) .unwrap(); assert_eq!(timeout, Some(Duration::from_millis(60))); assert_eq!(waker.was_waked(), true); assert_eq!(reactor.now(), now + Duration::from_millis(100)); } #[test] fn test_reactor_current() { assert!(Reactor::current().is_none()); let reactor = Reactor::new(1); let current = Reactor::current().unwrap(); mem::drop(reactor); assert!(Reactor::current().is_some()); mem::drop(current); assert!(Reactor::current().is_none()); } #[test] fn test_reactor_budget() { let reactor = Reactor::new(1); let (reg, _) = event::Registration::new(); let evented = CustomEvented::new(®, mio::Interest::READABLE, &reactor).unwrap(); let waker = Rc::new(TestWaker::new()); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(evented.registration().pull_from_budget(), true); assert_eq!(waker.was_waked(), false); reactor.set_budget(Some(0)); assert_eq!(evented.registration().pull_from_budget(), false); assert_eq!(waker.was_waked(), true); let waker = Rc::new(TestWaker::new()); reactor.set_budget(Some(1)); evented .registration() .set_waker(&waker.clone().into_std(), mio::Interest::READABLE); assert_eq!(evented.registration().pull_from_budget(), true); assert_eq!(waker.was_waked(), false); assert_eq!(evented.registration().pull_from_budget(), false); assert_eq!(waker.was_waked(), true); } } pushpin-1.41.0/src/core/readwrite.h000066400000000000000000000022421504671364300171520ustar00rootroot00000000000000/* * Copyright (C) 2025 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ #ifndef READWRITE_H #define READWRITE_H #include #include class ReadWrite { public: virtual ~ReadWrite() = default; // size < 0 means default read size // returns buffer of bytes read. null buffer means error. empty means end virtual QByteArray read(int size = -1) = 0; // returns amount accepted, or -1 for error virtual int write(const QByteArray &buf) = 0; // returns errno of latest operation virtual int errorCondition() const = 0; boost::signals2::signal readReady; boost::signals2::signal writeReady; }; #endif pushpin-1.41.0/src/core/select.rs000066400000000000000000000120331504671364300166370ustar00rootroot00000000000000/* * Copyright (C) 2020-2023 Fanout, Inc. * Copyright (C) 2024 Fastly, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. * You may obtain a copy of the License at * * http://www.apache.org/licenses/LICENSE-2.0 * * Unless required by applicable law or agreed to in writing, software * distributed under the License is distributed on an "AS IS" BASIS, * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. * See the License for the specific language governing permissions and * limitations under the License. */ use crate::core::shuffle::shuffle; use paste::paste; use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; fn range_unordered(dest: &mut [usize]) -> &[usize] { for (index, v) in dest.iter_mut().enumerate() { *v = index; } shuffle(dest); dest } fn map_poll(cx: &mut Context, fut: &mut F, wrap_func: W) -> Poll where F: Future + Unpin, W: FnOnce(F::Output) -> V, { match Pin::new(fut).poll(cx) { Poll::Ready(v) => Poll::Ready(wrap_func(v)), Poll::Pending => Poll::Pending, } } macro_rules! declare_select { ($count: literal, ( $($num:literal),* )) => { paste! { pub enum []<$([], )*> { $( []: [], )* } impl<$([], )*> Future for []<$([]::Output, )*>; fn poll(mut self: Pin<&mut Self>, cx: &mut Context) -> Poll { let mut indexes = [0; $count]; for i in range_unordered(&mut indexes) { let s = &mut *self; let p = match i + 1 { $( $num => map_poll(cx, &mut s.[], |v| []<$([], )*> where $( []: Future + Unpin, )* { [