pax_global_header00006660000000000000000000000064151225641060014514gustar00rootroot0000000000000052 comment=767b9c469fac4af7c5495d95b174b82b84ef6835 sctp-1.9.0/000077500000000000000000000000001512256410600124745ustar00rootroot00000000000000sctp-1.9.0/.github/000077500000000000000000000000001512256410600140345ustar00rootroot00000000000000sctp-1.9.0/.github/.gitignore000066400000000000000000000001561512256410600160260ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT .goassets sctp-1.9.0/.github/fetch-scripts.sh000077500000000000000000000016001512256410600171460ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT set -eu SCRIPT_PATH="$(realpath "$(dirname "$0")")" GOASSETS_PATH="${SCRIPT_PATH}/.goassets" GOASSETS_REF=${GOASSETS_REF:-master} if [ -d "${GOASSETS_PATH}" ]; then if ! git -C "${GOASSETS_PATH}" diff --exit-code; then echo "${GOASSETS_PATH} has uncommitted changes" >&2 exit 1 fi git -C "${GOASSETS_PATH}" fetch origin git -C "${GOASSETS_PATH}" checkout ${GOASSETS_REF} git -C "${GOASSETS_PATH}" reset --hard origin/${GOASSETS_REF} else git clone -b ${GOASSETS_REF} https://github.com/pion/.goassets.git "${GOASSETS_PATH}" fi sctp-1.9.0/.github/install-hooks.sh000077500000000000000000000011221512256410600171560ustar00rootroot00000000000000#!/bin/sh # # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT SCRIPT_PATH="$(realpath "$(dirname "$0")")" . ${SCRIPT_PATH}/fetch-scripts.sh cp "${GOASSETS_PATH}/hooks/commit-msg.sh" "${SCRIPT_PATH}/../.git/hooks/commit-msg" cp "${GOASSETS_PATH}/hooks/pre-commit.sh" "${SCRIPT_PATH}/../.git/hooks/pre-commit" sctp-1.9.0/.github/workflows/000077500000000000000000000000001512256410600160715ustar00rootroot00000000000000sctp-1.9.0/.github/workflows/api.yaml000066400000000000000000000011141512256410600175230ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: API on: pull_request: jobs: check: uses: pion/.goassets/.github/workflows/api.reusable.yml@master sctp-1.9.0/.github/workflows/codeql-analysis.yml000066400000000000000000000013201512256410600217000ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: CodeQL on: workflow_dispatch: schedule: - cron: '23 5 * * 0' pull_request: branches: - master paths: - '**.go' jobs: analyze: uses: pion/.goassets/.github/workflows/codeql-analysis.reusable.yml@master sctp-1.9.0/.github/workflows/fuzz.yaml000066400000000000000000000013421512256410600177530ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Fuzz on: push: branches: - master schedule: - cron: "0 */8 * * *" jobs: fuzz: uses: pion/.goassets/.github/workflows/fuzz.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version fuzz-time: "60s" sctp-1.9.0/.github/workflows/lint.yaml000066400000000000000000000011151512256410600177210ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Lint on: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/lint.reusable.yml@master sctp-1.9.0/.github/workflows/release.yml000066400000000000000000000012501512256410600202320ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Release on: push: tags: - 'v*' jobs: release: uses: pion/.goassets/.github/workflows/release.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version sctp-1.9.0/.github/workflows/renovate-go-sum-fix.yaml000066400000000000000000000012671512256410600225770ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Fix go.sum on: push: branches: - renovate/* jobs: fix: uses: pion/.goassets/.github/workflows/renovate-go-sum-fix.reusable.yml@master secrets: token: ${{ secrets.PIONBOT_PRIVATE_KEY }} sctp-1.9.0/.github/workflows/reuse.yml000066400000000000000000000011511512256410600177350ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: REUSE Compliance Check on: push: pull_request: jobs: lint: uses: pion/.goassets/.github/workflows/reuse.reusable.yml@master sctp-1.9.0/.github/workflows/test.yaml000066400000000000000000000033271512256410600177410ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Test on: push: branches: - master pull_request: jobs: test: uses: pion/.goassets/.github/workflows/test.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} secrets: inherit test-i386: uses: pion/.goassets/.github/workflows/test-i386.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-windows: uses: pion/.goassets/.github/workflows/test-windows.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-macos: uses: pion/.goassets/.github/workflows/test-macos.reusable.yml@master strategy: matrix: go: ["1.25", "1.24"] # auto-update/supported-go-version-list fail-fast: false with: go-version: ${{ matrix.go }} test-wasm: uses: pion/.goassets/.github/workflows/test-wasm.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version secrets: inherit sctp-1.9.0/.github/workflows/tidy-check.yaml000066400000000000000000000013021512256410600207750ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # If this repository should have package specific CI config, # remove the repository name from .goassets/.github/workflows/assets-sync.yml. # # If you want to update the shared CI config, send a PR to # https://github.com/pion/.goassets instead of this repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT name: Go mod tidy on: pull_request: push: branches: - master jobs: tidy: uses: pion/.goassets/.github/workflows/tidy-check.reusable.yml@master with: go-version: "1.25" # auto-update/latest-go-version sctp-1.9.0/.gitignore000066400000000000000000000006321512256410600144650ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT ### JetBrains IDE ### ##################### .idea/ ### Emacs Temporary Files ### ############################# *~ ### Folders ### ############### bin/ vendor/ node_modules/ ### Files ### ############# *.ivf *.ogg tags cover.out *.sw[poe] *.wasm examples/sfu-ws/cert.pem examples/sfu-ws/key.pem wasm_exec.js sctp-1.9.0/.golangci.yml000066400000000000000000000202661512256410600150660ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT version: "2" linters: enable: - asciicheck # Simple linter to check that your code does not contain non-ASCII identifiers - bidichk # Checks for dangerous unicode character sequences - bodyclose # checks whether HTTP response body is closed successfully - containedctx # containedctx is a linter that detects struct contained context.Context field - contextcheck # check the function whether use a non-inherited context - cyclop # checks function and package cyclomatic complexity - decorder # check declaration order and count of types, constants, variables and functions - dogsled # Checks assignments with too many blank identifiers (e.g. x, _, _, _, := f()) - dupl # Tool for code clone detection - durationcheck # check for two durations multiplied together - err113 # Golang linter to check the errors handling expressions - errcheck # Errcheck is a program for checking for unchecked errors in go programs. These unchecked errors can be critical bugs in some cases - errchkjson # Checks types passed to the json encoding functions. Reports unsupported types and optionally reports occations, where the check for the returned error can be omitted. - errname # Checks that sentinel errors are prefixed with the `Err` and error types are suffixed with the `Error`. - errorlint # errorlint is a linter for that can be used to find code that will cause problems with the error wrapping scheme introduced in Go 1.13. - exhaustive # check exhaustiveness of enum switch statements - forbidigo # Forbids identifiers - forcetypeassert # finds forced type assertions - gochecknoglobals # Checks that no globals are present in Go code - gocognit # Computes and checks the cognitive complexity of functions - goconst # Finds repeated strings that could be replaced by a constant - gocritic # The most opinionated Go source code linter - gocyclo # Computes and checks the cyclomatic complexity of functions - godot # Check if comments end in a period - godox # Tool for detection of FIXME, TODO and other comment keywords - goheader # Checks is file header matches to pattern - gomoddirectives # Manage the use of 'replace', 'retract', and 'excludes' directives in go.mod. - goprintffuncname # Checks that printf-like functions are named with `f` at the end - gosec # Inspects source code for security problems - govet # Vet examines Go source code and reports suspicious constructs, such as Printf calls whose arguments do not align with the format string - grouper # An analyzer to analyze expression groups. - importas # Enforces consistent import aliases - ineffassign # Detects when assignments to existing variables are not used - lll # Reports long lines - maintidx # maintidx measures the maintainability index of each function. - makezero # Finds slice declarations with non-zero initial length - misspell # Finds commonly misspelled English words in comments - nakedret # Finds naked returns in functions greater than a specified function length - nestif # Reports deeply nested if statements - nilerr # Finds the code that returns nil even if it checks that the error is not nil. - nilnil # Checks that there is no simultaneous return of `nil` error and an invalid value. - nlreturn # nlreturn checks for a new line before return and branch statements to increase code clarity - noctx # noctx finds sending http request without context.Context - predeclared # find code that shadows one of Go's predeclared identifiers - revive # golint replacement, finds style mistakes - staticcheck # Staticcheck is a go vet on steroids, applying a ton of static analysis checks - tagliatelle # Checks the struct tags. - thelper # thelper detects golang test helpers without t.Helper() call and checks the consistency of test helpers - unconvert # Remove unnecessary type conversions - unparam # Reports unused function parameters - unused # Checks Go code for unused constants, variables, functions and types - varnamelen # checks that the length of a variable's name matches its scope - wastedassign # wastedassign finds wasted assignment statements - whitespace # Tool for detection of leading and trailing whitespace disable: - depguard # Go linter that checks if package imports are in a list of acceptable packages - funlen # Tool for detection of long functions - gochecknoinits # Checks that no init functions are present in Go code - gomodguard # Allow and block list linter for direct Go module dependencies. This is different from depguard where there are different block types for example version constraints and module recommendations. - interfacebloat # A linter that checks length of interface. - ireturn # Accept Interfaces, Return Concrete Types - mnd # An analyzer to detect magic numbers - nolintlint # Reports ill-formed or insufficient nolint directives - paralleltest # paralleltest detects missing usage of t.Parallel() method in your Go test - prealloc # Finds slice declarations that could potentially be preallocated - promlinter # Check Prometheus metrics naming via promlint - rowserrcheck # checks whether Err of rows is checked successfully - sqlclosecheck # Checks that sql.Rows and sql.Stmt are closed. - testpackage # linter that makes you use a separate _test package - tparallel # tparallel detects inappropriate usage of t.Parallel() method in your Go test codes - wrapcheck # Checks that errors returned from external packages are wrapped - wsl # Whitespace Linter - Forces you to use empty lines! settings: staticcheck: checks: - all - -QF1008 # "could remove embedded field", to keep it explicit! - -QF1003 # "could use tagged switch on enum", Cases conflicts with exhaustive! exhaustive: default-signifies-exhaustive: true forbidigo: forbid: - pattern: ^fmt.Print(f|ln)?$ - pattern: ^log.(Panic|Fatal|Print)(f|ln)?$ - pattern: ^os.Exit$ - pattern: ^panic$ - pattern: ^print(ln)?$ - pattern: ^testing.T.(Error|Errorf|Fatal|Fatalf|Fail|FailNow)$ pkg: ^testing$ msg: use testify/assert instead analyze-types: true gomodguard: blocked: modules: - github.com/pkg/errors: recommendations: - errors govet: enable: - shadow revive: rules: # Prefer 'any' type alias over 'interface{}' for Go 1.18+ compatibility - name: use-any severity: warning disabled: false misspell: locale: US varnamelen: max-distance: 12 min-name-length: 2 ignore-type-assert-ok: true ignore-map-index-ok: true ignore-chan-recv-ok: true ignore-decls: - i int - n int - w io.Writer - r io.Reader - b []byte exclusions: generated: lax rules: - linters: - forbidigo - gocognit path: (examples|main\.go) - linters: - gocognit path: _test\.go - linters: - forbidigo path: cmd formatters: enable: - gci # Gci control golang package import order and make it always deterministic. - gofmt # Gofmt checks whether code was gofmt-ed. By default this tool runs with -s option to check for code simplification - gofumpt # Gofumpt checks whether code was gofumpt-ed. - goimports # Goimports does everything that gofmt does. Additionally it checks unused imports exclusions: generated: lax sctp-1.9.0/.goreleaser.yml000066400000000000000000000001711512256410600154240ustar00rootroot00000000000000# SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT builds: - skip: true sctp-1.9.0/.reuse/000077500000000000000000000000001512256410600136755ustar00rootroot00000000000000sctp-1.9.0/.reuse/dep5000066400000000000000000000011141512256410600144520ustar00rootroot00000000000000Format: https://www.debian.org/doc/packaging-manuals/copyright-format/1.0/ Upstream-Name: Pion Source: https://github.com/pion/ Files: README.md DESIGN.md **/README.md AUTHORS.txt renovate.json go.mod go.sum **/go.mod **/go.sum .eslintrc.json package.json examples.json sfu-ws/flutter/.gitignore sfu-ws/flutter/pubspec.yaml c-data-channels/webrtc.h examples/examples.json yarn.lock Copyright: 2023 The Pion community License: MIT Files: testdata/seed/* testdata/fuzz/* **/testdata/fuzz/* api/*.txt Copyright: 2023 The Pion community License: CC0-1.0 sctp-1.9.0/LICENSE000066400000000000000000000021051512256410600134770ustar00rootroot00000000000000MIT License Copyright (c) 2023 The Pion community Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sctp-1.9.0/LICENSES/000077500000000000000000000000001512256410600137015ustar00rootroot00000000000000sctp-1.9.0/LICENSES/MIT.txt000066400000000000000000000020661512256410600150770ustar00rootroot00000000000000MIT License Copyright (c) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. sctp-1.9.0/README.md000066400000000000000000000115351512256410600137600ustar00rootroot00000000000000


Pion SCTP

A Go implementation of SCTP

Pion SCTP join us on Discord Follow us on Bluesky
GitHub Workflow Status Go Reference Coverage Status Go Report Card License: MIT


### Implemented - [RFC 6525](https://www.rfc-editor.org/rfc/rfc6525.html) — Stream Control Transmission Protocol (SCTP) Stream Reconfiguration - [RFC 3758](https://www.rfc-editor.org/rfc/rfc3758.html) — Stream Control Transmission Protocol (SCTP) Partial Reliability Extension - [RFC 5061](https://www.rfc-editor.org/rfc/rfc5061.html) — Stream Control Transmission Protocol (SCTP) Dynamic Address Reconfiguration - [RFC 4895](https://www.rfc-editor.org/rfc/rfc4895.html) — Authenticated Chunks for the Stream Control Transmission Protocol (SCTP) - [RFC 1982](https://www.rfc-editor.org/rfc/rfc1982.html) — Serial Number Arithmetic ### Partial implementations Pion only implements the subset of RFC 4960 that is required for WebRTC. - [RFC 4960](https://www.rfc-editor.org/rfc/rfc4960.html) — Stream Control Transmission Protocol [Obsoleted by 9260, above] - [RFC 2960](https://www.rfc-editor.org/rfc/rfc2960.html) — Stream Control Transmission Protocol [Obsoleted by 4960, above] The update to [RFC 9260](https://www.rfc-editor.org/rfc/rfc9260) — Stream Control Transmission Protocol is currently a [work in progress](https://github.com/pion/sctp/issues/402). ### Potential future implementations Ideally, we would like to add the following features as part of a [v2 refresh](https://github.com/pion/sctp/issues/314): Feature | Reference | Progress --- | --- | --- RACK (tail loss probing) | [Paper](https://icnp20.cs.ucr.edu/proceedings/nipaa/RACK%20for%20SCTP.pdf), [Comment](https://github.com/pion/sctp/issues/206#issuecomment-968265853)| [In review](https://github.com/pion/sctp/pull/390) Adaptive burst mitigation | [Paper, see section 5A](https://icnp20.cs.ucr.edu/proceedings/nipaa/RACK%20for%20SCTP.pdf)| [In review](https://github.com/pion/sctp/pull/394) Update to RFC 9260 | [Parent issue](https://github.com/pion/sctp/issues/402) | [In progress](https://github.com/pion/sctp/issues/402) Implement RFC 8260 | [Issue](https://github.com/pion/sctp/issues/435) | In progress (no PR available yet) Blocking writes | [1](https://github.com/pion/sctp/issues/77), [2](https://github.com/pion/sctp/issues/357) | [Potentially in progress](https://github.com/pion/sctp/issues/357#issuecomment-3382050767) association.listener (and better docs) | [1](https://github.com/pion/sctp/issues/74), [2](https://github.com/pion/sctp/issues/173) | Not started, [blocked by above](https://github.com/pion/sctp/issues/74#issuecomment-545550714) RFCs of interest: - [RFC 9438](https://datatracker.ietf.org/doc/rfc9438/) as it addresses the low utilization problem of [RFC 4960](https://www.rfc-editor.org/rfc/rfc4960.html) in fast long-distance networks as mentioned [here](https://github.com/pion/sctp/issues/218#issuecomment-3329690797). ### Roadmap The library is used as a part of our WebRTC implementation. Please refer to that [roadmap](https://github.com/pion/webrtc/issues/9) to track our major milestones. ### Community Pion has an active community on the [Discord](https://discord.gg/PngbdqpFbt). Follow the [Pion Bluesky](https://bsky.app/profile/pion.ly) or [Pion Twitter](https://twitter.com/_pion) for project updates and important WebRTC news. We are always looking to support **your projects**. Please reach out if you have something to build! If you need commercial support or don't want to use public methods you can contact us at [team@pion.ly](mailto:team@pion.ly) ### Contributing Check out the [contributing wiki](https://github.com/pion/webrtc/wiki/Contributing) to join the group of amazing people making this project possible ### License MIT License - see [LICENSE](LICENSE) for full text sctp-1.9.0/ack_timer.go000066400000000000000000000041641512256410600147660ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "math" "sync" "time" ) const ( ackInterval time.Duration = 200 * time.Millisecond ) // ackTimerObserver is the inteface to an ack timer observer. type ackTimerObserver interface { onAckTimeout() } type ackTimerState uint8 const ( ackTimerStopped ackTimerState = iota ackTimerStarted ackTimerClosed ) // ackTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. type ackTimer struct { timer *time.Timer observer ackTimerObserver mutex sync.Mutex state ackTimerState pending uint8 } // newAckTimer creates a new acknowledgement timer used to enable delayed ack. func newAckTimer(observer ackTimerObserver) *ackTimer { t := &ackTimer{observer: observer} t.timer = time.AfterFunc(math.MaxInt64, t.timeout) t.timer.Stop() return t } func (t *ackTimer) timeout() { t.mutex.Lock() if t.pending--; t.pending == 0 && t.state == ackTimerStarted { t.state = ackTimerStopped defer t.observer.onAckTimeout() } t.mutex.Unlock() } // start starts the timer. func (t *ackTimer) start() bool { t.mutex.Lock() defer t.mutex.Unlock() // this timer is already closed or already running if t.state != ackTimerStopped { return false } t.state = ackTimerStarted t.pending++ t.timer.Reset(ackInterval) return true } // stops the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable). func (t *ackTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() if t.state == ackTimerStarted { if t.timer.Stop() { t.pending-- } t.state = ackTimerStopped } } // closes the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable). func (t *ackTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() if t.state == ackTimerStarted && t.timer.Stop() { t.pending-- } t.state = ackTimerClosed } // isRunning tests if the timer is running. // Debug purpose only. func (t *ackTimer) isRunning() bool { t.mutex.Lock() defer t.mutex.Unlock() return t.state == ackTimerStarted } sctp-1.9.0/ack_timer_test.go000066400000000000000000000046251512256410600160270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) type onAckTO func() type testAckTimerObserver struct { onAckTO onAckTO } func (o *testAckTimerObserver) onAckTimeout() { o.onAckTO() } func TestAckTimer(t *testing.T) { t.Run("start and close", func(t *testing.T) { var nCbs uint32 rt := newAckTimer(&testAckTimerObserver{ onAckTO: func() { t.Log("ack timed out") atomic.AddUint32(&nCbs, 1) }, }) for i := 0; i < 2; i++ { // should start ok ok := rt.start() assert.True(t, ok, "start() should succeed") assert.True(t, rt.isRunning(), "should be running") // subsequent start is a noop ok = rt.start() assert.False(t, ok, "start() should NOT succeed once closed") assert.True(t, rt.isRunning(), "should be running") // Sleep more than 2 * 200msec interval to test if it times out only once time.Sleep(ackInterval*2 + 50*time.Millisecond) assert.Equalf(t, uint32(1), atomic.LoadUint32(&nCbs), "should be called once (actual: %d)", atomic.LoadUint32(&nCbs)) atomic.StoreUint32(&nCbs, 0) } // should close ok rt.close() assert.False(t, rt.isRunning(), "should not be running") // once closed, it cannot start ok := rt.start() assert.False(t, ok, "start() should NOT succeed once closed") assert.False(t, rt.isRunning(), "should not be running") }) t.Run("start and stop", func(t *testing.T) { var nCbs uint32 rt := newAckTimer(&testAckTimerObserver{ onAckTO: func() { t.Log("ack timed out") atomic.AddUint32(&nCbs, 1) }, }) for i := 0; i < 2; i++ { // should start ok ok := rt.start() assert.True(t, ok, "start() should succeed") assert.True(t, rt.isRunning(), "should be running") // stop immedidately rt.stop() assert.False(t, rt.isRunning(), "should not be running") } // Sleep more than 200msec of interval to test if it never times out time.Sleep(ackInterval + 50*time.Millisecond) assert.Equalf(t, uint32(0), atomic.LoadUint32(&nCbs), "should not be timed out (actual: %d)", atomic.LoadUint32(&nCbs)) // can start again ok := rt.start() assert.True(t, ok, "start() should succeed again") assert.True(t, rt.isRunning(), "should be running") // should close ok rt.close() assert.False(t, rt.isRunning(), "should not be running") }) } sctp-1.9.0/association.go000066400000000000000000003363431512256410600153530ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "bytes" "context" "encoding/binary" "errors" "fmt" "io" "math" "net" "sync" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/randutil" "github.com/pion/transport/v3/deadline" ) // Port 5000 shows up in examples for SDPs used by WebRTC. Since this implementation // assumes it will be used by DTLS over UDP, the port is only meaningful for de-multiplexing // but more-so verification. // Example usage: https://www.rfc-editor.org/rfc/rfc8841.html#section-13.1-2 const defaultSCTPSrcDstPort = 5000 // Use global random generator to properly seed by crypto grade random. var globalMathRandomGenerator = randutil.NewMathRandomGenerator() // nolint:gochecknoglobals // Association errors. var ( ErrChunk = errors.New("abort chunk, with following errors") ErrShutdownNonEstablished = errors.New("shutdown called in non-established state") ErrAssociationClosedBeforeConn = errors.New("association closed before connecting") ErrAssociationClosed = errors.New("association closed") ErrSilentlyDiscard = errors.New("silently discard") ErrInitNotStoredToSend = errors.New("the init not stored to send") ErrCookieEchoNotStoredToSend = errors.New("cookieEcho not stored to send") ErrSCTPPacketSourcePortZero = errors.New("sctp packet must not have a source port of 0") ErrSCTPPacketDestinationPortZero = errors.New("sctp packet must not have a destination port of 0") ErrInitChunkBundled = errors.New("init chunk must not be bundled with any other chunk") ErrInitChunkVerifyTagNotZero = errors.New( "init chunk expects a verification tag of 0 on the packet when out-of-the-blue", ) ErrHandleInitState = errors.New("todo: handle Init when in state") ErrInitAckNoCookie = errors.New("no cookie in InitAck") ErrInflightQueueTSNPop = errors.New("unable to be popped from inflight queue TSN") ErrTSNRequestNotExist = errors.New("requested non-existent TSN") ErrResetPacketInStateNotExist = errors.New("sending reset packet in non-established state") ErrParamterType = errors.New("unexpected parameter type") ErrPayloadDataStateNotExist = errors.New("sending payload data in non-established state") ErrChunkTypeUnhandled = errors.New("unhandled chunk type") ErrHandshakeInitAck = errors.New("handshake failed (INIT ACK)") ErrHandshakeCookieEcho = errors.New("handshake failed (COOKIE ECHO)") ErrTooManyReconfigRequests = errors.New("too many outstanding reconfig requests") ) const ( receiveMTU uint32 = 8192 // MTU for inbound packet (from DTLS) initialMTU uint32 = 1228 // initial MTU for outgoing packets (to DTLS) initialRecvBufSize uint32 = 1024 * 1024 commonHeaderSize uint32 = 12 dataChunkHeaderSize uint32 = 16 defaultMaxMessageSize uint32 = 65536 ) // association state enums. const ( closed uint32 = iota cookieWait cookieEchoed established shutdownAckSent shutdownPending shutdownReceived shutdownSent ) // retransmission timer IDs. const ( timerT1Init int = iota timerT1Cookie timerT2Shutdown timerT3RTX timerReconfig ) // ack mode (for testing). const ( ackModeNormal int = iota ackModeNoDelay ackModeAlwaysDelay ) // ack transmission state. const ( ackStateIdle int = iota // ack timer is off ackStateImmediate // will send ack immediately ackStateDelay // ack timer is on (ack is being delayed) ) // other constants. const ( acceptChSize = 16 // avgChunkSize is an estimate of the average chunk size. There is no theory behind // this estimate. avgChunkSize = 500 // minTSNOffset is the minimum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size // see getMaxTSNOffset. minTSNOffset = 2000 // maxTSNOffset is the maximum offset over the cummulative TSN that we will enqueue // irrespective of the receive buffer size // see getMaxTSNOffset. maxTSNOffset = 40000 // maxReconfigRequests is the maximum number of reconfig requests we will keep outstanding. maxReconfigRequests = 1000 // TLR Adaptive burst mitigation uses quarter-MTU units. // 1 MTU == 4 units, 0.25 MTU == 1 unit. tlrUnitsPerMTU = 4 // Default burst limits. tlrBurstDefaultFirstRTT = 16 // 4.0 MTU tlrBurstDefaultLaterRTT = 8 // 2.0 MTU // Minimum burst limits. tlrBurstMinFirstRTT = 8 // 2.0 MTU tlrBurstMinLaterRTT = 5 // 1.25 MTU // Adaptation steps. tlrBurstStepDownFirstRTT = 4 // reduce by 1.0 MTU tlrBurstStepDownLaterRTT = 1 // reduce by 0.25 MTU tlrGoodOpsResetThreshold = 16 ) func getAssociationStateString(assoc uint32) string { switch assoc { case closed: return "Closed" case cookieWait: return "CookieWait" case cookieEchoed: return "CookieEchoed" case established: return "Established" case shutdownPending: return "ShutdownPending" case shutdownSent: return "ShutdownSent" case shutdownReceived: return "ShutdownReceived" case shutdownAckSent: return "ShutdownAckSent" default: return fmt.Sprintf("Invalid association state %d", assoc) } } // Association represents an SCTP association // 13.2. Parameters Necessary per Association (i.e., the TCB) // // Peer : Tag value to be sent in every packet and is received // Verification: in the INIT or INIT ACK chunk. // Tag : // State : A state variable indicating what state the association // : is in, i.e., COOKIE-WAIT, COOKIE-ECHOED, ESTABLISHED, // : SHUTDOWN-PENDING, SHUTDOWN-SENT, SHUTDOWN-RECEIVED, // : SHUTDOWN-ACK-SENT. // // Note: No "CLOSED" state is illustrated since if a // association is "CLOSED" its TCB SHOULD be removed. // Note: By nature of an Association being constructed with one net.Conn, // it is not a multi-home supporting implementation of SCTP. type Association struct { bytesReceived uint64 bytesSent uint64 lock sync.RWMutex netConn net.Conn peerVerificationTag uint32 myVerificationTag uint32 state uint32 initialTSN uint32 myNextTSN uint32 // nextTSN minTSN2MeasureRTT uint32 // for RTT measurement willSendForwardTSN bool willRetransmitFast bool willRetransmitReconfig bool willSendShutdown bool willSendShutdownAck bool willSendShutdownComplete bool willSendAbort bool willSendAbortCause errorCause // Reconfig myNextRSN uint32 reconfigs map[uint32]*chunkReconfig reconfigRequests map[uint32]*paramOutgoingResetRequest // Non-RFC internal data sourcePort uint16 destinationPort uint16 myMaxNumInboundStreams uint16 myMaxNumOutboundStreams uint16 myCookie *paramStateCookie payloadQueue *receivePayloadQueue inflightQueue *payloadQueue pendingQueue *pendingQueue controlQueue *controlQueue mtu uint32 maxPayloadSize uint32 // max DATA chunk payload size srtt atomic.Value // type float64 cumulativeTSNAckPoint uint32 advancedPeerTSNAckPoint uint32 useForwardTSN bool sendZeroChecksum bool recvZeroChecksum bool // Congestion control parameters maxReceiveBufferSize uint32 maxMessageSize uint32 cwnd uint32 // my congestion window size rwnd uint32 // calculated peer's receiver windows size ssthresh uint32 // slow start threshold partialBytesAcked uint32 inFastRecovery bool fastRecoverExitPoint uint32 minCwnd uint32 // Minimum congestion window fastRtxWnd uint32 // Send window for fast retransmit cwndCAStep uint32 // Step of congestion window increase at Congestion Avoidance // RTX & Ack timer rtoMgr *rtoManager t1Init *rtxTimer t1Cookie *rtxTimer t2Shutdown *rtxTimer t3RTX *rtxTimer tReconfig *rtxTimer ackTimer *ackTimer // RACK / TLP state rackReoWnd time.Duration // dynamic reordering window rackMinRTT time.Duration // min observed RTT rackMinRTTWnd *windowedMin // the window used to determine minRTT, defaults to 30s rackDeliveredTime time.Time // send time of most recently delivered original chunk rackHighestDeliveredOrigTSN uint32 rackReorderingSeen bool // ever observed reordering for this association rackKeepInflatedRecoveries int // keep inflated reoWnd for 16 loss recoveries // RACK xmit-time ordered list rackHead *chunkPayloadData rackTail *chunkPayloadData // Unified timer for RACK and PTO driven by a single goroutine. // Deadlines are protected with timerMu. timerMu sync.Mutex timerUpdateCh chan struct{} rackDeadline time.Time ptoDeadline time.Time rackWCDelAck time.Duration // 200ms default rackReoWndFloor time.Duration // Chunks stored for retransmission storedInit *chunkInit storedCookieEcho *chunkCookieEcho streams map[uint16]*Stream acceptCh chan *Stream readLoopCloseCh chan struct{} awakeWriteLoopCh chan struct{} closeWriteLoopCh chan struct{} handshakeCompletedCh chan error closeWriteLoopOnce sync.Once // local error silentError error ackState int ackMode int // for testing // stats stats *associationStats // per inbound packet context delayedAckTriggered bool immediateAckTriggered bool blockWrite bool writePending bool writeNotify chan struct{} name string log logging.LeveledLogger // Adaptive burst mitigation variables tlrActive bool tlrFirstRTT bool // first RTT of this TLR operation tlrHadAdditionalLoss bool tlrEndTSN uint32 // recovery is done when cumAck >= tlrEndTSN tlrBurstFirstRTTUnits int64 // quarter-MTU units tlrBurstLaterRTTUnits int64 // quarter-MTU units tlrGoodOps uint32 // count of TLR ops completed w/o additional loss tlrStartTime time.Time // time of first recovery RTT } // Config collects the arguments to createAssociation construction into // a single structure. type Config struct { Name string NetConn net.Conn MaxReceiveBufferSize uint32 MaxMessageSize uint32 EnableZeroChecksum bool LoggerFactory logging.LoggerFactory BlockWrite bool MTU uint32 // congestion control configuration // RTOMax is the maximum retransmission timeout in milliseconds RTOMax float64 // Minimum congestion window MinCwnd uint32 // Send window for fast retransmit FastRtxWnd uint32 // Step of congestion window increase at Congestion Avoidance CwndCAStep uint32 // The RACK configs are currently private as SCTP will be reworked to use the // modern options pattern in a future release. // Optional: size of window used to determine minimum RTT for RACK (defaults to 30s) rackMinRTTWnd time.Duration // Optional: cap the minimum reordering window: 0 = use quarter-RTT rackReoWndFloor time.Duration // Optional: receiver worst-case delayed-ACK for PTO when only one packet is in flight rackWCDelAck time.Duration } // Server accepts a SCTP stream over a conn. func Server(config Config) (*Association, error) { a := createAssociation(config) a.init(false) select { case err := <-a.handshakeCompletedCh: if err != nil { return nil, err } return a, nil case <-a.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } // Client opens a SCTP stream over a conn. func Client(config Config) (*Association, error) { return createClientWithContext(context.Background(), config) } func createClientWithContext(ctx context.Context, config Config) (*Association, error) { assoc := createAssociation(config) assoc.init(true) select { case <-ctx.Done(): assoc.log.Errorf("[%s] client handshake canceled: state=%s", assoc.name, getAssociationStateString(assoc.getState())) assoc.Close() // nolint:errcheck,gosec return nil, ctx.Err() case err := <-assoc.handshakeCompletedCh: if err != nil { return nil, err } return assoc, nil case <-assoc.readLoopCloseCh: return nil, ErrAssociationClosedBeforeConn } } func createAssociation(config Config) *Association { maxReceiveBufferSize := config.MaxReceiveBufferSize if maxReceiveBufferSize == 0 { maxReceiveBufferSize = initialRecvBufSize } maxMessageSize := config.MaxMessageSize if maxMessageSize == 0 { maxMessageSize = defaultMaxMessageSize } mtu := config.MTU if mtu == 0 { mtu = initialMTU } tsn := globalMathRandomGenerator.Uint32() assoc := &Association{ netConn: config.NetConn, maxReceiveBufferSize: maxReceiveBufferSize, maxMessageSize: maxMessageSize, minCwnd: config.MinCwnd, fastRtxWnd: config.FastRtxWnd, cwndCAStep: config.CwndCAStep, // These two max values have us not need to follow // 5.1.1 where this peer may be incapable of supporting // the requested amount of outbound streams from the other // peer. myMaxNumOutboundStreams: math.MaxUint16, myMaxNumInboundStreams: math.MaxUint16, payloadQueue: newReceivePayloadQueue(getMaxTSNOffset(maxReceiveBufferSize)), inflightQueue: newPayloadQueue(), pendingQueue: newPendingQueue(), controlQueue: newControlQueue(), mtu: mtu, maxPayloadSize: mtu - (commonHeaderSize + dataChunkHeaderSize), myVerificationTag: globalMathRandomGenerator.Uint32(), initialTSN: tsn, myNextTSN: tsn, myNextRSN: tsn, minTSN2MeasureRTT: tsn, state: closed, rtoMgr: newRTOManager(config.RTOMax), streams: map[uint16]*Stream{}, reconfigs: map[uint32]*chunkReconfig{}, reconfigRequests: map[uint32]*paramOutgoingResetRequest{}, acceptCh: make(chan *Stream, acceptChSize), readLoopCloseCh: make(chan struct{}), awakeWriteLoopCh: make(chan struct{}, 1), closeWriteLoopCh: make(chan struct{}), handshakeCompletedCh: make(chan error), cumulativeTSNAckPoint: tsn - 1, advancedPeerTSNAckPoint: tsn - 1, recvZeroChecksum: config.EnableZeroChecksum, silentError: ErrSilentlyDiscard, stats: &associationStats{}, log: config.LoggerFactory.NewLogger("sctp"), name: config.Name, blockWrite: config.BlockWrite, writeNotify: make(chan struct{}, 1), } // adaptive burst mitigation defaults assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT // RACK defaults assoc.rackWCDelAck = config.rackWCDelAck if assoc.rackWCDelAck == 0 { assoc.rackWCDelAck = 200 * time.Millisecond // WCDelAckT, RACK for SCTP section 2C } // defaults to 30s window to determine minRTT assoc.rackMinRTTWnd = newWindowedMin(config.rackMinRTTWnd) assoc.timerUpdateCh = make(chan struct{}, 1) go assoc.timerLoop() assoc.rackReoWndFloor = config.rackReoWndFloor // optional floor; usually 0 assoc.rackKeepInflatedRecoveries = 0 if assoc.name == "" { assoc.name = fmt.Sprintf("%p", assoc) } // RFC 4690 Sec 7.2.1 // o The initial cwnd before DATA transmission or after a sufficiently // long idle period MUST be set to min(4*MTU, max (2*MTU, 4380 // bytes)). assoc.setCWND(min32(4*assoc.MTU(), max32(2*assoc.MTU(), 4380))) assoc.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", assoc.name, assoc.CWND(), assoc.ssthresh, assoc.inflightQueue.getNumBytes()) assoc.srtt.Store(float64(0)) assoc.t1Init = newRTXTimer(timerT1Init, assoc, maxInitRetrans, config.RTOMax) assoc.t1Cookie = newRTXTimer(timerT1Cookie, assoc, maxInitRetrans, config.RTOMax) assoc.t2Shutdown = newRTXTimer(timerT2Shutdown, assoc, noMaxRetrans, config.RTOMax) assoc.t3RTX = newRTXTimer(timerT3RTX, assoc, noMaxRetrans, config.RTOMax) assoc.tReconfig = newRTXTimer(timerReconfig, assoc, noMaxRetrans, config.RTOMax) assoc.ackTimer = newAckTimer(assoc) return assoc } func (a *Association) init(isClient bool) { a.lock.Lock() defer a.lock.Unlock() go a.readLoop() go a.writeLoop() if isClient { init := &chunkInit{} init.initialTSN = a.myNextTSN init.numOutboundStreams = a.myMaxNumOutboundStreams init.numInboundStreams = a.myMaxNumInboundStreams init.initiateTag = a.myVerificationTag init.advertisedReceiverWindowCredit = a.maxReceiveBufferSize setSupportedExtensions(&init.chunkInitCommon) if a.recvZeroChecksum { init.params = append(init.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) } a.storedInit = init err := a.sendInit() if err != nil { a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) } // After sending the INIT chunk, "A" starts the T1-init timer and enters the COOKIE-WAIT state. // Note: ideally we would set state after the timer starts but since we don't do this in an atomic // set + timer-start, it's safer to just set the state first so that we don't have a timer expiration // race. a.setState(cookieWait) a.t1Init.start(a.rtoMgr.getRTO()) } } // caller must hold a.lock. func (a *Association) sendInit() error { a.log.Debugf("[%s] sending INIT", a.name) if a.storedInit == nil { return ErrInitNotStoredToSend } outbound := &packet{} outbound.verificationTag = 0 a.sourcePort = defaultSCTPSrcDstPort a.destinationPort = defaultSCTPSrcDstPort outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{a.storedInit} a.controlQueue.push(outbound) a.awakeWriteLoop() return nil } // caller must hold a.lock. func (a *Association) sendCookieEcho() error { if a.storedCookieEcho == nil { return ErrCookieEchoNotStoredToSend } a.log.Debugf("[%s] sending COOKIE-ECHO", a.name) outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{a.storedCookieEcho} a.controlQueue.push(outbound) a.awakeWriteLoop() return nil } // Shutdown initiates the shutdown sequence. The method blocks until the // shutdown sequence is completed and the connection is closed, or until the // passed context is done, in which case the context's error is returned. func (a *Association) Shutdown(ctx context.Context) error { a.log.Debugf("[%s] closing association..", a.name) state := a.getState() if state != established { return fmt.Errorf("%w: shutdown %s", ErrShutdownNonEstablished, a.name) } // Attempt a graceful shutdown. a.setState(shutdownPending) a.lock.Lock() if a.inflightQueue.size() == 0 { // No more outstanding, send shutdown. a.willSendShutdown = true a.awakeWriteLoop() a.setState(shutdownSent) } a.lock.Unlock() select { case <-a.closeWriteLoopCh: return nil case <-ctx.Done(): return ctx.Err() } } // Close ends the SCTP Association and cleans up any state. func (a *Association) Close() error { a.log.Debugf("[%s] closing association..", a.name) err := a.close() // Wait for readLoop to end <-a.readLoopCloseCh a.log.Debugf("[%s] association closed", a.name) a.log.Debugf("[%s] stats nPackets (in) : %d", a.name, a.stats.getNumPacketsReceived()) a.log.Debugf("[%s] stats nPackets (out) : %d", a.name, a.stats.getNumPacketsSent()) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) a.log.Debugf("[%s] stats nSACKs (out) : %d", a.name, a.stats.getNumSACKsSent()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) return err } func (a *Association) close() error { a.log.Debugf("[%s] closing association..", a.name) a.setState(closed) err := a.netConn.Close() a.closeAllTimers() // awake writeLoop to exit a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) return err } // Abort sends the abort packet with user initiated abort and immediately // closes the connection. func (a *Association) Abort(reason string) { a.log.Debugf("[%s] aborting association: %s", a.name, reason) a.lock.Lock() a.willSendAbort = true a.willSendAbortCause = &errorCauseUserInitiatedAbort{ upperLayerAbortReason: []byte(reason), } a.lock.Unlock() // short bound for abort flush. _ = a.netConn.SetWriteDeadline(time.Now().Add(200 * time.Millisecond)) a.awakeWriteLoop() // unblock readLoop even if the underlying TCP connection is half-open. // We want Abort to return promptly during shutdown. _ = a.netConn.SetReadDeadline(time.Now()) // Wait for readLoop to end <-a.readLoopCloseCh } func (a *Association) closeAllTimers() { // Close all retransmission & ack timers a.t1Init.close() a.t1Cookie.close() a.t2Shutdown.close() a.t3RTX.close() a.tReconfig.close() a.ackTimer.close() a.stopRackTimer() a.stopPTOTimer() } func (a *Association) readLoop() { var closeErr error defer func() { // also stop writeLoop, otherwise writeLoop can be leaked // if connection is lost when there is no writing packet. a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) a.lock.Lock() a.setState(closed) for _, s := range a.streams { a.unregisterStream(s, closeErr) } a.lock.Unlock() close(a.acceptCh) close(a.readLoopCloseCh) a.log.Debugf("[%s] association closed", a.name) a.log.Debugf("[%s] stats nDATAs (in) : %d", a.name, a.stats.getNumDATAs()) a.log.Debugf("[%s] stats nSACKs (in) : %d", a.name, a.stats.getNumSACKsReceived()) a.log.Debugf("[%s] stats nT3Timeouts : %d", a.name, a.stats.getNumT3Timeouts()) a.log.Debugf("[%s] stats nAckTimeouts: %d", a.name, a.stats.getNumAckTimeouts()) a.log.Debugf("[%s] stats nFastRetrans: %d", a.name, a.stats.getNumFastRetrans()) }() a.log.Debugf("[%s] readLoop entered", a.name) buffer := make([]byte, receiveMTU) for { n, err := a.netConn.Read(buffer) if err != nil { closeErr = err break } // Make a buffer sized to what we read, then copy the data we // read from the underlying transport. We do this because the // user data is passed to the reassembly queue without // copying. inbound := make([]byte, n) copy(inbound, buffer[:n]) atomic.AddUint64(&a.bytesReceived, uint64(n)) //nolint:gosec // G115 if err = a.handleInbound(inbound); err != nil { closeErr = err break } } a.log.Debugf("[%s] readLoop exited %s", a.name, closeErr) } func (a *Association) writeLoop() { a.log.Debugf("[%s] writeLoop entered", a.name) defer a.log.Debugf("[%s] writeLoop exited", a.name) loop: for { rawPackets, ok := a.gatherOutbound() for _, raw := range rawPackets { _, err := a.netConn.Write(raw) if err != nil { if !errors.Is(err, io.EOF) { a.log.Warnf("[%s] failed to write packets on netConn: %v", a.name, err) } a.log.Debugf("[%s] writeLoop ended", a.name) break loop } atomic.AddUint64(&a.bytesSent, uint64(len(raw))) a.stats.incPacketsSent() } if !ok { if err := a.close(); err != nil { a.log.Warnf("[%s] failed to close association: %v", a.name, err) } return } select { case <-a.awakeWriteLoopCh: case <-a.closeWriteLoopCh: break loop } } a.setState(closed) a.closeAllTimers() } func (a *Association) awakeWriteLoop() { select { case a.awakeWriteLoopCh <- struct{}{}: default: } } func (a *Association) isBlockWrite() bool { return a.blockWrite } // Mark the association is writable and unblock the waiting write, // the caller should hold the association write lock. func (a *Association) notifyBlockWritable() { a.writePending = false select { case a.writeNotify <- struct{}{}: default: } } // unregisterStream un-registers a stream from the association // The caller should hold the association write lock. func (a *Association) unregisterStream(s *Stream, err error) { s.lock.Lock() defer s.lock.Unlock() delete(a.streams, s.streamIdentifier) s.readErr = err s.readNotifier.Broadcast() } func chunkMandatoryChecksum(cc []chunk) bool { for _, c := range cc { switch c.(type) { case *chunkInit, *chunkCookieEcho: return true } } return false } func (a *Association) marshalPacket(p *packet) ([]byte, error) { return p.marshal(!a.sendZeroChecksum || chunkMandatoryChecksum(p.chunks)) } func (a *Association) unmarshalPacket(raw []byte) (*packet, error) { p := &packet{} if err := p.unmarshal(!a.recvZeroChecksum, raw); err != nil { return nil, err } return p, nil } // handleInbound parses incoming raw packets. func (a *Association) handleInbound(raw []byte) error { pkt, err := a.unmarshalPacket(raw) if err != nil { a.log.Warnf("[%s] unable to parse SCTP packet %s", a.name, err) return nil } if err := checkPacket(pkt); err != nil { a.log.Warnf("[%s] failed validating packet %s", a.name, err) return nil } a.handleChunksStart() for _, c := range pkt.chunks { if err := a.handleChunk(pkt, c); err != nil { return err } } a.handleChunksEnd() return nil } // The caller should hold the lock. func (a *Association) gatherDataPacketsToRetransmit(rawPackets [][]byte, budgetUnits *int64, consumed *bool) [][]byte { for _, p := range a.getDataPacketsToRetransmit(budgetUnits, consumed) { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be retransmitted", a.name) continue } rawPackets = append(rawPackets, raw) } return rawPackets } // The caller should hold the lock. // //nolint:cyclop func (a *Association) gatherOutboundDataAndReconfigPackets( rawPackets [][]byte, budgetUnits *int64, consumed *bool, ) [][]byte { // Pop unsent data chunks from the pending queue to send as much as // cwnd and rwnd allow. chunks, sisToReset := a.popPendingDataChunksToSend(budgetUnits, consumed) if len(chunks) > 0 { // Start timer. (noop if already started) a.log.Tracef("[%s] T3-rtx timer start (pt1)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) for _, p := range a.bundleDataChunksIntoPackets(chunks) { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet", a.name) continue } rawPackets = append(rawPackets, raw) } // RFC 8985 (RACK) schedule PTO on new data transmission a.schedulePTOAfterSendLocked() } if len(sisToReset) > 0 || a.willRetransmitReconfig { //nolint:nestif if a.willRetransmitReconfig { a.willRetransmitReconfig = false a.log.Debugf("[%s] retransmit %d RECONFIG chunk(s)", a.name, len(a.reconfigs)) for _, c := range a.reconfigs { p := a.createPacket([]chunk{c}) raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be retransmitted", a.name) } else { rawPackets = append(rawPackets, raw) } } } if len(sisToReset) > 0 { rsn := a.generateNextRSN() tsn := a.myNextTSN - 1 c := &chunkReconfig{ paramA: ¶mOutgoingResetRequest{ reconfigRequestSequenceNumber: rsn, senderLastTSN: tsn, streamIdentifiers: sisToReset, }, } a.reconfigs[rsn] = c // store in the map for retransmission a.log.Debugf("[%s] sending RECONFIG: rsn=%d tsn=%d streams=%v", a.name, rsn, a.myNextTSN-1, sisToReset) p := a.createPacket([]chunk{c}) raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a RECONFIG packet to be transmitted", a.name) } else { rawPackets = append(rawPackets, raw) } } if len(a.reconfigs) > 0 { a.tReconfig.start(a.rtoMgr.getRTO()) } } return rawPackets } // The caller should hold the lock. // //nolint:cyclop func (a *Association) gatherOutboundFastRetransmissionPackets( //nolint:gocognit rawPackets [][]byte, budgetScaled *int64, consumed *bool, ) [][]byte { if !a.willRetransmitFast { return rawPackets } a.willRetransmitFast = false toFastRetrans := []*chunkPayloadData{} fastRetransSize := int(commonHeaderSize) fastRetransWnd := int(max(a.MTU(), a.fastRtxWnd)) now := time.Now() // MTU bundling + burst budgeting tracker bytesInPacket := 0 stopBundling := false for i := 0; ; i++ { chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 if !ok { break // end of pending data } if chunkPayload.acked || chunkPayload.abandoned() { continue } if chunkPayload.nSent > 1 || chunkPayload.missIndicator < 3 { continue } // include padding for sizing. chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) chunkBytes += getPadding(chunkBytes) // fast retransmit window cap if fastRetransWnd < fastRetransSize+chunkBytes { break } // MTU bundling + burst budget before mutating for { addBytes := chunkBytes if bytesInPacket == 0 { addBytes += int(commonHeaderSize) if addBytes > int(a.MTU()) { stopBundling = true break } } else if bytesInPacket+chunkBytes > int(a.MTU()) { // start a new packet and retry this same chunk as first in packet bytesInPacket = 0 continue } if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { // budget exhausted, stop selecting any more fast-rtx chunks stopBundling = true break } if bytesInPacket == 0 { bytesInPacket = int(commonHeaderSize) } bytesInPacket += chunkBytes break } if stopBundling { break } fastRetransSize += chunkBytes a.stats.incFastRetrans() // Update for retransmission chunkPayload.nSent++ chunkPayload.since = now a.rackRemove(chunkPayload) a.rackInsert(chunkPayload) a.checkPartialReliabilityStatus(chunkPayload) toFastRetrans = append(toFastRetrans, chunkPayload) a.log.Tracef("[%s] fast-retransmit: tsn=%d sent=%d htna=%d", a.name, chunkPayload.tsn, chunkPayload.nSent, a.fastRecoverExitPoint) } if len(toFastRetrans) == 0 { return rawPackets } for _, p := range a.bundleDataChunksIntoPackets(toFastRetrans) { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a DATA packet to be fast-retransmitted", a.name) continue } rawPackets = append(rawPackets, raw) } return rawPackets } // The caller should hold the lock. func (a *Association) gatherOutboundSackPackets(rawPackets [][]byte) [][]byte { if a.ackState == ackStateImmediate { a.ackState = ackStateIdle sack := a.createSelectiveAckChunk() a.stats.incSACKsSent() a.log.Debugf("[%s] sending SACK: %s", a.name, sack) raw, err := a.marshalPacket(a.createPacket([]chunk{sack})) if err != nil { a.log.Warnf("[%s] failed to serialize a SACK packet", a.name) } else { rawPackets = append(rawPackets, raw) } } return rawPackets } // The caller should hold the lock. func (a *Association) gatherOutboundForwardTSNPackets(rawPackets [][]byte) [][]byte { if a.willSendForwardTSN { a.willSendForwardTSN = false if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { fwdtsn := a.createForwardTSN() raw, err := a.marshalPacket(a.createPacket([]chunk{fwdtsn})) if err != nil { a.log.Warnf("[%s] failed to serialize a Forward TSN packet", a.name) } else { rawPackets = append(rawPackets, raw) } } } return rawPackets } func (a *Association) gatherOutboundShutdownPackets(rawPackets [][]byte) ([][]byte, bool) { ok := true switch { case a.willSendShutdown: a.willSendShutdown = false shutdown := &chunkShutdown{ cumulativeTSNAck: a.cumulativeTSNAckPoint, } raw, err := a.marshalPacket(a.createPacket([]chunk{shutdown})) if err != nil { a.log.Warnf("[%s] failed to serialize a Shutdown packet", a.name) } else { a.t2Shutdown.start(a.rtoMgr.getRTO()) rawPackets = append(rawPackets, raw) } case a.willSendShutdownAck: a.willSendShutdownAck = false shutdownAck := &chunkShutdownAck{} raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownAck})) if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownAck packet", a.name) } else { a.t2Shutdown.start(a.rtoMgr.getRTO()) rawPackets = append(rawPackets, raw) } case a.willSendShutdownComplete: a.willSendShutdownComplete = false shutdownComplete := &chunkShutdownComplete{} raw, err := a.marshalPacket(a.createPacket([]chunk{shutdownComplete})) if err != nil { a.log.Warnf("[%s] failed to serialize a ShutdownComplete packet", a.name) } else { rawPackets = append(rawPackets, raw) ok = false } } return rawPackets, ok } func (a *Association) gatherAbortPacket() ([]byte, error) { cause := a.willSendAbortCause a.willSendAbort = false a.willSendAbortCause = nil abort := &chunkAbort{} if cause != nil { abort.errorCauses = []errorCause{cause} } raw, err := a.marshalPacket(a.createPacket([]chunk{abort})) return raw, err } // gatherOutbound gathers outgoing packets. The returned bool value set to // false means the association should be closed down after the final send. func (a *Association) gatherOutbound() ([][]byte, bool) { a.lock.Lock() defer a.lock.Unlock() if a.willSendAbort { pkt, err := a.gatherAbortPacket() if err != nil { a.log.Warnf("[%s] failed to serialize an abort packet", a.name) return nil, false } return [][]byte{pkt}, false } rawPackets := [][]byte{} if a.controlQueue.size() > 0 { for _, p := range a.controlQueue.popAll() { raw, err := a.marshalPacket(p) if err != nil { a.log.Warnf("[%s] failed to serialize a control packet", a.name) continue } rawPackets = append(rawPackets, raw) } } state := a.getState() ok := true switch state { case established: budgetUnits := a.tlrCurrentBurstBudgetScaledLocked() consumed := false rawPackets = a.gatherDataPacketsToRetransmit(rawPackets, &budgetUnits, &consumed) rawPackets = a.gatherOutboundDataAndReconfigPackets(rawPackets, &budgetUnits, &consumed) rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets, &budgetUnits, &consumed) // control traffic shouldn't be limited. rawPackets = a.gatherOutboundSackPackets(rawPackets) rawPackets = a.gatherOutboundForwardTSNPackets(rawPackets) case shutdownPending, shutdownSent, shutdownReceived: budgetUnits := a.tlrCurrentBurstBudgetScaledLocked() consumed := false rawPackets = a.gatherDataPacketsToRetransmit(rawPackets, &budgetUnits, &consumed) rawPackets = a.gatherOutboundFastRetransmissionPackets(rawPackets, &budgetUnits, &consumed) rawPackets = a.gatherOutboundSackPackets(rawPackets) rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) case shutdownAckSent: rawPackets, ok = a.gatherOutboundShutdownPackets(rawPackets) } return rawPackets, ok } func checkPacket(pkt *packet) error { // All packets must adhere to these rules // This is the SCTP sender's port number. It can be used by the // receiver in combination with the source IP address, the SCTP // destination port, and possibly the destination IP address to // identify the association to which this packet belongs. The port // number 0 MUST NOT be used. if pkt.sourcePort == 0 { return ErrSCTPPacketSourcePortZero } // This is the SCTP port number to which this packet is destined. // The receiving host will use this port number to de-multiplex the // SCTP packet to the correct receiving endpoint/application. The // port number 0 MUST NOT be used. if pkt.destinationPort == 0 { return ErrSCTPPacketDestinationPortZero } // Check values on the packet that are specific to a particular chunk type for _, c := range pkt.chunks { switch c.(type) { // nolint:gocritic case *chunkInit: // An INIT or INIT ACK chunk MUST NOT be bundled with any other chunk. // They MUST be the only chunks present in the SCTP packets that carry // them. if len(pkt.chunks) != 1 { return ErrInitChunkBundled } // A packet containing an INIT chunk MUST have a zero Verification // Tag. if pkt.verificationTag != 0 { return ErrInitChunkVerifyTagNotZero } } } return nil } func min16(a, b uint16) uint16 { if a < b { return a } return b } func max32(a, b uint32) uint32 { if a > b { return a } return b } func min32(a, b uint32) uint32 { if a < b { return a } return b } // peerLastTSN return last received cumulative TSN. func (a *Association) peerLastTSN() uint32 { return a.payloadQueue.getcumulativeTSN() } // setState atomically sets the state of the Association. // The caller should hold the lock. func (a *Association) setState(newState uint32) { oldState := atomic.SwapUint32(&a.state, newState) if newState != oldState { a.log.Debugf("[%s] state change: '%s' => '%s'", a.name, getAssociationStateString(oldState), getAssociationStateString(newState)) } } // getState atomically returns the state of the Association. func (a *Association) getState() uint32 { return atomic.LoadUint32(&a.state) } // BytesSent returns the number of bytes sent. func (a *Association) BytesSent() uint64 { return atomic.LoadUint64(&a.bytesSent) } // BytesReceived returns the number of bytes received. func (a *Association) BytesReceived() uint64 { return atomic.LoadUint64(&a.bytesReceived) } // MTU returns the association's current MTU. func (a *Association) MTU() uint32 { return atomic.LoadUint32(&a.mtu) } // CWND returns the association's current congestion window (cwnd). func (a *Association) CWND() uint32 { return atomic.LoadUint32(&a.cwnd) } func (a *Association) setCWND(cwnd uint32) { if cwnd < a.minCwnd { cwnd = a.minCwnd } atomic.StoreUint32(&a.cwnd, cwnd) } // RWND returns the association's current receiver window (rwnd). func (a *Association) RWND() uint32 { return atomic.LoadUint32(&a.rwnd) } func (a *Association) setRWND(rwnd uint32) { atomic.StoreUint32(&a.rwnd, rwnd) } // SRTT returns the latest smoothed round-trip time (srrt). func (a *Association) SRTT() float64 { return a.srtt.Load().(float64) //nolint:forcetypeassert } // getMaxTSNOffset returns the maximum offset over the current cummulative TSN that // we are willing to enqueue. This ensures that we keep the bytes utilized in the receive // buffer within a small multiple of the user provided max receive buffer size. func getMaxTSNOffset(maxReceiveBufferSize uint32) uint32 { // 4 is a magic number here. There is no theory behind this. offset := max((maxReceiveBufferSize*4)/avgChunkSize, minTSNOffset) if offset > maxTSNOffset { offset = maxTSNOffset } return offset } func setSupportedExtensions(init *chunkInitCommon) { // nolint:godox // TODO RFC5061 https://tools.ietf.org/html/rfc6525#section-5.2 // An implementation supporting this (Supported Extensions Parameter) // extension MUST list the ASCONF, the ASCONF-ACK, and the AUTH chunks // in its INIT and INIT-ACK parameters. init.params = append(init.params, ¶mSupportedExtensions{ ChunkTypes: []chunkType{ctReconfig, ctForwardTSN}, }) } // The caller should hold the lock. // //nolint:cyclop func (a *Association) handleInit(pkt *packet, initChunk *chunkInit) ([]*packet, error) { state := a.getState() a.log.Debugf("[%s] chunkInit received in state '%s'", a.name, getAssociationStateString(state)) // https://tools.ietf.org/html/rfc4960#section-5.2.1 // Upon receipt of an INIT in the COOKIE-WAIT state, an endpoint MUST // respond with an INIT ACK using the same parameters it sent in its // original INIT chunk (including its Initiate Tag, unchanged). When // responding, the endpoint MUST send the INIT ACK back to the same // address that the original INIT (sent by this endpoint) was sent. if state != closed && state != cookieWait && state != cookieEchoed { // 5.2.2. Unexpected INIT in States Other than CLOSED, COOKIE-ECHOED, // COOKIE-WAIT, and SHUTDOWN-ACK-SENT return nil, fmt.Errorf("%w: %s", ErrHandleInitState, getAssociationStateString(state)) } // NOTE: Setting these prior to a reception of a COOKIE ECHO chunk containing // our cookie is not compliant with https://www.rfc-editor.org/rfc/rfc9260#section-5.1-2.2.3. // It makes us more vulnerable to resource attacks, albeit minimally so. // https://www.rfc-editor.org/rfc/rfc9260#sec_handle_stream_parameters a.myMaxNumInboundStreams = min16(initChunk.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(initChunk.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = initChunk.initiateTag a.sourcePort = pkt.destinationPort a.destinationPort = pkt.sourcePort // 13.2 This is the last TSN received in sequence. This value // is set initially by taking the peer's initial TSN, // received in the INIT or INIT ACK chunk, and // subtracting one from it. a.payloadQueue.init(initChunk.initialTSN - 1) a.setRWND(initChunk.advertisedReceiverWindowCredit) a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) for _, param := range initChunk.params { switch v := param.(type) { // nolint:gocritic case *paramSupportedExtensions: for _, t := range v.ChunkTypes { if t == ctForwardTSN { a.log.Debugf("[%s] use ForwardTSN (on init)", a.name) a.useForwardTSN = true } } case *paramZeroChecksumAcceptable: a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on init)", a.name) } outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort initAck := &chunkInitAck{} a.log.Debug("sending INIT ACK") initAck.initialTSN = a.myNextTSN initAck.numOutboundStreams = a.myMaxNumOutboundStreams initAck.numInboundStreams = a.myMaxNumInboundStreams initAck.initiateTag = a.myVerificationTag initAck.advertisedReceiverWindowCredit = a.maxReceiveBufferSize if a.myCookie == nil { var err error // NOTE: This generation process is not compliant with // 5.1.3. Generating State Cookie (https://www.rfc-editor.org/rfc/rfc4960#section-5.1.3) if a.myCookie, err = newRandomStateCookie(); err != nil { return nil, err } } initAck.params = []param{a.myCookie} if a.recvZeroChecksum { initAck.params = append(initAck.params, ¶mZeroChecksumAcceptable{edmid: dtlsErrorDetectionMethod}) } a.log.Debugf("[%s] sendZeroChecksum=%t (on init)", a.name, a.sendZeroChecksum) setSupportedExtensions(&initAck.chunkInitCommon) outbound.chunks = []chunk{initAck} return pack(outbound), nil } // The caller should hold the lock. func (a *Association) handleInitAck(pkt *packet, initChunkAck *chunkInitAck) error { //nolint:cyclop state := a.getState() a.log.Debugf("[%s] chunkInitAck received in state '%s'", a.name, getAssociationStateString(state)) if state != cookieWait { // RFC 4960 // 5.2.3. Unexpected INIT ACK // If an INIT ACK is received by an endpoint in any state other than the // COOKIE-WAIT state, the endpoint should discard the INIT ACK chunk. // An unexpected INIT ACK usually indicates the processing of an old or // duplicated INIT chunk. return nil } a.myMaxNumInboundStreams = min16(initChunkAck.numInboundStreams, a.myMaxNumInboundStreams) a.myMaxNumOutboundStreams = min16(initChunkAck.numOutboundStreams, a.myMaxNumOutboundStreams) a.peerVerificationTag = initChunkAck.initiateTag a.payloadQueue.init(initChunkAck.initialTSN - 1) if a.sourcePort != pkt.destinationPort || a.destinationPort != pkt.sourcePort { a.log.Warnf("[%s] handleInitAck: port mismatch", a.name) return nil } a.setRWND(initChunkAck.advertisedReceiverWindowCredit) a.log.Debugf("[%s] initial rwnd=%d", a.name, a.RWND()) // RFC 4690 Sec 7.2.1 // o The initial value of ssthresh MAY be arbitrarily high (for // example, implementations MAY use the size of the receiver // advertised window). a.ssthresh = a.RWND() a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (INI)", a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) a.t1Init.stop() a.storedInit = nil var cookieParam *paramStateCookie for _, param := range initChunkAck.params { switch v := param.(type) { case *paramStateCookie: cookieParam = v case *paramSupportedExtensions: for _, t := range v.ChunkTypes { if t == ctForwardTSN { a.log.Debugf("[%s] use ForwardTSN (on initAck)", a.name) a.useForwardTSN = true } } case *paramZeroChecksumAcceptable: a.sendZeroChecksum = v.edmid == dtlsErrorDetectionMethod } } a.log.Debugf("[%s] sendZeroChecksum=%t (on initAck)", a.name, a.sendZeroChecksum) if !a.useForwardTSN { a.log.Warnf("[%s] not using ForwardTSN (on initAck)", a.name) } if cookieParam == nil { return ErrInitAckNoCookie } a.storedCookieEcho = &chunkCookieEcho{} a.storedCookieEcho.cookie = cookieParam.cookie err := a.sendCookieEcho() if err != nil { a.log.Errorf("[%s] failed to send init: %s", a.name, err.Error()) } a.t1Cookie.start(a.rtoMgr.getRTO()) a.setState(cookieEchoed) return nil } // The caller should hold the lock. func (a *Association) handleHeartbeat(c *chunkHeartbeat) []*packet { a.log.Tracef("[%s] chunkHeartbeat", a.name) if len(c.params) == 0 { a.log.Warnf("[%s] Heartbeat without ParamHeartbeatInfo (no params)", a.name) return nil } info, ok := c.params[0].(*paramHeartbeatInfo) if !ok { a.log.Warnf("[%s] Heartbeat without ParamHeartbeatInfo (got %T)", a.name, c.params[0]) return nil } return pack(&packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: []chunk{&chunkHeartbeatAck{ params: []param{ ¶mHeartbeatInfo{ heartbeatInformation: info.heartbeatInformation, }, }, }}, }) } // The caller should hold the lock. func (a *Association) handleHeartbeatAck(c *chunkHeartbeatAck) { a.log.Tracef("[%s] chunkHeartbeatAck", a.name) if len(c.params) == 0 { return } info, ok := c.params[0].(*paramHeartbeatInfo) if !ok { a.log.Warnf("[%s] HeartbeatAck without ParamHeartbeatInfo", a.name) return } // active RTT probe: if heartbeatInformation is exactly 8 bytes, treat it // as a big-endian unix nano timestamp. if len(info.heartbeatInformation) == 8 { ns := binary.BigEndian.Uint64(info.heartbeatInformation) if ns > math.MaxInt64 { // Malformed or future-unsafe value; ignore this heartbeat-ack. a.log.Warnf("[%s] HB RTT: timestamp overflows int64, ignoring", a.name) return } sentNanos := int64(ns) sent := time.Unix(0, sentNanos) now := time.Now() if !sent.IsZero() && !now.Before(sent) { rttMs := now.Sub(sent).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rttMs) a.srtt.Store(srtt) a.rackMinRTTWnd.Push(now, now.Sub(sent)) a.log.Tracef("[%s] HB RTT: measured=%.3fms srtt=%.3fms rto=%.3fms", a.name, rttMs, srtt, a.rtoMgr.getRTO()) } } } // The caller should hold the lock. func (a *Association) handleCookieEcho(cookieEcho *chunkCookieEcho) []*packet { state := a.getState() a.log.Debugf("[%s] COOKIE-ECHO received in state '%s'", a.name, getAssociationStateString(state)) if a.myCookie == nil { a.log.Debugf("[%s] COOKIE-ECHO received before initialization", a.name) return nil } switch state { default: return nil case established: if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { return nil } case closed, cookieWait, cookieEchoed: if !bytes.Equal(a.myCookie.cookie, cookieEcho.cookie) { return nil } // RFC wise, these do not seem to belong here, but removing them // causes TestCookieEchoRetransmission to break a.t1Init.stop() a.storedInit = nil a.t1Cookie.stop() a.storedCookieEcho = nil a.setState(established) if !a.completeHandshake(nil) { return nil } } p := &packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: []chunk{&chunkCookieAck{}}, } return pack(p) } // The caller should hold the lock. func (a *Association) handleCookieAck() { state := a.getState() a.log.Debugf("[%s] COOKIE-ACK received in state '%s'", a.name, getAssociationStateString(state)) if state != cookieEchoed { // RFC 4960 // 5.2.5. Handle Duplicate COOKIE-ACK. // At any state other than COOKIE-ECHOED, an endpoint should silently // discard a received COOKIE ACK chunk. return } a.t1Cookie.stop() a.storedCookieEcho = nil a.setState(established) a.completeHandshake(nil) } // The caller should hold the lock. func (a *Association) handleData(chunkPayload *chunkPayloadData) []*packet { a.log.Tracef("[%s] DATA: tsn=%d immediateSack=%v len=%d", a.name, chunkPayload.tsn, chunkPayload.immediateSack, len(chunkPayload.userData)) a.stats.incDATAs() canPush := a.payloadQueue.canPush(chunkPayload.tsn) if canPush { //nolint:nestif stream := a.getOrCreateStream(chunkPayload.streamIdentifier, true, PayloadTypeUnknown) if stream == nil { // silently discard the data. (sender will retry on T3-rtx timeout) // see pion/sctp#30 a.log.Debugf("[%s] discard %d", a.name, chunkPayload.streamSequenceNumber) return nil } if a.getMyReceiverWindowCredit() > 0 { // Pass the new chunk to stream level as soon as it arrives a.payloadQueue.push(chunkPayload.tsn) stream.handleData(chunkPayload) } else { // Receive buffer is full lastTSN, ok := a.payloadQueue.getLastTSNReceived() if ok && sna32LT(chunkPayload.tsn, lastTSN) { a.log.Debugf( "[%s] receive buffer full, but accepted as this is a missing chunk with tsn=%d ssn=%d", a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, ) a.payloadQueue.push(chunkPayload.tsn) stream.handleData(chunkPayload) } else { a.log.Debugf( "[%s] receive buffer full. dropping DATA with tsn=%d ssn=%d", a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, ) } } } // Upon the reception of a new DATA chunk, an endpoint shall examine the // continuity of the TSNs received. If the endpoint detects a gap in // the received DATA chunk sequence, it SHOULD send a SACK with Gap Ack // Blocks immediately. The data receiver continues sending a SACK after // receipt of each SCTP packet that doesn't fill the gap. // https://datatracker.ietf.org/doc/html/rfc4960#section-6.7 expectedTSN := a.peerLastTSN() + 1 gapDetected := sna32GT(chunkPayload.tsn, expectedTSN) sackNow := chunkPayload.immediateSack || gapDetected return a.handlePeerLastTSNAndAcknowledgement(sackNow) } // A common routine for handleData and handleForwardTSN routines // The caller should hold the lock. func (a *Association) handlePeerLastTSNAndAcknowledgement(sackImmediately bool) []*packet { //nolint:cyclop var reply []*packet // Try to advance peerLastTSN // From RFC 3758 Sec 3.6: // .. and then MUST further advance its cumulative TSN point locally // if possible // Meaning, if peerLastTSN+1 points to a chunk that is received, // advance peerLastTSN until peerLastTSN+1 points to unreceived chunk. for { if popOk := a.payloadQueue.pop(false); !popOk { break } for _, rstReq := range a.reconfigRequests { resp := a.resetStreamsIfAny(rstReq) if resp != nil { a.log.Debugf("[%s] RESET RESPONSE: %+v", a.name, resp) reply = append(reply, resp) } } } hasPacketLoss := (a.payloadQueue.size() > 0) if hasPacketLoss { a.log.Tracef("[%s] packetloss: %s", a.name, a.payloadQueue.getGapAckBlocksString()) } // RFC 4960 $6.7: SHOULD ack immediately when detecting a gap. if sackImmediately || hasPacketLoss || a.ackMode == ackModeNoDelay { a.immediateAckTriggered = true return reply } if a.ackMode == ackModeAlwaysDelay || (a.ackMode == ackModeNormal && a.ackState != ackStateImmediate) { if a.ackState == ackStateIdle { a.delayedAckTriggered = true } else { a.immediateAckTriggered = true } return reply } a.immediateAckTriggered = true return reply } // The caller should hold the lock. func (a *Association) getMyReceiverWindowCredit() uint32 { var bytesQueued uint32 for _, s := range a.streams { bytesQueued += uint32(s.getNumBytesInReassemblyQueue()) //nolint:gosec // G115 } if bytesQueued >= a.maxReceiveBufferSize { return 0 } return a.maxReceiveBufferSize - bytesQueued } // OpenStream opens a stream. func (a *Association) OpenStream( streamIdentifier uint16, defaultPayloadType PayloadProtocolIdentifier, ) (*Stream, error) { a.lock.Lock() defer a.lock.Unlock() switch a.getState() { case shutdownAckSent, shutdownPending, shutdownReceived, shutdownSent, closed: return nil, ErrAssociationClosed } return a.getOrCreateStream(streamIdentifier, false, defaultPayloadType), nil } // AcceptStream accepts a stream. func (a *Association) AcceptStream() (*Stream, error) { s, ok := <-a.acceptCh if !ok { return nil, io.EOF // no more incoming streams } return s, nil } // createStream creates a stream. The caller should hold the lock and check no stream exists for this id. func (a *Association) createStream(streamIdentifier uint16, accept bool) *Stream { stream := &Stream{ association: a, streamIdentifier: streamIdentifier, reassemblyQueue: newReassemblyQueue(streamIdentifier), log: a.log, name: fmt.Sprintf("%d:%s", streamIdentifier, a.name), writeDeadline: deadline.New(), } stream.readNotifier = sync.NewCond(&stream.lock) if accept { select { case a.acceptCh <- stream: a.streams[streamIdentifier] = stream a.log.Debugf("[%s] accepted a new stream (streamIdentifier: %d)", a.name, streamIdentifier) default: a.log.Debugf("[%s] dropped a new stream (acceptCh size: %d)", a.name, len(a.acceptCh)) return nil } } else { a.streams[streamIdentifier] = stream } return stream } // getOrCreateStream gets or creates a stream. The caller should hold the lock. func (a *Association) getOrCreateStream( streamIdentifier uint16, accept bool, defaultPayloadType PayloadProtocolIdentifier, ) *Stream { if s, ok := a.streams[streamIdentifier]; ok { s.SetDefaultPayloadType(defaultPayloadType) return s } s := a.createStream(streamIdentifier, accept) if s != nil { s.SetDefaultPayloadType(defaultPayloadType) } return s } // The caller should hold the lock. // //nolint:gocognit,cyclop func (a *Association) processSelectiveAck(selectiveAckChunk *chunkSelectiveAck) ( bytesAckedPerStream map[uint16]int, htna uint32, newestDeliveredSendTime time.Time, newestDeliveredOrigTSN uint32, deliveredFound bool, err error, ) { bytesAckedPerStream = map[uint16]int{} now := time.Now() // capture the time for this SACK // New ack point, so pop all ACKed packets from inflightQueue // We add 1 because the "currentAckPoint" has already been popped from the inflight queue // For the first SACK we take care of this by setting the ackpoint to cumAck - 1 for idx := a.cumulativeTSNAckPoint + 1; sna32LTE(idx, selectiveAckChunk.cumulativeTSNAck); idx++ { chunkPayload, ok := a.inflightQueue.pop(idx) if !ok { return nil, 0, time.Time{}, 0, false, fmt.Errorf("%w: %v", ErrInflightQueueTSNPop, idx) } // RACK: remove from xmit-time list since it's delivered a.rackRemove(chunkPayload) if !chunkPayload.acked { //nolint:nestif // RFC 4960 sec 6.3.2. Retransmission Timer Rules // R3) Whenever a SACK is received that acknowledges the DATA chunk // with the earliest outstanding TSN for that address, restart the // T3-rtx timer for that address with its current RTO (if there is // still outstanding data on that address). if idx == a.cumulativeTSNAckPoint+1 { // T3 timer needs to be reset. Stop it for now. a.t3RTX.stop() } nBytesAcked := len(chunkPayload.userData) // Sum the number of bytes acknowledged per stream if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked } else { bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked } // RFC 4960 sec 6.3.1. RTO Calculation // C4) When data is in flight and when allowed by rule C5 below, a new // RTT measurement MUST be made each round trip. Furthermore, new // RTT measurements SHOULD be made no more than once per round trip // for a given destination transport address. // C5) Karn's algorithm: RTT measurements MUST NOT be made using // packets that were retransmitted (and thus for which it is // ambiguous whether the reply was for the first instance of the // chunk or for a later instance) if sna32GTE(chunkPayload.tsn, a.minTSN2MeasureRTT) { // Only original transmissions for classic RTT measurement (Karn's rule) if chunkPayload.nSent == 1 { a.minTSN2MeasureRTT = a.myNextTSN rtt := now.Sub(chunkPayload.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.srtt.Store(srtt) // use a window to determine minRtt instead of a global min // as the RTT can fluctuate, which can cause problems if going from a // high RTT to a low RTT. a.rackMinRTTWnd.Push(now, now.Sub(chunkPayload.since)) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", a.name, rtt, srtt, a.rtoMgr.getRTO()) } } // RFC 8985 (RACK) sec 5.2: RACK.segment is the most recently sent // segment that has been delivered, including retransmissions. if chunkPayload.since.After(newestDeliveredSendTime) { newestDeliveredSendTime = chunkPayload.since newestDeliveredOrigTSN = chunkPayload.tsn deliveredFound = true } if a.inFastRecovery && chunkPayload.tsn == a.fastRecoverExitPoint { a.log.Debugf("[%s] exit fast-recovery", a.name) a.inFastRecovery = false } } } htna = selectiveAckChunk.cumulativeTSNAck // Mark selectively acknowledged chunks as "acked" for _, g := range selectiveAckChunk.gapAckBlocks { for i := g.start; i <= g.end; i++ { tsn := selectiveAckChunk.cumulativeTSNAck + uint32(i) chunkPayload, ok := a.inflightQueue.get(tsn) if !ok { return nil, 0, time.Time{}, 0, false, fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) } // RACK: remove from xmit-time list since it's delivered a.rackRemove(chunkPayload) if !chunkPayload.acked { //nolint:nestif nBytesAcked := a.inflightQueue.markAsAcked(tsn) // Sum the number of bytes acknowledged per stream if amount, ok := bytesAckedPerStream[chunkPayload.streamIdentifier]; ok { bytesAckedPerStream[chunkPayload.streamIdentifier] = amount + nBytesAcked } else { bytesAckedPerStream[chunkPayload.streamIdentifier] = nBytesAcked } a.log.Tracef("[%s] tsn=%d has been sacked", a.name, chunkPayload.tsn) // RTT / RTO and RACK updates if sna32GTE(chunkPayload.tsn, a.minTSN2MeasureRTT) { // Only original transmissions for classic RTT measurement if chunkPayload.nSent == 1 { a.minTSN2MeasureRTT = a.myNextTSN rtt := now.Sub(chunkPayload.since).Seconds() * 1000.0 srtt := a.rtoMgr.setNewRTT(rtt) a.srtt.Store(srtt) a.rackMinRTTWnd.Push(now, now.Sub(chunkPayload.since)) a.log.Tracef("[%s] SACK: measured-rtt=%f srtt=%f new-rto=%f", a.name, rtt, srtt, a.rtoMgr.getRTO()) } } if chunkPayload.since.After(newestDeliveredSendTime) { newestDeliveredSendTime = chunkPayload.since newestDeliveredOrigTSN = chunkPayload.tsn deliveredFound = true } } if sna32LT(htna, tsn) { htna = tsn } } } return bytesAckedPerStream, htna, newestDeliveredSendTime, newestDeliveredOrigTSN, deliveredFound, nil } // The caller should hold the lock. func (a *Association) onCumulativeTSNAckPointAdvanced(totalBytesAcked int) { // RFC 4960, sec 6.3.2. Retransmission Timer Rules // R2) Whenever all outstanding data sent to an address have been // acknowledged, turn off the T3-rtx timer of that address. if a.inflightQueue.size() == 0 { a.log.Tracef("[%s] SACK: no more packet in-flight (pending=%d)", a.name, a.pendingQueue.size()) a.t3RTX.stop() a.stopPTOTimer() a.stopRackTimer() } else { a.log.Tracef("[%s] T3-rtx timer start (pt2)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) } // Update congestion control parameters if a.CWND() <= a.ssthresh { //nolint:nestif // RFC 4960, sec 7.2.1. Slow-Start // o When cwnd is less than or equal to ssthresh, an SCTP endpoint MUST // use the slow-start algorithm to increase cwnd only if the current // congestion window is being fully utilized, an incoming SACK // advances the Cumulative TSN Ack Point, and the data sender is not // in Fast Recovery. Only when these three conditions are met can // the cwnd be increased; otherwise, the cwnd MUST not be increased. // If these conditions are met, then cwnd MUST be increased by, at // most, the lesser of 1) the total size of the previously // outstanding DATA chunk(s) acknowledged, and 2) the destination's // path MTU. if !a.inFastRecovery && a.pendingQueue.size() > 0 { a.setCWND(a.CWND() + min32(uint32(totalBytesAcked), a.CWND())) //nolint:gosec // G115 // a.cwnd += min32(uint32(totalBytesAcked), a.MTU()) // SCTP way (slow) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (SS)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) } else { a.log.Tracef("[%s] cwnd did not grow: cwnd=%d ssthresh=%d acked=%d FR=%v pending=%d", a.name, a.CWND(), a.ssthresh, totalBytesAcked, a.inFastRecovery, a.pendingQueue.size()) } } else { // RFC 4960, sec 7.2.2. Congestion Avoidance // o Whenever cwnd is greater than ssthresh, upon each SACK arrival // that advances the Cumulative TSN Ack Point, increase // partial_bytes_acked by the total number of bytes of all new chunks // acknowledged in that SACK including chunks acknowledged by the new // Cumulative TSN Ack and by Gap Ack Blocks. a.partialBytesAcked += uint32(totalBytesAcked) //nolint:gosec // G115 // o When partial_bytes_acked is equal to or greater than cwnd and // before the arrival of the SACK the sender had cwnd or more bytes // of data outstanding (i.e., before arrival of the SACK, flight size // was greater than or equal to cwnd), increase cwnd by MTU, and // reset partial_bytes_acked to (partial_bytes_acked - cwnd). if a.partialBytesAcked >= a.CWND() && a.pendingQueue.size() > 0 { a.partialBytesAcked -= a.CWND() step := max(a.MTU(), a.cwndCAStep) a.setCWND(a.CWND() + step) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d acked=%d (CA)", a.name, a.CWND(), a.ssthresh, totalBytesAcked) } } } // The caller should hold the lock. // //nolint:cyclop func (a *Association) processFastRetransmission( //nolint:gocognit cumTSNAckPoint uint32, gapAckBlocks []gapAckBlock, htna uint32, cumTSNAckPointAdvanced bool, ) error { // HTNA algorithm - RFC 4960 Sec 7.2.4 // Increment missIndicator of each chunks that the SACK reported missing // when either of the following is met: // a) Not in fast-recovery // miss indications are incremented only for missing TSNs prior to the // highest TSN newly acknowledged in the SACK. // b) In fast-recovery AND the Cumulative TSN Ack Point advanced // the miss indications are incremented for all TSNs reported missing // in the SACK. //nolint:nestif if !a.inFastRecovery || (a.inFastRecovery && cumTSNAckPointAdvanced) { var maxTSN uint32 if !a.inFastRecovery { // a) increment only for missing TSNs prior to the HTNA maxTSN = htna } else { // b) increment for all TSNs reported missing maxTSN = cumTSNAckPoint if len(gapAckBlocks) > 0 { maxTSN += uint32(gapAckBlocks[len(gapAckBlocks)-1].end) } } for tsn := cumTSNAckPoint + 1; sna32LT(tsn, maxTSN); tsn++ { c, ok := a.inflightQueue.get(tsn) if !ok { return fmt.Errorf("%w: %v", ErrTSNRequestNotExist, tsn) } if !c.acked && !c.abandoned() && c.missIndicator < 3 { c.missIndicator++ if c.missIndicator == 3 { if a.tlrActive { a.tlrApplyAdditionalLossLocked(time.Now()) } if !a.inFastRecovery { // 2) If not in Fast Recovery, adjust the ssthresh and cwnd of the // destination address(es) to which the missing DATA chunks were // last sent, according to the formula described in Section 7.2.3. a.inFastRecovery = true a.fastRecoverExitPoint = htna a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) a.setCWND(a.ssthresh) a.partialBytesAcked = 0 a.willRetransmitFast = true a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (FR)", a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) } } } } } if a.inFastRecovery && cumTSNAckPointAdvanced { a.willRetransmitFast = true } return nil } // The caller should hold the lock. // //nolint:cyclop func (a *Association) handleSack(selectiveAckChunk *chunkSelectiveAck) error { a.log.Tracef( "[%s] SACK: cumTSN=%d a_rwnd=%d", a.name, selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.advertisedReceiverWindowCredit, ) state := a.getState() if state != established && state != shutdownPending && state != shutdownReceived { return nil } a.stats.incSACKsReceived() if sna32GT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { // RFC 4960 sec 6.2.1. Processing a Received SACK // D) // i) If Cumulative TSN Ack is less than the Cumulative TSN Ack // Point, then drop the SACK. Since Cumulative TSN Ack is // monotonically increasing, a SACK whose Cumulative TSN Ack is // less than the Cumulative TSN Ack Point indicates an out-of- // order SACK. a.log.Debugf("[%s] SACK Cumulative ACK %v is older than ACK point %v", a.name, selectiveAckChunk.cumulativeTSNAck, a.cumulativeTSNAckPoint) return nil } // Process selective ack bytesAckedPerStream, htna, newestDeliveredSendTime, newestDeliveredOrigTSN, deliveredFound, err := a.processSelectiveAck(selectiveAckChunk) if err != nil { return err } var totalBytesAcked int for _, nBytesAcked := range bytesAckedPerStream { totalBytesAcked += nBytesAcked } cumTSNAckPointAdvanced := false if sna32LT(a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) { a.log.Tracef("[%s] SACK: cumTSN advanced: %d -> %d", a.name, a.cumulativeTSNAckPoint, selectiveAckChunk.cumulativeTSNAck) a.cumulativeTSNAckPoint = selectiveAckChunk.cumulativeTSNAck cumTSNAckPointAdvanced = true a.onCumulativeTSNAckPointAdvanced(totalBytesAcked) } for si, nBytesAcked := range bytesAckedPerStream { if s, ok := a.streams[si]; ok { a.lock.Unlock() s.onBufferReleased(nBytesAcked) a.lock.Lock() } } // New rwnd value // RFC 4960 sec 6.2.1. Processing a Received SACK // D) // ii) Set rwnd equal to the newly received a_rwnd minus the number // of bytes still outstanding after processing the Cumulative // TSN Ack and the Gap Ack Blocks. // bytes acked were already subtracted by markAsAcked() method bytesOutstanding := uint32(a.inflightQueue.getNumBytes()) //nolint:gosec // G115 if bytesOutstanding >= selectiveAckChunk.advertisedReceiverWindowCredit { a.setRWND(0) } else { a.setRWND(selectiveAckChunk.advertisedReceiverWindowCredit - bytesOutstanding) } err = a.processFastRetransmission( selectiveAckChunk.cumulativeTSNAck, selectiveAckChunk.gapAckBlocks, htna, cumTSNAckPointAdvanced, ) if err != nil { return err } if a.useForwardTSN { // RFC 3758 Sec 3.5 C1 if sna32LT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.advancedPeerTSNAckPoint = a.cumulativeTSNAckPoint } // RFC 3758 Sec 3.5 C2 for i := a.advancedPeerTSNAckPoint + 1; ; i++ { c, ok := a.inflightQueue.get(i) if !ok { break } if !c.abandoned() { break } a.advancedPeerTSNAckPoint = i } // RFC 3758 Sec 3.5 C3 if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.willSendForwardTSN = true } a.awakeWriteLoop() } a.postprocessSack(state, cumTSNAckPointAdvanced) // RACK a.onRackAfterSACK(deliveredFound, newestDeliveredSendTime, newestDeliveredOrigTSN, selectiveAckChunk) // adaptive burst mitigation ackProgress := cumTSNAckPointAdvanced || deliveredFound a.tlrMaybeFinishLocked(ackProgress) return nil } // The caller must hold the lock. This method was only added because the // linter was complaining about the "cognitive complexity" of handleSack. func (a *Association) postprocessSack(state uint32, shouldAwakeWriteLoop bool) { switch { case a.inflightQueue.size() > 0: // Start timer. (noop if already started) a.log.Tracef("[%s] T3-rtx timer start (pt3)", a.name) a.t3RTX.start(a.rtoMgr.getRTO()) case state == shutdownPending: // No more outstanding, send shutdown. shouldAwakeWriteLoop = true a.willSendShutdown = true a.setState(shutdownSent) case state == shutdownReceived: // No more outstanding, send shutdown ack. shouldAwakeWriteLoop = true a.willSendShutdownAck = true a.setState(shutdownAckSent) } if shouldAwakeWriteLoop { a.awakeWriteLoop() } } // The caller should hold the lock. func (a *Association) handleShutdown(_ *chunkShutdown) { state := a.getState() switch state { case established: if a.inflightQueue.size() > 0 { a.setState(shutdownReceived) } else { // No more outstanding, send shutdown ack. a.willSendShutdownAck = true a.setState(shutdownAckSent) a.awakeWriteLoop() } // a.cumulativeTSNAckPoint = c.cumulativeTSNAck case shutdownSent: a.willSendShutdownAck = true a.setState(shutdownAckSent) a.awakeWriteLoop() } } // The caller should hold the lock. func (a *Association) handleShutdownAck(_ *chunkShutdownAck) { state := a.getState() if state == shutdownSent || state == shutdownAckSent { a.t2Shutdown.stop() a.willSendShutdownComplete = true a.awakeWriteLoop() } } func (a *Association) handleShutdownComplete(_ *chunkShutdownComplete) error { state := a.getState() if state == shutdownAckSent { a.t2Shutdown.stop() return a.close() } return nil } func (a *Association) handleAbort(c *chunkAbort) error { var errStr string for _, e := range c.errorCauses { errStr += fmt.Sprintf("(%s)", e) } _ = a.close() return fmt.Errorf("[%s] %w: %s", a.name, ErrChunk, errStr) } // createForwardTSN generates ForwardTSN chunk. // This method will be be called if useForwardTSN is set to false. // The caller should hold the lock. func (a *Association) createForwardTSN() *chunkForwardTSN { // RFC 3758 Sec 3.5 C4 streamMap := map[uint16]uint16{} // to report only once per SI for i := a.cumulativeTSNAckPoint + 1; sna32LTE(i, a.advancedPeerTSNAckPoint); i++ { c, ok := a.inflightQueue.get(i) if !ok { break } ssn, ok := streamMap[c.streamIdentifier] if !ok { streamMap[c.streamIdentifier] = c.streamSequenceNumber } else if sna16LT(ssn, c.streamSequenceNumber) { // to report only once with greatest SSN streamMap[c.streamIdentifier] = c.streamSequenceNumber } } fwdtsn := &chunkForwardTSN{ newCumulativeTSN: a.advancedPeerTSNAckPoint, streams: []chunkForwardTSNStream{}, } var streamStr string for si, ssn := range streamMap { streamStr += fmt.Sprintf("(si=%d ssn=%d)", si, ssn) fwdtsn.streams = append(fwdtsn.streams, chunkForwardTSNStream{ identifier: si, sequence: ssn, }) } a.log.Tracef( "[%s] building fwdtsn: newCumulativeTSN=%d cumTSN=%d - %s", a.name, fwdtsn.newCumulativeTSN, a.cumulativeTSNAckPoint, streamStr, ) return fwdtsn } // createPacket wraps chunks in a packet. // The caller should hold the read lock. func (a *Association) createPacket(cs []chunk) *packet { return &packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: cs, } } // The caller should hold the lock. func (a *Association) handleReconfig(reconfigChunk *chunkReconfig) ([]*packet, error) { a.log.Tracef("[%s] handleReconfig", a.name) pp := make([]*packet, 0) pkt, err := a.handleReconfigParam(reconfigChunk.paramA) if err != nil { return nil, err } if pkt != nil { pp = append(pp, pkt) } if reconfigChunk.paramB != nil { pkt, err = a.handleReconfigParam(reconfigChunk.paramB) if err != nil { return nil, err } if pkt != nil { pp = append(pp, pkt) } } return pp, nil } // The caller should hold the lock. func (a *Association) handleForwardTSN(chunkTSN *chunkForwardTSN) []*packet { a.log.Tracef("[%s] FwdTSN: %s", a.name, chunkTSN.String()) if !a.useForwardTSN { a.log.Warn("[%s] received FwdTSN but not enabled") // Return an error chunk cerr := &chunkError{ errorCauses: []errorCause{&errorCauseUnrecognizedChunkType{}}, } outbound := &packet{} outbound.verificationTag = a.peerVerificationTag outbound.sourcePort = a.sourcePort outbound.destinationPort = a.destinationPort outbound.chunks = []chunk{cerr} return []*packet{outbound} } // From RFC 3758 Sec 3.6: // Note, if the "New Cumulative TSN" value carried in the arrived // FORWARD TSN chunk is found to be behind or at the current cumulative // TSN point, the data receiver MUST treat this FORWARD TSN as out-of- // date and MUST NOT update its Cumulative TSN. The receiver SHOULD // send a SACK to its peer (the sender of the FORWARD TSN) since such a // duplicate may indicate the previous SACK was lost in the network. a.log.Tracef("[%s] should send ack? newCumTSN=%d peerLastTSN=%d", a.name, chunkTSN.newCumulativeTSN, a.peerLastTSN()) if sna32LTE(chunkTSN.newCumulativeTSN, a.peerLastTSN()) { a.log.Tracef("[%s] sending ack on Forward TSN", a.name) a.ackState = ackStateImmediate a.ackTimer.stop() a.awakeWriteLoop() return nil } // From RFC 3758 Sec 3.6: // the receiver MUST perform the same TSN handling, including duplicate // detection, gap detection, SACK generation, cumulative TSN // advancement, etc. as defined in RFC 2960 [2]---with the following // exceptions and additions. // When a FORWARD TSN chunk arrives, the data receiver MUST first update // its cumulative TSN point to the value carried in the FORWARD TSN // chunk, // Advance peerLastTSN for sna32LT(a.peerLastTSN(), chunkTSN.newCumulativeTSN) { a.payloadQueue.pop(true) // may not exist } // Report new peerLastTSN value and abandoned largest SSN value to // corresponding streams so that the abandoned chunks can be removed // from the reassemblyQueue. for _, forwarded := range chunkTSN.streams { if s, ok := a.streams[forwarded.identifier]; ok { s.handleForwardTSNForOrdered(forwarded.sequence) } } // TSN may be forewared for unordered chunks. ForwardTSN chunk does not // report which stream identifier it skipped for unordered chunks. // Therefore, we need to broadcast this event to all existing streams for // unordered chunks. // See https://github.com/pion/sctp/issues/106 for _, s := range a.streams { s.handleForwardTSNForUnordered(chunkTSN.newCumulativeTSN) } return a.handlePeerLastTSNAndAcknowledgement(false) } func (a *Association) sendResetRequest(streamIdentifier uint16) error { a.lock.Lock() defer a.lock.Unlock() state := a.getState() if state != established { return fmt.Errorf("%w: state=%s", ErrResetPacketInStateNotExist, getAssociationStateString(state)) } // Create DATA chunk which only contains valid stream identifier with // nil userData and use it as a EOS from the stream. c := &chunkPayloadData{ streamIdentifier: streamIdentifier, beginningFragment: true, endingFragment: true, userData: nil, } a.pendingQueue.push(c) a.awakeWriteLoop() return nil } // The caller should hold the lock. func (a *Association) handleReconfigParam(raw param) (*packet, error) { switch par := raw.(type) { case *paramOutgoingResetRequest: a.log.Tracef("[%s] handleReconfigParam (OutgoingResetRequest)", a.name) if a.peerLastTSN() < par.senderLastTSN && len(a.reconfigRequests) >= maxReconfigRequests { // We have too many reconfig requests outstanding. Drop the request and let // the peer retransmit. A well behaved peer should only have 1 outstanding // reconfig request. // // RFC 6525: https://www.rfc-editor.org/rfc/rfc6525.html#section-5.1.1 // At any given time, there MUST NOT be more than one request in flight. // So, if the Re-configuration Timer is running and the RE-CONFIG chunk // contains at least one request parameter, the chunk MUST be buffered. // chrome: // https://chromium.googlesource.com/external/webrtc/+/refs/heads/main/net/dcsctp/socket/stream_reset_handler.cc#271 return nil, fmt.Errorf("%w: %d", ErrTooManyReconfigRequests, len(a.reconfigRequests)) } a.reconfigRequests[par.reconfigRequestSequenceNumber] = par resp := a.resetStreamsIfAny(par) if resp != nil { return resp, nil } return nil, nil //nolint:nilnil case *paramReconfigResponse: a.log.Tracef("[%s] handleReconfigParam (ReconfigResponse)", a.name) if par.result == reconfigResultInProgress { // RFC 6525: https://www.rfc-editor.org/rfc/rfc6525.html#section-5.2.7 // // If the Result field indicates "In progress", the timer for the // Re-configuration Request Sequence Number is started again. If // the timer runs out, the RE-CONFIG chunk MUST be retransmitted // but the corresponding error counters MUST NOT be incremented. if _, ok := a.reconfigs[par.reconfigResponseSequenceNumber]; ok { a.tReconfig.stop() a.tReconfig.start(a.rtoMgr.getRTO()) } return nil, nil //nolint:nilnil } delete(a.reconfigs, par.reconfigResponseSequenceNumber) if len(a.reconfigs) == 0 { a.tReconfig.stop() } return nil, nil //nolint:nilnil default: return nil, fmt.Errorf("%w: %t", ErrParamterType, par) } } // The caller should hold the lock. func (a *Association) resetStreamsIfAny(resetRequest *paramOutgoingResetRequest) *packet { result := reconfigResultSuccessPerformed if sna32LTE(resetRequest.senderLastTSN, a.peerLastTSN()) { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d <= peerLastTSN=%d", a.name, resetRequest.senderLastTSN, a.peerLastTSN()) for _, id := range resetRequest.streamIdentifiers { s, ok := a.streams[id] if !ok { continue } a.lock.Unlock() s.onInboundStreamReset() a.lock.Lock() a.log.Debugf("[%s] deleting stream %d", a.name, id) delete(a.streams, s.streamIdentifier) } delete(a.reconfigRequests, resetRequest.reconfigRequestSequenceNumber) } else { a.log.Debugf("[%s] resetStream(): senderLastTSN=%d > peerLastTSN=%d", a.name, resetRequest.senderLastTSN, a.peerLastTSN()) result = reconfigResultInProgress } return a.createPacket([]chunk{&chunkReconfig{ paramA: ¶mReconfigResponse{ reconfigResponseSequenceNumber: resetRequest.reconfigRequestSequenceNumber, result: result, }, }}) } // Move the chunk peeked with a.pendingQueue.peek() to the inflightQueue. // The caller should hold the lock. func (a *Association) movePendingDataChunkToInflightQueue(chunkPayload *chunkPayloadData) { if err := a.pendingQueue.pop(chunkPayload); err != nil { a.log.Errorf("[%s] failed to pop from pending queue: %s", a.name, err.Error()) } if chunkPayload.endingFragment { chunkPayload.setAllInflight() } // Assign TSN and original send time chunkPayload.tsn = a.generateNextTSN() chunkPayload.since = time.Now() chunkPayload.nSent = 1 a.checkPartialReliabilityStatus(chunkPayload) a.log.Tracef( "[%s] sending ppi=%d tsn=%d ssn=%d sent=%d len=%d (%v,%v)", a.name, chunkPayload.payloadType, chunkPayload.tsn, chunkPayload.streamSequenceNumber, chunkPayload.nSent, len(chunkPayload.userData), chunkPayload.beginningFragment, chunkPayload.endingFragment, ) a.inflightQueue.pushNoCheck(chunkPayload) // RACK: track outstanding original transmissions by send time. a.rackInsert(chunkPayload) } // popPendingDataChunksToSend pops chunks from the pending queues as many as // the cwnd and rwnd allows to send. // The caller should hold the lock. // //nolint:cyclop func (a *Association) popPendingDataChunksToSend( //nolint:cyclop,gocognit budgetScaled *int64, consumed *bool, ) ([]*chunkPayloadData, []uint16) { chunks := []*chunkPayloadData{} var sisToReset []uint16 // stream indentifiers to reset // track current packet size for MTU bundling so budgeting is accurate. bytesInPacket := 0 if a.pendingQueue.size() > 0 { //nolint:nestif // RFC 4960 sec 6.1. Transmission of DATA Chunks // A) At any given time, the data sender MUST NOT transmit new data to // any destination transport address if its peer's rwnd indicates // that the peer has no buffer space (i.e., rwnd is 0; see Section // 6.2.1). However, regardless of the value of rwnd (including if it // is 0), the data sender can always have one DATA chunk in flight to // the receiver if allowed by cwnd (see rule B, below). for { chunkPayload := a.pendingQueue.peek() if chunkPayload == nil { break // no more pending data } dataLen := uint32(len(chunkPayload.userData)) //nolint:gosec // G115 if dataLen == 0 { sisToReset = append(sisToReset, chunkPayload.streamIdentifier) err := a.pendingQueue.pop(chunkPayload) if err != nil { a.log.Errorf("failed to pop from pending queue: %s", err.Error()) } continue } if uint32(a.inflightQueue.getNumBytes())+dataLen > a.CWND() { //nolint:gosec // G115 break // would exceeds cwnd } if dataLen > a.RWND() { break // no more rwnd } // compute current DATA chunk size including padding. chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) chunkBytes += getPadding(chunkBytes) // ensure MTU bundling matches bundleDataChunksIntoPackets(). addBytes := chunkBytes if bytesInPacket == 0 { addBytes += int(commonHeaderSize) if addBytes > int(a.MTU()) { break } // reserve budget for common header + first chunk. if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { break } bytesInPacket = int(commonHeaderSize) } else { // if it doesn't fit, start a new packet and retry same chunk. if bytesInPacket+chunkBytes > int(a.MTU()) { bytesInPacket = 0 continue } // reserve budget for the additional chunk bytes. if !a.tlrAllowSendLocked(budgetScaled, consumed, chunkBytes) { break } } a.setRWND(a.RWND() - dataLen) a.movePendingDataChunkToInflightQueue(chunkPayload) chunks = append(chunks, chunkPayload) bytesInPacket += chunkBytes } // allow one DATA chunk if nothing is inflight to the receiver. if len(chunks) == 0 && a.inflightQueue.size() == 0 { // Send zero window probe c := a.pendingQueue.peek() if c != nil && len(c.userData) > 0 { // probe is a new packet: common header + chunk bytes. chunkBytes := int(dataChunkHeaderSize) + len(c.userData) chunkBytes += getPadding(chunkBytes) addBytes := int(commonHeaderSize) + chunkBytes if addBytes <= int(a.MTU()) && a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { a.movePendingDataChunkToInflightQueue(c) chunks = append(chunks, c) } } } } if a.blockWrite && len(chunks) > 0 && a.pendingQueue.size() == 0 { a.log.Tracef("[%s] all pending data have been sent, notify writable", a.name) a.notifyBlockWritable() } return chunks, sisToReset } // bundleDataChunksIntoPackets packs DATA chunks into packets. It tries to bundle // DATA chunks into a packet so long as the resulting packet size does not exceed // the path MTU. // The caller should hold the lock. func (a *Association) bundleDataChunksIntoPackets(chunks []*chunkPayloadData) []*packet { packets := []*packet{} chunksToSend := []chunk{} bytesInPacket := int(commonHeaderSize) for _, chunkPayload := range chunks { // RFC 4960 sec 6.1. Transmission of DATA Chunks // Multiple DATA chunks committed for transmission MAY be bundled in a // single packet. Furthermore, DATA chunks being retransmitted MAY be // bundled with new DATA chunks, as long as the resulting packet size // does not exceed the path MTU. chunkSizeInPacket := int(dataChunkHeaderSize) + len(chunkPayload.userData) chunkSizeInPacket += getPadding(chunkSizeInPacket) if bytesInPacket+chunkSizeInPacket > int(a.MTU()) { packets = append(packets, a.createPacket(chunksToSend)) chunksToSend = []chunk{} bytesInPacket = int(commonHeaderSize) } chunksToSend = append(chunksToSend, chunkPayload) bytesInPacket += chunkSizeInPacket } if len(chunksToSend) > 0 { packets = append(packets, a.createPacket(chunksToSend)) } return packets } // sendPayloadData sends the data chunks. func (a *Association) sendPayloadData(ctx context.Context, chunks []*chunkPayloadData) error { a.lock.Lock() state := a.getState() if state != established { a.lock.Unlock() return fmt.Errorf("%w: state=%s", ErrPayloadDataStateNotExist, getAssociationStateString(state)) } if a.blockWrite { for a.writePending { a.lock.Unlock() select { case <-ctx.Done(): return ctx.Err() case <-a.writeNotify: a.lock.Lock() } } a.writePending = true } // Push the chunks into the pending queue first. for _, c := range chunks { a.pendingQueue.push(c) } a.lock.Unlock() a.awakeWriteLoop() return nil } // The caller should hold the lock. func (a *Association) checkPartialReliabilityStatus(chunkPayload *chunkPayloadData) { if !a.useForwardTSN { return } // draft-ietf-rtcweb-data-protocol-09.txt section 6 // 6. Procedures // All Data Channel Establishment Protocol messages MUST be sent using // ordered delivery and reliable transmission. // if chunkPayload.payloadType == PayloadTypeWebRTCDCEP { return } // PR-SCTP if stream, ok := a.streams[chunkPayload.streamIdentifier]; ok { //nolint:nestif stream.lock.RLock() if stream.reliabilityType == ReliabilityTypeRexmit { if chunkPayload.nSent >= stream.reliabilityValue { chunkPayload.setAbandoned(true) a.rackRemove(chunkPayload) a.log.Tracef( "[%s] marked as abandoned: tsn=%d ppi=%d (remix: %d)", a.name, chunkPayload.tsn, chunkPayload.payloadType, chunkPayload.nSent, ) } } else if stream.reliabilityType == ReliabilityTypeTimed { elapsed := int64(time.Since(chunkPayload.since).Seconds() * 1000) if elapsed >= int64(stream.reliabilityValue) { chunkPayload.setAbandoned(true) a.rackRemove(chunkPayload) a.log.Tracef( "[%s] marked as abandoned: tsn=%d ppi=%d (timed: %d)", a.name, chunkPayload.tsn, chunkPayload.payloadType, elapsed, ) } } stream.lock.RUnlock() } else { // Remote has reset its send side of the stream, we can still send data. a.log.Tracef("[%s] stream %d not found, remote reset", a.name, chunkPayload.streamIdentifier) } } // getDataPacketsToRetransmit is called when T3-rtx is timed out and retransmit outstanding data chunks // that are not acked or abandoned yet. // The caller should hold the lock. func (a *Association) getDataPacketsToRetransmit(budgetScaled *int64, consumed *bool) []*packet { //nolint:cyclop awnd := min32(a.CWND(), a.RWND()) chunks := []*chunkPayloadData{} var bytesToSend int currRtxTimestamp := time.Now() bytesInPacket := 0 for i := 0; ; i++ { chunkPayload, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + uint32(i) + 1) //nolint:gosec // G115 if !ok { break // end of pending data } if !chunkPayload.retransmit { continue } if i == 0 && int(a.RWND()) < len(chunkPayload.userData) { // allow as zero window probe } else if bytesToSend+len(chunkPayload.userData) > int(awnd) { break } chunkBytes := int(dataChunkHeaderSize) + len(chunkPayload.userData) chunkBytes += getPadding(chunkBytes) // retry as first chunk in a new packet if needed. for { addBytes := chunkBytes if bytesInPacket == 0 { addBytes += int(commonHeaderSize) if addBytes > int(a.MTU()) { return a.bundleDataChunksIntoPackets(chunks) } } else if bytesInPacket+chunkBytes > int(a.MTU()) { bytesInPacket = 0 continue } // burst budget gate before mutating the chunk. if !a.tlrAllowSendLocked(budgetScaled, consumed, addBytes) { return a.bundleDataChunksIntoPackets(chunks) } if bytesInPacket == 0 { bytesInPacket = int(commonHeaderSize) } bytesInPacket += chunkBytes break } chunkPayload.retransmit = false bytesToSend += len(chunkPayload.userData) // Update for retransmission chunkPayload.nSent++ chunkPayload.since = currRtxTimestamp a.rackRemove(chunkPayload) a.rackInsert(chunkPayload) a.checkPartialReliabilityStatus(chunkPayload) a.log.Tracef( "[%s] retransmitting tsn=%d ssn=%d sent=%d", a.name, chunkPayload.tsn, chunkPayload.streamSequenceNumber, chunkPayload.nSent, ) chunks = append(chunks, chunkPayload) } return a.bundleDataChunksIntoPackets(chunks) } // generateNextTSN returns the myNextTSN and increases it. The caller should hold the lock. // The caller should hold the lock. func (a *Association) generateNextTSN() uint32 { tsn := a.myNextTSN a.myNextTSN++ return tsn } // generateNextRSN returns the myNextRSN and increases it. The caller should hold the lock. // The caller should hold the lock. func (a *Association) generateNextRSN() uint32 { rsn := a.myNextRSN a.myNextRSN++ return rsn } func (a *Association) createSelectiveAckChunk() *chunkSelectiveAck { sack := &chunkSelectiveAck{} sack.cumulativeTSNAck = a.peerLastTSN() sack.advertisedReceiverWindowCredit = a.getMyReceiverWindowCredit() sack.duplicateTSN = a.payloadQueue.popDuplicates() sack.gapAckBlocks = a.payloadQueue.getGapAckBlocks() return sack } func pack(p *packet) []*packet { return []*packet{p} } func (a *Association) handleChunksStart() { a.lock.Lock() defer a.lock.Unlock() a.stats.incPacketsReceived() a.delayedAckTriggered = false a.immediateAckTriggered = false } func (a *Association) handleChunksEnd() { a.lock.Lock() defer a.lock.Unlock() if a.immediateAckTriggered { a.ackState = ackStateImmediate a.ackTimer.stop() a.awakeWriteLoop() } else if a.delayedAckTriggered { // Will send delayed ack in the next ack timeout a.ackState = ackStateDelay a.ackTimer.start() } } func (a *Association) handleChunk(receivedPacket *packet, receivedChunk chunk) error { //nolint:cyclop a.lock.Lock() defer a.lock.Unlock() var packets []*packet var err error if _, err = receivedChunk.check(); err != nil { a.log.Errorf("[%s] failed validating chunk: %s ", a.name, err) return nil } isAbort := false switch receivedChunk := receivedChunk.(type) { // Note: We do not do the following for chunkInit, chunkInitAck, and chunkCookieEcho: // If an endpoint receives an INIT, INIT ACK, or COOKIE ECHO chunk but decides not to establish the // new association due to missing mandatory parameters in the received INIT or INIT ACK chunk, invalid // parameter values, or lack of local resources, it SHOULD respond with an ABORT chunk. case *chunkInit: packets, err = a.handleInit(receivedPacket, receivedChunk) case *chunkInitAck: err = a.handleInitAck(receivedPacket, receivedChunk) case *chunkAbort: isAbort = true err = a.handleAbort(receivedChunk) case *chunkError: var errStr string for _, e := range receivedChunk.errorCauses { errStr += fmt.Sprintf("(%s)", e) } a.log.Debugf("[%s] Error chunk, with following errors: %s", a.name, errStr) case *chunkHeartbeat: packets = a.handleHeartbeat(receivedChunk) case *chunkHeartbeatAck: a.handleHeartbeatAck(receivedChunk) case *chunkCookieEcho: packets = a.handleCookieEcho(receivedChunk) case *chunkCookieAck: a.handleCookieAck() case *chunkPayloadData: packets = a.handleData(receivedChunk) case *chunkSelectiveAck: err = a.handleSack(receivedChunk) case *chunkReconfig: packets, err = a.handleReconfig(receivedChunk) case *chunkForwardTSN: packets = a.handleForwardTSN(receivedChunk) case *chunkShutdown: a.handleShutdown(receivedChunk) case *chunkShutdownAck: a.handleShutdownAck(receivedChunk) case *chunkShutdownComplete: err = a.handleShutdownComplete(receivedChunk) default: err = ErrChunkTypeUnhandled } // Log and return, the only condition that is fatal is a ABORT chunk if err != nil { if isAbort { return err } a.log.Errorf("Failed to handle chunk: %v", err) return nil } if len(packets) > 0 { a.controlQueue.pushAll(packets) a.awakeWriteLoop() } return nil } func (a *Association) onRetransmissionTimeout(id int, nRtos uint) { //nolint:cyclop a.lock.Lock() defer a.lock.Unlock() if id == timerT1Init { err := a.sendInit() if err != nil { a.log.Debugf("[%s] failed to retransmit init (nRtos=%d): %v", a.name, nRtos, err) } return } if id == timerT1Cookie { err := a.sendCookieEcho() if err != nil { a.log.Debugf("[%s] failed to retransmit cookie-echo (nRtos=%d): %v", a.name, nRtos, err) } return } if id == timerT2Shutdown { a.log.Debugf("[%s] retransmission of shutdown timeout (nRtos=%d): %v", a.name, nRtos) state := a.getState() switch state { case shutdownSent: a.willSendShutdown = true a.awakeWriteLoop() case shutdownAckSent: a.willSendShutdownAck = true a.awakeWriteLoop() } } if id == timerT3RTX { //nolint:nestif a.stats.incT3Timeouts() // RFC 4960 sec 6.3.3 // E1) For the destination address for which the timer expires, adjust // its ssthresh with rules defined in Section 7.2.3 and set the // cwnd <- MTU. // RFC 4960 sec 7.2.3 // When the T3-rtx timer expires on an address, SCTP should perform slow // start by: // ssthresh = max(cwnd/2, 4*MTU) // cwnd = 1*MTU a.ssthresh = max32(a.CWND()/2, 4*a.MTU()) a.setCWND(a.MTU()) a.log.Tracef("[%s] updated cwnd=%d ssthresh=%d inflight=%d (RTO)", a.name, a.CWND(), a.ssthresh, a.inflightQueue.getNumBytes()) // RFC 3758 sec 3.5 // A5) Any time the T3-rtx timer expires, on any destination, the sender // SHOULD try to advance the "Advanced.Peer.Ack.Point" by following // the procedures outlined in C2 - C5. if a.useForwardTSN { // RFC 3758 Sec 3.5 C2 for i := a.advancedPeerTSNAckPoint + 1; ; i++ { c, ok := a.inflightQueue.get(i) if !ok { break } if !c.abandoned() { break } a.advancedPeerTSNAckPoint = i } // RFC 3758 Sec 3.5 C3 if sna32GT(a.advancedPeerTSNAckPoint, a.cumulativeTSNAckPoint) { a.willSendForwardTSN = true } } a.log.Debugf("[%s] T3-rtx timed out: nRtos=%d cwnd=%d ssthresh=%d", a.name, nRtos, a.CWND(), a.ssthresh) /* a.log.Debugf(" - advancedPeerTSNAckPoint=%d", a.advancedPeerTSNAckPoint) a.log.Debugf(" - cumulativeTSNAckPoint=%d", a.cumulativeTSNAckPoint) a.inflightQueue.updateSortedKeys() for i, tsn := range a.inflightQueue.sorted { if c, ok := a.inflightQueue.get(tsn); ok { a.log.Debugf(" - [%d] tsn=%d acked=%v abandoned=%v (%v,%v) len=%d", i, c.tsn, c.acked, c.abandoned(), c.beginningFragment, c.endingFragment, len(c.userData)) } } */ a.inflightQueue.markAllToRetrasmit() a.awakeWriteLoop() return } if id == timerReconfig { a.willRetransmitReconfig = true a.awakeWriteLoop() } } func (a *Association) onRetransmissionFailure(id int) { a.lock.Lock() defer a.lock.Unlock() if id == timerT1Init { a.log.Errorf("[%s] retransmission failure: T1-init", a.name) a.completeHandshake(ErrHandshakeInitAck) return } if id == timerT1Cookie { a.log.Errorf("[%s] retransmission failure: T1-cookie", a.name) a.completeHandshake(ErrHandshakeCookieEcho) return } if id == timerT2Shutdown { a.log.Errorf("[%s] retransmission failure: T2-shutdown", a.name) return } if id == timerT3RTX { // T3-rtx timer will not fail by design // Justifications: // * ICE would fail if the connectivity is lost // * WebRTC spec is not clear how this incident should be reported to ULP a.log.Errorf("[%s] retransmission failure: T3-rtx (DATA)", a.name) return } } func (a *Association) onAckTimeout() { a.lock.Lock() defer a.lock.Unlock() a.log.Tracef("[%s] ack timed out (ackState: %d)", a.name, a.ackState) a.stats.incAckTimeouts() a.ackState = ackStateImmediate a.awakeWriteLoop() } // BufferedAmount returns total amount (in bytes) of currently buffered user data. func (a *Association) BufferedAmount() int { a.lock.RLock() defer a.lock.RUnlock() return a.pendingQueue.getNumBytes() + a.inflightQueue.getNumBytes() } // MaxMessageSize returns the maximum message size you can send. func (a *Association) MaxMessageSize() uint32 { return atomic.LoadUint32(&a.maxMessageSize) } // SetMaxMessageSize sets the maximum message size you can send. func (a *Association) SetMaxMessageSize(maxMsgSize uint32) { atomic.StoreUint32(&a.maxMessageSize, maxMsgSize) } // completeHandshake sends the given error to handshakeCompletedCh unless the read/write // side of the association closes before that can happen. It returns whether it was able // to send on the channel or not. func (a *Association) completeHandshake(handshakeErr error) bool { select { // Note: This is a future place where the user could be notified (COMMUNICATION UP) case a.handshakeCompletedCh <- handshakeErr: return true case <-a.closeWriteLoopCh: // check the read/write sides for closure case <-a.readLoopCloseCh: } return false } func (a *Association) pokeTimerLoop() { // enqueue a single wake-up without blocking. select { case a.timerUpdateCh <- struct{}{}: default: } } func (a *Association) startRackTimer(dur time.Duration) { a.timerMu.Lock() if dur <= 0 { a.rackDeadline = time.Time{} } else { a.rackDeadline = time.Now().Add(dur) } a.timerMu.Unlock() a.pokeTimerLoop() } func (a *Association) stopRackTimer() { a.timerMu.Lock() a.rackDeadline = time.Time{} a.timerMu.Unlock() a.pokeTimerLoop() } func (a *Association) startPTOTimer(dur time.Duration) { a.timerMu.Lock() if dur <= 0 { a.ptoDeadline = time.Time{} } else { a.ptoDeadline = time.Now().Add(dur) } a.timerMu.Unlock() a.pokeTimerLoop() } func (a *Association) stopPTOTimer() { a.timerMu.Lock() a.ptoDeadline = time.Time{} a.timerMu.Unlock() a.pokeTimerLoop() } // drainTimer safely stops a timer and drains its channel if needed. func drainTimer(t *time.Timer) { if !t.Stop() { select { case <-t.C: default: } } } // timerLoop runs one goroutine per association for RACK and PTO deadlines. // this only runs if RACK is enabled. func (a *Association) timerLoop() { //nolint:gocognit,cyclop // begin with a disarmed timer. timer := time.NewTimer(time.Hour) drainTimer(timer) armed := false for { // compute the earliest non-zero deadline. a.timerMu.Lock() rackDeadline := a.rackDeadline ptoDeadline := a.ptoDeadline a.timerMu.Unlock() var next time.Time switch { case rackDeadline.IsZero(): next = ptoDeadline case ptoDeadline.IsZero(): next = rackDeadline default: if rackDeadline.Before(ptoDeadline) { next = rackDeadline } else { next = ptoDeadline } } if next.IsZero() { if armed { drainTimer(timer) armed = false } } else { d := time.Until(next) if d <= 0 { d = time.Nanosecond } if armed { drainTimer(timer) } timer.Reset(d) armed = true } select { case <-a.closeWriteLoopCh: if armed { drainTimer(timer) } return case <-a.timerUpdateCh: // re-compute deadlines and (re)arm in next loop iteration. case <-timer.C: armed = false // snapshot & clear due deadlines before firing to avoid races with re-arms. currTime := time.Now() var fireRack, firePTO bool a.timerMu.Lock() if !a.rackDeadline.IsZero() && !currTime.Before(a.rackDeadline) { fireRack = true a.rackDeadline = time.Time{} } if !a.ptoDeadline.IsZero() && !currTime.Before(a.ptoDeadline) { firePTO = true a.ptoDeadline = time.Time{} } a.timerMu.Unlock() // fire callbacks without holding timerMu. if fireRack { a.onRackTimeout() } if firePTO { a.onPTOTimer() } } } } // onRackAfterSACK implements the RACK logic (RACK for SCTP section 2A/B, section 3) and TLP scheduling (section 2C). func (a *Association) onRackAfterSACK( // nolint:gocognit,cyclop,gocyclo deliveredFound bool, newestDeliveredSendTime time.Time, newestDeliveredOrigTSN uint32, sack *chunkSelectiveAck, ) { // store the current time for when we check if it's needed in step 2 (whether we should maintain ReoWND) currTime := time.Now() // 1) Update highest delivered original TSN for reordering detection (section 2B) if deliveredFound { if sna32LT(a.rackHighestDeliveredOrigTSN, newestDeliveredOrigTSN) { a.rackHighestDeliveredOrigTSN = newestDeliveredOrigTSN } else { // ACK of an original TSN below the high-watermark -> reordering observed a.rackReorderingSeen = true } if newestDeliveredSendTime.After(a.rackDeliveredTime) { a.rackDeliveredTime = newestDeliveredSendTime } } // 2) Maintain ReoWND (RACK for SCTP section 2B) if minRTT := a.rackMinRTTWnd.Min(currTime); minRTT > 0 { a.rackMinRTT = minRTT } var base time.Duration if a.rackMinRTT > 0 { base = max(a.rackMinRTT/4, a.rackReoWndFloor) } // Suppress during recovery if no reordering ever seen; else (re)initialize from base if zero. if !a.rackReorderingSeen && (a.inFastRecovery || a.t3RTX.isRunning()) { a.rackReoWnd = 0 } else if a.rackReoWnd == 0 && base > 0 { a.rackReoWnd = base } // DSACK-style inflation using SCTP duplicate TSNs (RACK for SCTP section 3 noting SCTP // natively reports duplicates + RACK for SCTP section 2B policy) if len(sack.duplicateTSN) > 0 && a.rackMinRTT > 0 { a.rackReoWnd += max(a.rackMinRTT/4, a.rackReoWndFloor) // keep inflated for 16 loss recoveries before reset a.rackKeepInflatedRecoveries = 16 a.log.Tracef("[%s] RACK: DSACK/dupTSN seen, inflate reoWnd to %v", a.name, a.rackReoWnd) } // decrement the keep inflated counter when we leave recovery if !a.inFastRecovery && a.rackKeepInflatedRecoveries > 0 { a.rackKeepInflatedRecoveries-- if a.rackKeepInflatedRecoveries == 0 && a.rackMinRTT > 0 { a.rackReoWnd = a.rackMinRTT / 4 } } // RFC 8985: the reordering window MUST be bounded by SRTT. if srttMs := a.SRTT(); srttMs > 0 { if srttDur := time.Duration(srttMs * 1e6); a.rackReoWnd > srttDur { a.rackReoWnd = srttDur } } // 3) Loss marking on ACK: any outstanding chunk whose (send_time + reoWnd) < newestDeliveredSendTime // is lost (RACK for SCTP section 2A) if !a.rackDeliveredTime.IsZero() { //nolint:nestif marked := false for chunk := a.rackHead; chunk != nil; { next := chunk.rackNext // save in case we remove c // but clean up if they exist. if chunk.acked || chunk.abandoned() { a.rackRemove(chunk) chunk = next continue } if chunk.retransmit || chunk.nSent > 1 { // Either already scheduled for retransmit or not an original send: // skip but keep in list in case it's still outstanding. chunk = next continue } // Ordered by original send time. If this one is too new, // all later ones are even newer -> short-circuit. if !chunk.since.Add(a.rackReoWnd).Before(a.rackDeliveredTime) { break } // Mark as lost by RACK. chunk.retransmit = true marked = true // Remove from xmit-time list: we no longer need RACK for this TSN. a.rackRemove(chunk) a.log.Tracef("[%s] RACK: mark lost tsn=%d (sent=%v, delivered=%v, reoWnd=%v)", a.name, chunk.tsn, chunk.since, a.rackDeliveredTime, a.rackReoWnd) chunk = next } if marked { // loss detected during active TLR so we must reduce burst if a.tlrActive { a.tlrApplyAdditionalLossLocked(currTime) } a.awakeWriteLoop() } } // 4) Arm the RACK timer if there are still outstanding but not-yet-overdue chunks (RACK for SCTP section 2A) if a.rackHead != nil && !a.rackDeliveredTime.IsZero() { // RackRTT = RTT of the most recently delivered packet rackRTT := max(time.Since(a.rackDeliveredTime), time.Duration(0)) a.startRackTimer(rackRTT + a.rackReoWnd) // RACK for SCTP section 2A } else { a.stopRackTimer() } // 5) Re/schedule Tail Loss Probe (PTO) (RACK for SCTP section 2C) // Triggered when new data is sent or cum-ack advances; we approximate by scheduling on every SACK that advanced if a.inflightQueue.size() == 0 { a.stopPTOTimer() return } var pto time.Duration srttMs := a.SRTT() if srttMs > 0 { srtt := time.Duration(srttMs * 1e6) extra := 2 * time.Millisecond if a.inflightQueue.size() == 1 { extra = a.rackWCDelAck // 200ms for single outstanding, else 2ms } pto = 2*srtt + extra } else { pto = time.Second // no RTT yet } a.startPTOTimer(pto) } // schedulePTOAfterSendLocked starts/restarts the PTO timer when new data is transmitted. // Caller must hold a.lock. func (a *Association) schedulePTOAfterSendLocked() { if a.inflightQueue.size() == 0 { a.stopPTOTimer() return } var pto time.Duration if srttMs := a.SRTT(); srttMs > 0 { srtt := time.Duration(srttMs * 1e6) extra := 2 * time.Millisecond if a.inflightQueue.size() == 1 { extra = a.rackWCDelAck } pto = 2*srtt + extra } else { pto = time.Second } a.startPTOTimer(pto) } // onRackTimeout is fired to avoid waiting for the next ACK. func (a *Association) onRackTimeout() { a.lock.Lock() defer a.lock.Unlock() a.onRackTimeoutLocked() } func (a *Association) onRackTimeoutLocked() { //nolint:cyclop if a.rackDeliveredTime.IsZero() { return } marked := false for chunk := a.rackHead; chunk != nil; { next := chunk.rackNext if chunk.acked || chunk.abandoned() { a.rackRemove(chunk) chunk = next continue } if chunk.retransmit || chunk.nSent > 1 { chunk = next continue } if !chunk.since.Add(a.rackReoWnd).Before(a.rackDeliveredTime) { // too new, later ones are newer so we can skip. break } chunk.retransmit = true marked = true a.rackRemove(chunk) a.log.Tracef("[%s] RACK timer: mark lost tsn=%d", a.name, chunk.tsn) chunk = next } if marked { // loss detected during active TLR so we must reduce burst if a.tlrActive { a.tlrApplyAdditionalLossLocked(time.Now()) } a.awakeWriteLoop() } } func (a *Association) onPTOTimer() { a.lock.Lock() defer a.lock.Unlock() a.onPTOTimerLocked() } func (a *Association) onPTOTimerLocked() { // if nothing is inflight, PTO should not drive TLR. // use PTO as a chance to probe RTT via HEARTBEAT instead of retransmitting DATA. if a.inflightQueue.size() == 0 { a.stopPTOTimer() a.log.Tracef("[%s] PTO idle: sending active HEARTBEAT for RTT probe", a.name) a.sendActiveHeartbeatLocked() return } currTime := time.Now() if !a.tlrActive { a.tlrBeginLocked() } else { a.tlrApplyAdditionalLossLocked(currTime) } // If we have unsent data, PTO should just wake the writer. if a.pendingQueue.size() > 0 { a.awakeWriteLoop() return } // otherwise retransmit most recently sent in-flight DATA. var latest *chunkPayloadData for i := uint32(0); ; i++ { c, ok := a.inflightQueue.get(a.cumulativeTSNAckPoint + i + 1) if !ok { break } if c.acked || c.abandoned() { continue } latest = c } if latest != nil && !latest.retransmit { latest.retransmit = true a.log.Tracef("[%s] PTO fired: probe tsn=%d", a.name, latest.tsn) a.awakeWriteLoop() } } func (a *Association) rackInsert(c *chunkPayloadData) { if c == nil || c.rackInList { return } if a.rackTail != nil { a.rackTail.rackNext = c c.rackPrev = a.rackTail } else { a.rackHead = c } a.rackTail = c c.rackInList = true } func (a *Association) rackRemove(chunk *chunkPayloadData) { if chunk == nil || !chunk.rackInList { return } if prev := chunk.rackPrev; prev != nil { prev.rackNext = chunk.rackNext } else { a.rackHead = chunk.rackNext } if next := chunk.rackNext; next != nil { next.rackPrev = chunk.rackPrev } else { a.rackTail = chunk.rackPrev } chunk.rackPrev = nil chunk.rackNext = nil chunk.rackInList = false } // caller must hold a.lock. func (a *Association) tlrFirstRTTDurationLocked() time.Duration { // Use SRTT when available; fall back to a safe default. if srttMs := a.SRTT(); srttMs > 0 { return time.Duration(srttMs * 1e6) } return time.Second } // caller must hold a.lock. func (a *Association) tlrUpdatePhaseLocked(currTime time.Time) { if !a.tlrActive || !a.tlrFirstRTT { return } if a.tlrStartTime.IsZero() { return } if currTime.Sub(a.tlrStartTime) >= a.tlrFirstRTTDurationLocked() { a.tlrFirstRTT = false } } // caller must hold a.lock. func (a *Association) tlrCurrentBurstUnitsLocked() int64 { if !a.tlrActive { return 0 } a.tlrUpdatePhaseLocked(time.Now()) if a.tlrFirstRTT { return a.tlrBurstFirstRTTUnits } return a.tlrBurstLaterRTTUnits } // caller must hold a.lock. // Returns remaining burst budget in "scaled bytes": bytes * 4 (quarter-MTU precision). func (a *Association) tlrCurrentBurstBudgetScaledLocked() int64 { if !a.tlrActive { return 0 } units := a.tlrCurrentBurstUnitsLocked() return units * int64(a.MTU()) } // caller must hold a.lock. func (a *Association) tlrHighestOutstandingTSNLocked() (uint32, bool) { var last uint32 found := false for i := uint32(0); ; i++ { tsn := a.cumulativeTSNAckPoint + i + 1 _, ok := a.inflightQueue.get(tsn) if !ok { break } last = tsn found = true } return last, found } // caller must hold a.lock. func (a *Association) tlrBeginLocked() { currTime := time.Now() a.tlrActive = true a.tlrFirstRTT = true a.tlrHadAdditionalLoss = false a.tlrStartTime = currTime if endTSN, ok := a.tlrHighestOutstandingTSNLocked(); ok { a.tlrEndTSN = endTSN } else { a.tlrEndTSN = a.cumulativeTSNAckPoint } } // caller must hold a.lock. func (a *Association) tlrApplyAdditionalLossLocked(currTime time.Time) { if !a.tlrActive { return } // Decide whether we're still within the first recovery RTT window. a.tlrUpdatePhaseLocked(currTime) a.tlrHadAdditionalLoss = true a.tlrGoodOps = 0 if a.tlrFirstRTT { // Loss during first recovery RTT => initial burst too high. a.tlrBurstFirstRTTUnits -= tlrBurstStepDownFirstRTT if a.tlrBurstFirstRTTUnits < tlrBurstMinFirstRTT { a.tlrBurstFirstRTTUnits = tlrBurstMinFirstRTT } } else { // Loss during later RTTs => increasing rate too high. a.tlrBurstLaterRTTUnits -= tlrBurstStepDownLaterRTT if a.tlrBurstLaterRTTUnits < tlrBurstMinLaterRTT { a.tlrBurstLaterRTTUnits = tlrBurstMinLaterRTT } } } // caller must hold a.lock. func (a *Association) tlrMaybeFinishLocked(ackProgress bool) { if !a.tlrActive { return } // determine if we should move from the first RTT burst to later RTT burst. if a.tlrFirstRTT && ackProgress { a.tlrFirstRTT = false } // finish once cumulatively ACKed through the tail we were recovering. if sna32GTE(a.cumulativeTSNAckPoint, a.tlrEndTSN) { if !a.tlrHadAdditionalLoss { a.tlrGoodOps++ if a.tlrGoodOps >= tlrGoodOpsResetThreshold { a.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT a.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT a.tlrGoodOps = 0 } } else { a.tlrGoodOps = 0 } a.tlrActive = false a.tlrFirstRTT = false a.tlrHadAdditionalLoss = false a.tlrEndTSN = 0 } } // caller must hold a.lock. // "budgetScaled" is remaining burst budget in (bytes*4) scale. // "consumed" allows the first send in a burst. func (a *Association) tlrAllowSendLocked(budgetScaled *int64, consumed *bool, estBytes int) bool { if !a.tlrActive || budgetScaled == nil || consumed == nil { return true } if estBytes <= 0 { return true } needScaled := int64(estBytes) * tlrUnitsPerMTU // bytes*4 if *consumed && *budgetScaled < needScaled { return false } *budgetScaled -= needScaled if *budgetScaled < 0 { *budgetScaled = 0 } *consumed = true return true } // ActiveHeartbeat sends a HEARTBEAT chunk on the association to perform an // on-demand RTT measurement without application payload. // // It is safe to call from outside; it will take the association lock and // be a no-op if the association is not established. func (a *Association) ActiveHeartbeat() { a.lock.Lock() defer a.lock.Unlock() if a.getState() != established { return } a.sendActiveHeartbeatLocked() } // caller must hold a.lock. func (a *Association) sendActiveHeartbeatLocked() { now := time.Now().UnixNano() buf := make([]byte, 8) binary.BigEndian.PutUint64(buf, uint64(now)) //nolint:gosec // time.now() will never be negative info := ¶mHeartbeatInfo{heartbeatInformation: buf} hb := &chunkHeartbeat{ chunkHeader: chunkHeader{ typ: ctHeartbeat, flags: 0, }, params: []param{info}, } a.controlQueue.push(&packet{ verificationTag: a.peerVerificationTag, sourcePort: a.sourcePort, destinationPort: a.destinationPort, chunks: []chunk{hb}, }) a.awakeWriteLoop() } sctp-1.9.0/association_stats.go000066400000000000000000000042711512256410600165610ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "sync/atomic" ) type associationStats struct { nPacketsReceived uint64 nPacketsSent uint64 nDATAs uint64 nSACKsReceived uint64 nSACKsSent uint64 nT3Timeouts uint64 nAckTimeouts uint64 nFastRetrans uint64 } func (s *associationStats) incPacketsReceived() { atomic.AddUint64(&s.nPacketsReceived, 1) } func (s *associationStats) getNumPacketsReceived() uint64 { return atomic.LoadUint64(&s.nPacketsReceived) } func (s *associationStats) incPacketsSent() { atomic.AddUint64(&s.nPacketsSent, 1) } func (s *associationStats) getNumPacketsSent() uint64 { return atomic.LoadUint64(&s.nPacketsSent) } func (s *associationStats) incDATAs() { atomic.AddUint64(&s.nDATAs, 1) } func (s *associationStats) getNumDATAs() uint64 { return atomic.LoadUint64(&s.nDATAs) } func (s *associationStats) incSACKsReceived() { atomic.AddUint64(&s.nSACKsReceived, 1) } func (s *associationStats) getNumSACKsReceived() uint64 { return atomic.LoadUint64(&s.nSACKsReceived) } func (s *associationStats) incSACKsSent() { atomic.AddUint64(&s.nSACKsSent, 1) } func (s *associationStats) getNumSACKsSent() uint64 { return atomic.LoadUint64(&s.nSACKsSent) } func (s *associationStats) incT3Timeouts() { atomic.AddUint64(&s.nT3Timeouts, 1) } func (s *associationStats) getNumT3Timeouts() uint64 { return atomic.LoadUint64(&s.nT3Timeouts) } func (s *associationStats) incAckTimeouts() { atomic.AddUint64(&s.nAckTimeouts, 1) } func (s *associationStats) getNumAckTimeouts() uint64 { return atomic.LoadUint64(&s.nAckTimeouts) } func (s *associationStats) incFastRetrans() { atomic.AddUint64(&s.nFastRetrans, 1) } func (s *associationStats) getNumFastRetrans() uint64 { return atomic.LoadUint64(&s.nFastRetrans) } func (s *associationStats) reset() { atomic.StoreUint64(&s.nPacketsReceived, 0) atomic.StoreUint64(&s.nPacketsSent, 0) atomic.StoreUint64(&s.nDATAs, 0) atomic.StoreUint64(&s.nSACKsReceived, 0) atomic.StoreUint64(&s.nSACKsSent, 0) atomic.StoreUint64(&s.nT3Timeouts, 0) atomic.StoreUint64(&s.nAckTimeouts, 0) atomic.StoreUint64(&s.nFastRetrans, 0) } sctp-1.9.0/association_test.go000066400000000000000000003700251512256410600164050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !js // +build !js package sctp import ( "context" cryptoRand "crypto/rand" "encoding/binary" "errors" "io" "math" "math/rand" "net" "os" "runtime" "strings" "sync" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) var ( errHandshakeFailed = errors.New("handshake failed") errSINotMatch = errors.New("SI should match") errReadData = errors.New("failed to read data") errReceivedDataNot3Bytes = errors.New("received data must by 3 bytes") errPPIUnexpected = errors.New("unexpected ppi") errReceivedDataMismatch = errors.New("received data mismatch") ) func TestAssocStressDuplex(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 20) defer lim.Stop() // Check for leaking routines report := test.CheckRoutines(t) defer report() stressDuplex(t) } func stressDuplex(t *testing.T) { t.Helper() ca, cb, stop, err := pipe(t, pipeDump) assert.NoError(t, err) defer stop(t) // Need to Increase once SCTP is more reliable in case of slow reader opt := test.Options{ MsgSize: 2048, // 65535, MsgCount: 10, // 1000, } err = test.StressDuplex(ca, cb, opt) assert.NoError(t, err) } func pipe(t *testing.T, piper piperFunc) (*Stream, *Stream, func(*testing.T), error) { t.Helper() var err error var aa, ab *Association aa, ab, err = association(t, piper) if err != nil { return nil, nil, nil, err } var sa, sb *Stream sa, err = aa.OpenStream(0, 0) if err != nil { return nil, nil, nil, err } sb, err = ab.OpenStream(0, 0) if err != nil { return nil, nil, nil, err } stop := func(t *testing.T) { t.Helper() err = sa.Close() assert.NoError(t, err) err = sb.Close() assert.NoError(t, err) err = aa.Close() assert.NoError(t, err) err = ab.Close() assert.NoError(t, err) } return sa, sb, stop, nil } func association(t *testing.T, piper piperFunc) (*Association, *Association, error) { t.Helper() ca, cb := piper(t) type result struct { a *Association err error } resultCh := make(chan result) loggerFactory := logging.NewDefaultLoggerFactory() // Setup client go func() { client, err := Client(Config{ NetConn: ca, LoggerFactory: loggerFactory, }) resultCh <- result{client, err} }() // Setup server server, err := Server(Config{ NetConn: cb, LoggerFactory: loggerFactory, }) if err != nil { return nil, nil, err } // Receive client res := <-resultCh if res.err != nil { return nil, nil, res.err } return res.a, server, nil } type piperFunc func(t *testing.T) (net.Conn, net.Conn) func pipeDump(t *testing.T) (net.Conn, net.Conn) { t.Helper() aConn := acceptDumbConn(t) addr, ok := aConn.LocalAddr().(*net.UDPAddr) assert.True(t, ok) bConn, err := net.DialUDP("udp4", nil, addr) assert.NoError(t, err) // Dumb handshake mgs := "Test" _, err = bConn.Write([]byte(mgs)) assert.NoError(t, err) b := make([]byte, 4) _, err = aConn.Read(b) assert.NoError(t, err) assert.Equal(t, string(b), mgs) return aConn, bConn } type dumbConn struct { mu sync.RWMutex rAddr net.Addr pConn net.PacketConn } func acceptDumbConn(t *testing.T) *dumbConn { t.Helper() pConn, err := net.ListenUDP("udp4", nil) assert.NoError(t, err) return &dumbConn{ pConn: pConn, } } // Read. func (c *dumbConn) Read(p []byte) (int, error) { i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { return 0, err } c.mu.Lock() c.rAddr = rAddr c.mu.Unlock() return i, err } // Write writes len(p) bytes from p to the DTLS connection. func (c *dumbConn) Write(p []byte) (n int, err error) { return c.pConn.WriteTo(p, c.RemoteAddr()) } // Close closes the conn and releases any Read calls. func (c *dumbConn) Close() error { return c.pConn.Close() } // LocalAddr is a stub. func (c *dumbConn) LocalAddr() net.Addr { if c.pConn != nil { return c.pConn.LocalAddr() } return nil } // RemoteAddr is a stub. func (c *dumbConn) RemoteAddr() net.Addr { c.mu.RLock() defer c.mu.RUnlock() return c.rAddr } // SetDeadline is a stub. func (c *dumbConn) SetDeadline(time.Time) error { return nil } // SetReadDeadline is a stub. func (c *dumbConn) SetReadDeadline(time.Time) error { return nil } // SetWriteDeadline is a stub. func (c *dumbConn) SetWriteDeadline(time.Time) error { return nil } //nolint:cyclop func createNewAssociationPair( br *test.Bridge, ackMode int, recvBufSize uint32, ) (*Association, *Association, error) { var a0, a1 *Association var err0, err1 error loggerFactory := logging.NewDefaultLoggerFactory() handshake0Ch := make(chan bool) handshake1Ch := make(chan bool) go func() { a0, err0 = Client(Config{ Name: "a0", NetConn: br.GetConn0(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake0Ch <- true }() go func() { // we could have two "client"s here but it's more // standard to have one peer starting initialization and // another waiting for the initialization to be requested (INIT). a1, err1 = Server(Config{ Name: "a1", NetConn: br.GetConn1(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake1Ch <- true }() a0handshakeDone := false a1handshakeDone := false loop1: for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) br.Tick() select { case a0handshakeDone = <-handshake0Ch: if a1handshakeDone { break loop1 } case a1handshakeDone = <-handshake1Ch: if a0handshakeDone { break loop1 } default: } } if !a0handshakeDone || !a1handshakeDone { return nil, nil, errHandshakeFailed } if err0 != nil { return nil, nil, err0 } if err1 != nil { return nil, nil, err1 } a0.ackMode = ackMode a1.ackMode = ackMode return a0, a1, nil } func closeAssociationPair(br *test.Bridge, a0, a1 *Association) { close0Ch := make(chan bool) close1Ch := make(chan bool) go func() { // nolint:errcheck,gosec a0.Close() close0Ch <- true }() go func() { // nolint:errcheck,gosec a1.Close() close1Ch <- true }() a0closed := false a1closed := false loop1: for i := 0; i < 100; i++ { time.Sleep(10 * time.Millisecond) br.Tick() select { case a0closed = <-close0Ch: if a1closed { break loop1 } case a1closed = <-close1Ch: if a0closed { break loop1 } default: } } } func flushBuffers(br *test.Bridge, a0, a1 *Association) { for { for { n := br.Tick() if n == 0 { break } } if a0.BufferedAmount() == 0 && a1.BufferedAmount() == 0 { break } time.Sleep(10 * time.Millisecond) } } func establishSessionPair(br *test.Bridge, a0, a1 *Association, si uint16) (*Stream, *Stream, error) { helloMsg := "Hello" // mimic datachannel.channelOpen s0, err := a0.OpenStream(si, PayloadTypeWebRTCBinary) if err != nil { return nil, nil, err } _, err = s0.WriteSCTP([]byte(helloMsg), PayloadTypeWebRTCDCEP) if err != nil { return nil, nil, err } flushBuffers(br, a0, a1) s1, err := a1.AcceptStream() if err != nil { return nil, nil, err } if s0.streamIdentifier != s1.streamIdentifier { return nil, nil, errSINotMatch } br.Process() buf := make([]byte, 1024) n, ppi, err := s1.ReadSCTP(buf) if err != nil { return nil, nil, errReadData } if n != len(helloMsg) { return nil, nil, errReceivedDataNot3Bytes } if ppi != PayloadTypeWebRTCDCEP { return nil, nil, errPPIUnexpected } if string(buf[:n]) != helloMsg { return nil, nil, errReceivedDataMismatch } flushBuffers(br, a0, a1) return s0, s1, nil } func TestAssocReliable(t *testing.T) { //nolint:maintidx // sbuf - small enough not to be fragmented // large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) // sbufL - large enough to be fragmented into two chunks and each chunks are // large enough not to be bundled sbufL := make([]byte, 2000) for i := 0; i < len(sbufL); i++ { sbufL[i] = byte(i & 0xff) } rand.Shuffle(len(sbufL), func(i, j int) { sbufL[i], sbufL[j] = sbufL[j], sbufL[i] }) t.Run("Simple", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) buf := make([]byte, 32) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") closeAssociationPair(br, a0, a1) }) t.Run("ReadDeadline", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") assert.NoError(t, s1.SetReadDeadline(time.Now().Add(time.Millisecond)), "failed to set read deadline") buf := make([]byte, 32) // First fails n, ppi, err := s1.ReadSCTP(buf) assert.Equal(t, 0, n) assert.Equal(t, PayloadProtocolIdentifier(0), ppi) assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) // Second too n, ppi, err = s1.ReadSCTP(buf) assert.Equal(t, 0, n) assert.Equal(t, PayloadProtocolIdentifier(0), ppi) assert.True(t, errors.Is(err, os.ErrDeadlineExceeded)) assert.NoError(t, s1.SetReadDeadline(time.Time{}), "failed to disable read deadline") n, err = s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") closeAssociationPair(br, a0, a1) }) t.Run("ordered reordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 2 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") binary.BigEndian.PutUint32(sbuf, 0) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") binary.BigEndian.PutUint32(sbuf, 1) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") time.Sleep(10 * time.Millisecond) err = br.Reorder(0) assert.NoError(t, err, "reorder failed") br.Process() buf := make([]byte, 2000) n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(0), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("ordered fragmented then defragmented", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") s0.SetReliabilityParams(false, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(false, ReliabilityTypeReliable, 0) n, err = s0.WriteSCTP(sbufL, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") rbuf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(rbuf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") assert.Equal(t, sbufL, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("unordered fragmented then defragmented", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 4 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") s0.SetReliabilityParams(true, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(true, ReliabilityTypeReliable, 0) n, err = s0.WriteSCTP(sbufL, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") rbuf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(rbuf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbufL), "unexpected length of received data") assert.Equal(t, sbufL, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("unordered reordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 5 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") s0.SetReliabilityParams(true, ReliabilityTypeReliable, 0) s1.SetReliabilityParams(true, ReliabilityTypeReliable, 0) br.ReorderNextNWrites(0, 2) binary.BigEndian.PutUint32(sbuf, 0) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") binary.BigEndian.PutUint32(sbuf, 1) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") buf := make([]byte, 2000) flushBuffers(br, a0, a1) n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") assert.Equal(t, uint32(0), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("retransmission", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 const msg1 = "ABC" const msg2 = "DEFG" var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") // lock RTO value at 100 [msec] a0.rtoMgr.setRTO(100.0, true) s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") n, err = s0.WriteSCTP([]byte(msg1), PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") n, err = s0.WriteSCTP([]byte(msg2), PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg2), "unexpected length of received data") br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) // process packets for 200 msec for i := 0; i < 20; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } buf := make([]byte, 32) n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") assert.Equal(t, msg1, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg2), "unexpected length of received data") assert.Equal(t, msg2, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("short buffer", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "Hello" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) buf := make([]byte, 3) n, ppi, err := s1.ReadSCTP(buf) assert.Equal(t, err, io.ErrShortBuffer, "expected error to be io.ErrShortBuffer") assert.Equal(t, n, 5, "unexpected length of received data") assert.Equal(t, ppi, PayloadProtocolIdentifier(0), "unexpected ppi") assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") closeAssociationPair(br, a0, a1) }) } func TestAssocUnreliable(t *testing.T) { //nolint:maintidx // sbuf1, sbuf2: // large enough to be fragmented into two chunks and each chunks are // large enough not to be bundled sbuf1 := make([]byte, 2000) sbuf2 := make([]byte, 2000) for i := 0; i < len(sbuf1); i++ { sbuf1[i] = byte(i & 0xff) } rand.Shuffle(len(sbuf1), func(i, j int) { sbuf1[i], sbuf1[j] = sbuf1[j], sbuf1[i] }) for i := 0; i < len(sbuf2); i++ { sbuf2[i] = byte(i & 0xff) } rand.Shuffle(len(sbuf2), func(i, j int) { sbuf2[i], sbuf2[j] = sbuf2[j], sbuf2[i] }) // sbuf - small enough not to be fragmented // large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xff) } rand.Shuffle(len(sbuf), func(i, j int) { sbuf[i], sbuf[j] = sbuf[j], sbuf[i] }) t.Run("Rexmit ordered no fragment", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit ordered fragments", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // lock RTO value at 100 [msec] a0.rtoMgr.setRTO(100.0, true) // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(false, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first fragment of the first chunk (second chunk should be sacked) var n int n, err = s0.WriteSCTP(sbuf1, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf1), n, "unexpected length of written data") n, err = s0.WriteSCTP(sbuf2, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf2), n, "unexpected length of written data") flushBuffers(br, a0, a1) rbuf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(rbuf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, sbuf2, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.ordered), "should be nothing in the ordered queue") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit unordered no fragment", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 2 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Rexmit unordered fragments", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // When we set the reliability value to 0 [times], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) s1.SetReliabilityParams(true, ReliabilityTypeRexmit, 0) // doesn't matter var n int n, err = s0.WriteSCTP(sbuf1, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf1), n, "unexpected length of written data") n, err = s0.WriteSCTP(sbuf2, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf2), n, "unexpected length of written data") time.Sleep(10 * time.Millisecond) br.Drop(0, 0, 2) // drop the second fragment of the first chunk (second chunk should be sacked) flushBuffers(br, a0, a1) rbuf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(rbuf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, sbuf2, rbuf[:n], "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.unordered), "should be nothing in the unordered queue") assert.Equal(t, 0, len(s0.reassemblyQueue.unorderedChunks), "should be nothing in the unorderedChunks list") closeAssociationPair(br, a0, a1) }) t.Run("Timed ordered", func(t *testing.T) { // nolint:dupl lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // When we set the reliability value to 0 [msec], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(false, ReliabilityTypeTimed, 0) s1.SetReliabilityParams(false, ReliabilityTypeTimed, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") // br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") closeAssociationPair(br, a0, a1) }) t.Run("Timed unordered", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 3 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // When we set the reliability value to 0 [msec], then it will cause // the chunk to be abandoned immediately after the first transmission. s0.SetReliabilityParams(true, ReliabilityTypeTimed, 0) s1.SetReliabilityParams(true, ReliabilityTypeTimed, 0) // doesn't matter br.DropNextNWrites(0, 1) // drop the first packet (second one should be sacked) var n int binary.BigEndian.PutUint32(sbuf, uint32(0)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") binary.BigEndian.PutUint32(sbuf, uint32(1)) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(sbuf), n, "unexpected length of written data") flushBuffers(br, a0, a1) buf := make([]byte, 2000) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") // should receive the second one only assert.Equal(t, len(sbuf), n, "unexpected length of written data") assert.Equal(t, uint32(1), binary.BigEndian.Uint32(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, len(s0.reassemblyQueue.unordered), "should be nothing in the unordered queue") assert.Equal(t, 0, len(s0.reassemblyQueue.unorderedChunks), "should be nothing in the unorderedChunks list") closeAssociationPair(br, a0, a1) }) } // This test ensures that verification tag is set to 0 for all INIT packets. // A test for this PR https://github.com/pion/sctp/pull/341 // We drop the first INIT ACK, and we expect the verification tag to be 0 on // retransmission. func TestInitVerificationTagIsZero(t *testing.T) { //nolint:cyclop lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() ackCount := 0 recvBufSize := uint32(0) var a0, a1 *Association var err0, err1 error loggerFactory := logging.NewDefaultLoggerFactory() handshake0Ch := make(chan bool) handshake1Ch := make(chan bool) fatalChannel := make(chan error) fitlerFunc := func(pkt []byte) bool { t.Helper() packetData := packet{} assert.NoError(t, packetData.unmarshal(true, pkt)) // Init chunk and Init Ack chunk are never bundled. if len(packetData.chunks) != 1 { return true } switch packetData.chunks[0].(type) { case *chunkInit: if packetData.verificationTag != 0 { // Even without this we will get WARNING: // failed validating packet init chunk expects a verification tag of 0 on the packet when out-of-the-blue // And the connection will fail silently. go func() { fatalChannel <- errors.New("verification tag should be 0 for Init chunk") //nolint:err113 }() return false } // Drop the first two Init Ack chunk. case *chunkInitAck: ackCount++ return ackCount > 2 } return true } br.Filter(0, fitlerFunc) br.Filter(1, fitlerFunc) go func() { a0, err0 = Client(Config{ Name: "a0", NetConn: br.GetConn0(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake0Ch <- true }() go func() { a1, err1 = Client(Config{ Name: "a1", NetConn: br.GetConn1(), MaxReceiveBufferSize: recvBufSize, LoggerFactory: loggerFactory, }) handshake1Ch <- true }() a0handshakeDone := false a1handshakeDone := false loop1: for i := 0; i < 1e3; i++ { time.Sleep(10 * time.Millisecond) br.Tick() select { case a0handshakeDone = <-handshake0Ch: if a1handshakeDone { break loop1 } case a1handshakeDone = <-handshake1Ch: if a0handshakeDone { break loop1 } case err := <-fatalChannel: assert.Failf(t, "fatal error during handshake: %v", err.Error()) default: } } assert.Equal(t, a0handshakeDone, true, "handshake failed e0") assert.Equal(t, a1handshakeDone, true, "handshake failed e1") assert.NoError(t, err0, "failed to create association a0") assert.NoError(t, err1, "failed to create association a1") a0.ackMode = ackModeNoDelay a1.ackMode = ackModeNoDelay s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") flushBuffers(br, a0, a1) buf := make([]byte, 32) n, ppi, err := s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") closeAssociationPair(br, a0, a1) } func TestCreateForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward one abandoned", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.cumulativeTSNAckPoint = 9 assoc.advancedPeerTSNAckPoint = 10 assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, streamIdentifier: 1, streamSequenceNumber: 2, userData: []byte("ABC"), nSent: 1, _abandoned: true, }) fwdtsn := assoc.createForwardTSN() assert.Equal(t, uint32(10), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 1, len(fwdtsn.streams), "there should be one stream") assert.Equal(t, uint16(1), fwdtsn.streams[0].identifier, "si should be 1") assert.Equal(t, uint16(2), fwdtsn.streams[0].sequence, "ssn should be 2") }) t.Run("forward two abandoned with the same SI", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.cumulativeTSNAckPoint = 9 assoc.advancedPeerTSNAckPoint = 12 assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 10, streamIdentifier: 1, streamSequenceNumber: 2, userData: []byte("ABC"), nSent: 1, _abandoned: true, }) assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 11, streamIdentifier: 1, streamSequenceNumber: 3, userData: []byte("DEF"), nSent: 1, _abandoned: true, }) assoc.inflightQueue.pushNoCheck(&chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: 12, streamIdentifier: 2, streamSequenceNumber: 1, userData: []byte("123"), nSent: 1, _abandoned: true, }) fwdtsn := assoc.createForwardTSN() assert.Equal(t, uint32(12), fwdtsn.newCumulativeTSN, "should be able to serialize") assert.Equal(t, 2, len(fwdtsn.streams), "there should be two stream") si1OK := false si2OK := false for _, s := range fwdtsn.streams { switch s.identifier { case 1: assert.Equal(t, uint16(3), s.sequence, "ssn should be 3") si1OK = true case 2: assert.Equal(t, uint16(1), s.sequence, "ssn should be 1") si2OK = true default: assert.Fail(t, "unexpected stream indentifier") } } assert.True(t, si1OK, "si=1 should be present") assert.True(t, si2OK, "si=2 should be present") }) } func TestHandleForwardTSN(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("forward 3 unreceived chunks", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.useForwardTSN = true prevTSN := assoc.peerLastTSN() fwdtsn := &chunkForwardTSN{ newCumulativeTSN: prevTSN + 3, streams: []chunkForwardTSNStream{{identifier: 0, sequence: 0}}, } p := assoc.handleForwardTSN(fwdtsn) assoc.lock.Lock() delayedAckTriggered := assoc.delayedAckTriggered immediateAckTriggered := assoc.immediateAckTriggered assoc.lock.Unlock() assert.Equal(t, assoc.peerLastTSN(), prevTSN+3, "peerLastTSN should advance by 3 ") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 1 missing", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.useForwardTSN = true prevTSN := assoc.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 assoc.payloadQueue.push(assoc.peerLastTSN() + 2) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: assoc.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := assoc.handleForwardTSN(fwdtsn) assoc.lock.Lock() delayedAckTriggered := assoc.delayedAckTriggered immediateAckTriggered := assoc.immediateAckTriggered assoc.lock.Unlock() assert.Equal(t, assoc.peerLastTSN(), prevTSN+2, "peerLastTSN should advance by 3") assert.True(t, delayedAckTriggered, "delayed sack should be triggered") assert.False(t, immediateAckTriggered, "immediate sack should NOT be triggered") assert.Nil(t, p, "should return nil") }) t.Run("forward 1 for 2 missing", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.useForwardTSN = true prevTSN := assoc.peerLastTSN() // this chunk is blocked by the missing chunk at tsn=1 assoc.payloadQueue.push(assoc.peerLastTSN() + 3) fwdtsn := &chunkForwardTSN{ newCumulativeTSN: assoc.peerLastTSN() + 1, streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := assoc.handleForwardTSN(fwdtsn) assoc.lock.Lock() immediateAckTriggered := assoc.immediateAckTriggered assoc.lock.Unlock() assert.Equal(t, assoc.peerLastTSN(), prevTSN+1, "peerLastTSN should advance by 1") assert.True(t, immediateAckTriggered, "immediate sack should be triggered") assert.Nil(t, p, "should return nil") }) t.Run("dup forward TSN chunk should generate sack", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.useForwardTSN = true prevTSN := assoc.peerLastTSN() fwdtsn := &chunkForwardTSN{ newCumulativeTSN: assoc.peerLastTSN(), // old TSN streams: []chunkForwardTSNStream{ {identifier: 0, sequence: 1}, }, } p := assoc.handleForwardTSN(fwdtsn) assoc.lock.Lock() ackState := assoc.ackState assoc.lock.Unlock() assert.Equal(t, assoc.peerLastTSN(), prevTSN, "peerLastTSN should not advance") assert.Equal(t, ackStateImmediate, ackState, "sack should be requested") assert.Nil(t, p, "should return nil") }) } func TestHandleDataAckTriggering(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() newAssoc := func() *Association { assoc := createAssociation(Config{ LoggerFactory: loggerFactory, }) assoc.payloadQueue.init(0) return assoc } t.Run("ordered data uses delayed ack", func(t *testing.T) { assoc := newAssoc() defer assoc.ackTimer.stop() pd := &chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: assoc.peerLastTSN() + 1, streamIdentifier: 1, streamSequenceNumber: 1, userData: []byte("ordered"), } assoc.handleChunksStart() assoc.handleData(pd) assoc.handleChunksEnd() assoc.lock.RLock() ackState := assoc.ackState delayed := assoc.delayedAckTriggered immediate := assoc.immediateAckTriggered assoc.lock.RUnlock() assert.Equal(t, ackStateDelay, ackState, "ordered DATA should use delayed ack") assert.True(t, delayed, "ordered DATA should trigger delayed ack") assert.False(t, immediate, "ordered DATA should not trigger immediate ack") }) t.Run("immediateSack flag requests immediate ack", func(t *testing.T) { assoc := newAssoc() defer assoc.ackTimer.stop() pd := &chunkPayloadData{ immediateSack: true, beginningFragment: true, endingFragment: true, tsn: assoc.peerLastTSN() + 1, streamIdentifier: 1, streamSequenceNumber: 1, userData: []byte("immediate"), } assoc.handleChunksStart() assoc.handleData(pd) assoc.handleChunksEnd() assoc.lock.RLock() ackState := assoc.ackState delayed := assoc.delayedAckTriggered immediate := assoc.immediateAckTriggered assoc.lock.RUnlock() assert.Equal(t, ackStateImmediate, ackState, "Immediate SACK flag should trigger immediate ack") assert.False(t, delayed, "Immediate SACK flag should not trigger delayed ack") assert.True(t, immediate, "Immediate SACK flag should trigger immediate ack") }) t.Run("gap forces immediate ack", func(t *testing.T) { assoc := newAssoc() defer assoc.ackTimer.stop() pd := &chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: assoc.peerLastTSN() + 2, streamIdentifier: 1, streamSequenceNumber: 1, userData: []byte("gap"), } assoc.handleChunksStart() assoc.handleData(pd) assoc.handleChunksEnd() assoc.lock.RLock() ackState := assoc.ackState delayed := assoc.delayedAckTriggered immediate := assoc.immediateAckTriggered assoc.lock.RUnlock() assert.Equal(t, ackStateImmediate, ackState, "gap should trigger immediate ack") assert.False(t, delayed, "gap should not trigger delayed ack") assert.True(t, immediate, "gap should trigger immediate ack") }) t.Run("gap forces immediate ack even in always-delay mode", func(t *testing.T) { assoc := newAssoc() defer assoc.ackTimer.stop() assoc.ackMode = ackModeAlwaysDelay pd := &chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: assoc.peerLastTSN() + 3, // leave gaps streamIdentifier: 1, streamSequenceNumber: 1, userData: []byte("gap-delay"), } assoc.handleChunksStart() assoc.handleData(pd) assoc.handleChunksEnd() assoc.lock.RLock() ackState := assoc.ackState delayed := assoc.delayedAckTriggered immediate := assoc.immediateAckTriggered assoc.lock.RUnlock() assert.Equal(t, ackStateImmediate, ackState, "gap should override always-delay mode") assert.False(t, delayed, "gap should not trigger delayed ack") assert.True(t, immediate, "gap should trigger immediate ack") }) } func TestAssocT1InitTimer(t *testing.T) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Drop the first write br.DropNextNWrites(0, 1) // Start the handlshake a0.init(true) a1.init(true) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } flushBuffers(br, a0, a1) assert.NoError(t, err0, "should be nil") assert.NoError(t, err1, "should be nil") closeAssociationPair(br, a0, a1) }) t.Run("Retransmission failure", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) a1.rtoMgr.setRTO(20, false) // fail after 4 retransmission a0.t1Init.maxRetrans = 4 a1.t1Init.maxRetrans = 4 go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Drop all INIT br.DropNextNWrites(0, 99) br.DropNextNWrites(1, 99) // Start the handlshake a0.init(true) a1.init(true) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } assert.Error(t, err0) assert.Error(t, err1) closeAssociationPair(br, a0, a1) }) } func TestAssocT1CookieTimer(t *testing.T) { //nolint:cyclop loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0, err1 error a0ReadyCh := make(chan bool) a1ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() go func() { err1 = <-a1.handshakeCompletedCh a1ReadyCh <- true }() // Start the handlshake a0.init(true) a1.init(true) // Let the INIT go. br.Tick() // Drop COOKIE-ECHO br.DropNextNWrites(0, 1) a0Ready := false a1Ready := false for !a0Ready || !a1Ready { br.Process() select { case a0Ready = <-a0ReadyCh: case a1Ready = <-a1ReadyCh: default: } } assert.NoError(t, err0, "should be nil") assert.NoError(t, err1, "should be nil") closeAssociationPair(br, a0, a1) }) t.Run("Retransmission failure", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() br := test.NewBridge() a0 := createAssociation(Config{ NetConn: br.GetConn0(), LoggerFactory: loggerFactory, }) a1 := createAssociation(Config{ NetConn: br.GetConn1(), LoggerFactory: loggerFactory, }) var err0 error a0ReadyCh := make(chan bool) assert.Equal(t, rtoInitial, a0.rtoMgr.getRTO()) assert.Equal(t, rtoInitial, a1.rtoMgr.getRTO()) // modified rto for fast test a0.rtoMgr.setRTO(20, false) // fail after 4 retransmission a0.t1Cookie.maxRetrans = 4 go func() { err0 = <-a0.handshakeCompletedCh a0ReadyCh <- true }() // Drop all COOKIE-ECHO br.Filter(0, func(raw []byte) bool { p := &packet{} err := p.unmarshal(true, raw) if !assert.NoError(t, err, "failed to parse packet") { return false // drop } for _, c := range p.chunks { switch c.(type) { case *chunkCookieEcho: return false // drop default: return true } } return true }) // Start the handlshake a0.init(true) a1.init(false) a0Ready := false for !a0Ready { br.Process() select { case a0Ready = <-a0ReadyCh: default: } } assert.Error(t, err0) time.Sleep(1000 * time.Millisecond) br.Process() closeAssociationPair(br, a0, a1) }) } func TestAssocCreateNewStream(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("acceptChSize", func(t *testing.T) { assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) for i := 0; i < acceptChSize; i++ { s := assoc.createStream(uint16(i), true) //nolint:gosec _, ok := assoc.streams[s.streamIdentifier] assert.True(t, ok, "should be in a.streams map") } newSI := uint16(acceptChSize) s := assoc.createStream(newSI, true) assert.Nil(t, s, "should be nil") _, ok := assoc.streams[newSI] assert.False(t, ok, "should NOT be in a.streams map") toBeIgnored := &chunkPayloadData{ beginningFragment: true, endingFragment: true, tsn: assoc.peerLastTSN() + 1, streamIdentifier: newSI, userData: []byte("ABC"), } p := assoc.handleData(toBeIgnored) assert.Nil(t, p, "should be nil") }) } func TestAssocT3RtxTimer(t *testing.T) { // Send one packet, drop it, then retransmitted successfully. t.Run("Retransmission success", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 const msg1 = "ABC" var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") // lock RTO value at 20 [msec] a0.rtoMgr.setRTO(20.0, false) a0.rtoMgr.noUpdate = true s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") n, err = s0.WriteSCTP([]byte(msg1), PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") br.Drop(0, 0, 1) // drop the first packet (second one should be sacked) // process packets for 100 msec for i := 0; i < 10; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } buf := make([]byte, 32) n, ppi, err = s1.ReadSCTP(buf) assert.NoError(t, err, "ReadSCTP failed") assert.Equal(t, n, len(msg1), "unexpected length of received data") assert.Equal(t, msg1, string(buf[:n]), "unexpected received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") br.Process() assert.False(t, s0.reassemblyQueue.isReadable(), "should no longer be readable") a0.lock.RLock() assert.Equal(t, 0, a0.pendingQueue.size(), "should be no packet pending") assert.Equal(t, 0, a0.inflightQueue.size(), "should be no packet inflight") a0.lock.RUnlock() closeAssociationPair(br, a0, a1) }) } func TestAssocCongestionControl(t *testing.T) { //nolint:cyclop,maintidx // sbuf - large enough not to be bundled sbuf := make([]byte, 1000) for i := 0; i < len(sbuf); i++ { sbuf[i] = byte(i & 0xcc) } // 1) Send 4 packets. drop the first one. // 2) Last 3 packets will be received, which triggers loss recovery RACK/TLP. // 3) The first one is retransmitted, which makes s1 readable. // Above should be done before RTO occurs. t.Run("Fast retransmission", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 var n int var ppi PayloadProtocolIdentifier br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") // 1) Send 4 packets, drop the first one. // 2) Last 3 packets will be received, which triggers loss recovery // (either classic Fast Retransmit or RACK/TLP). // 3) The first one is retransmitted, and s1 should see all 4 in order. br.DropNextNWrites(0, 1) // drop the next write from a0 for i := 0; i < 4; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) //nolint:gosec // G115 n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, len(sbuf), n, "unexpected length of sent data") } // process packets for 500 msec; recovery should complete without relying on RTO for i := 0; i < 50; i++ { br.Tick() time.Sleep(10 * time.Millisecond) } rbuf := make([]byte, 3000) // Try to read all 4 packets for i := 0; i < 4; i++ { // The receiver (s1) should be readable s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !assert.True(t, readable, "should be readable") { return } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.NoError(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, i, int(binary.BigEndian.Uint32(rbuf)), "unexpected payload sequence") assert.Equal(t, PayloadTypeWebRTCBinary, ppi, "unexpected ppi") } // Log stats for debugging / sanity t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) t.Logf("nFastRetrans: %d\n", a0.stats.getNumFastRetrans()) t.Logf("nT3Timeouts : %d\n", a0.stats.getNumT3Timeouts()) // With RACK enabled, recovery may happen without classic fast retransmit. // Require recovery before RTO; allow FR count to be 0 or 1. assert.Zero(t, a0.stats.getNumT3Timeouts(), "recovery should complete before any T3 RTO") fr := a0.stats.getNumFastRetrans() assert.Truef(t, fr == 0 || fr == 1, "expected fast retrans 0 or 1 (RACK may bypass FR), got %d", fr) closeAssociationPair(br, a0, a1) }) t.Run("Congestion Avoidance", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 const nPacketsToSend = 2000 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier rbuf := make([]byte, 3000) br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNormal, maxReceiveBufferSize) a0.cwndCAStep = 2800 // 2 mtu assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") a0.stats.reset() a1.stats.reset() for i := 0; i < nPacketsToSend; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) //nolint:gosec // G115 uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") } // Repeat calling br.Tick() until the buffered amount becomes 0 for s0.BufferedAmount() > 0 && nPacketsReceived < nPacketsToSend { for { n = br.Tick() if n == 0 { break } } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.NoError(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, nPacketsReceived, int(binary.BigEndian.Uint32(rbuf)), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } } br.Process() a0.lock.RLock() inFastRecovery := a0.inFastRecovery cwnd := a0.cwnd ssthresh := a0.ssthresh a0.lock.RUnlock() assert.False(t, inFastRecovery, "should not be in fast-recovery") assert.True(t, cwnd > ssthresh, "should be in congestion avoidance mode") assert.True(t, ssthresh >= maxReceiveBufferSize, "should not be less than the initial size of 128KB") assert.Equal(t, nPacketsReceived, nPacketsToSend, "unexpected num of packets received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nT3Timeouts : %d\n", a0.stats.getNumT3Timeouts()) assert.Equal(t, uint64(nPacketsToSend), a1.stats.getNumDATAs(), "packet count mismatch") assert.True(t, a0.stats.getNumSACKsReceived() <= nPacketsToSend/2, "too many sacks") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") closeAssociationPair(br, a0, a1) }) // This is to test even rwnd becomes 0, sender should be able to send a zero window probe // on T3-rtx retramission timeout to complete receiving all the packets. t.Run("Slow reader", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const maxReceiveBufferSize uint32 = 64 * 1024 const si uint16 = 6 nPacketsToSend := int(math.Floor(float64(maxReceiveBufferSize)/1000.0)) * 2 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier rbuf := make([]byte, 3000) br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, maxReceiveBufferSize) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") for i := 0; i < nPacketsToSend; i++ { binary.BigEndian.PutUint32(sbuf, uint32(i)) // nolint:gosec // G115 uint32 sequence number n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") } // 1. First forward packets to receiver until rwnd becomes 0 // 2. Wait until the sender's cwnd becomes 1*MTU (RTO occurred) // 3. Stat reading a1's data var hasRTOed bool for s0.BufferedAmount() > 0 && nPacketsReceived < nPacketsToSend { for { n = br.Tick() if n == 0 { break } } if !hasRTOed { a1.lock.RLock() rwnd := a1.getMyReceiverWindowCredit() a1.lock.RUnlock() a0.lock.RLock() cwnd := a0.cwnd a0.lock.RUnlock() if cwnd > a0.mtu || rwnd > 0 { // Do not read until a1.getMyReceiverWindowCredit() becomes zero continue } hasRTOed = true } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.NoError(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, nPacketsReceived, int(binary.BigEndian.Uint32(rbuf)), "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } time.Sleep(4 * time.Millisecond) } br.Process() assert.Equal(t, nPacketsReceived, nPacketsToSend, "unexpected num of packets received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) closeAssociationPair(br, a0, a1) }) } func TestAssocDelayedAck(t *testing.T) { t.Run("First DATA chunk gets acked with delay", func(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 6 var n int var nPacketsReceived int var ppi PayloadProtocolIdentifier sbuf := make([]byte, 1000) // size should be less than initial cwnd (4380) rbuf := make([]byte, 1500) _, err := cryptoRand.Read(sbuf) if !assert.NoError(t, err, "failed to create associations") { return } br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeAlwaysDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") a0.stats.reset() a1.stats.reset() // Writes data (will fragmented) n, err = s0.WriteSCTP(sbuf, PayloadTypeWebRTCBinary) assert.NoError(t, err, "WriteSCTP failed") assert.Equal(t, n, len(sbuf), "unexpected length of received data") // Repeat calling br.Tick() until the buffered amount becomes 0 since := time.Now() for s0.BufferedAmount() > 0 { for { n = br.Tick() if n == 0 { break } } for { s1.lock.RLock() readable := s1.reassemblyQueue.isReadable() s1.lock.RUnlock() if !readable { break } n, ppi, err = s1.ReadSCTP(rbuf) if !assert.NoError(t, err, "ReadSCTP failed") { return } assert.Equal(t, len(sbuf), n, "unexpected length of received data") assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") nPacketsReceived++ } } delay := time.Since(since).Seconds() t.Logf("received in %.03f seconds", delay) assert.True(t, delay >= 0.2, "should be >= 200msec") br.Process() assert.Equal(t, 1, nPacketsReceived, "should be one packet received") assert.Equal(t, 0, s1.getNumBytesInReassemblyQueue(), "reassembly queue should be empty") t.Logf("nDATAs : %d\n", a1.stats.getNumDATAs()) t.Logf("nSACKs : %d\n", a0.stats.getNumSACKsReceived()) t.Logf("nAckTimeouts: %d\n", a1.stats.getNumAckTimeouts()) assert.Equal(t, uint64(1), a1.stats.getNumDATAs(), "DATA chunk count mismatch") assert.Equal( t, a0.stats.getNumSACKsReceived(), a1.stats.getNumDATAs(), "sack count should be equal to the number of data chunks", ) assert.Equal(t, uint64(1), a1.stats.getNumAckTimeouts(), "ackTimeout count mismatch") assert.Equal(t, uint64(0), a0.stats.getNumT3Timeouts(), "should be no retransmit") closeAssociationPair(br, a0, a1) }) } func checkGoroutineLeaks(t *testing.T) { t.Helper() // Get the count of goroutines at the start of the test. initialGoroutines := runtime.NumGoroutine() // Register a cleanup function to run after the test completes. t.Cleanup(func() { // Allow for up to 1 second for all goroutines to finish. for i := 0; i < 10; i++ { time.Sleep(100 * time.Millisecond) if goroutines := runtime.NumGoroutine(); goroutines <= initialGoroutines { return } } // If we've gotten this far, not all goroutines have finished. assert.Failf(t, "goroutine leak", "leaked: %d", runtime.NumGoroutine()-initialGoroutines) }) } func TestAssocReset(t *testing.T) { //nolint:cyclop t.Run("Close one way", func(t *testing.T) { checkGoroutineLeaks(t) lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") err = s0.Close() // send reset assert.NoError(t, err) doneCh := make(chan error) buf := make([]byte, 32) go func() { for { var ppi PayloadProtocolIdentifier n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err return } assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.Equal(t, n, len(msg), "unexpected length of received data") } }() loop: for { br.Process() select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") break loop default: } } closeAssociationPair(br, a0, a1) }) t.Run("Close both ways", func(t *testing.T) { checkGoroutineLeaks(t) lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 const msg = "ABC" br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err, "failed to create associations") s0, s1, err := establishSessionPair(br, a0, a1, si) assert.NoError(t, err, "failed to establish session pair") assert.Equal(t, 0, a0.BufferedAmount(), "incorrect bufferedAmount") // send a message from s0 to s1 n, err := s0.WriteSCTP([]byte(msg), PayloadTypeWebRTCBinary) assert.NoError(t, err) assert.Equal(t, len(msg), n, "unexpected length of received data") assert.Equal(t, len(msg), a0.BufferedAmount(), "incorrect bufferedAmount") // close s0 as soon as the message is sent err = s0.Close() assert.NoError(t, err) doneCh := make(chan error) buf := make([]byte, 32) go func() { for { var ppi PayloadProtocolIdentifier n, ppi, err = s1.ReadSCTP(buf) if err != nil { doneCh <- err return } assert.Equal(t, ppi, PayloadTypeWebRTCBinary, "unexpected ppi") assert.Equal(t, n, len(msg), "unexpected length of received data") } }() loop0: for { br.Process() select { case err = <-doneCh: assert.Equal(t, io.EOF, err, "should end with EOF") break loop0 default: } } // send reset from s1 err = s1.Close() assert.NoError(t, err) go func() { for { _, _, err = s0.ReadSCTP(buf) assert.Equal(t, io.EOF, err, "should be EOF") if err != nil { doneCh <- err return } } }() loop1: for { br.Process() select { case <-doneCh: break loop1 default: } } time.Sleep(2 * time.Second) closeAssociationPair(br, a0, a1) }) } func TestAssocAbort(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() const si uint16 = 1 br := test.NewBridge() a0, a1, err := createNewAssociationPair(br, ackModeNoDelay, 0) assert.NoError(t, err) abort := &chunkAbort{ errorCauses: []errorCause{&errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }}, } packet, err := a0.marshalPacket(a0.createPacket([]chunk{abort})) assert.NoError(t, err) _, _, err = establishSessionPair(br, a0, a1, si) assert.NoError(t, err) // Both associations are established assert.Equal(t, established, a0.getState()) assert.Equal(t, established, a1.getState()) _, err = a0.netConn.Write(packet) assert.NoError(t, err) flushBuffers(br, a0, a1) // There is a little delay before changing the state to closed time.Sleep(10 * time.Millisecond) // The receiving association should be closed because it got an ABORT assert.Equal(t, established, a0.getState()) assert.Equal(t, closed, a1.getState()) closeAssociationPair(br, a0, a1) } type fakeEchoConn struct { echo chan []byte done chan struct{} closed chan struct{} once sync.Once errClose error mu sync.Mutex bytesSent uint64 bytesReceived uint64 mtu uint32 cwnd uint32 rwnd uint32 srtt float64 } func newFakeEchoConn(errClose error) *fakeEchoConn { return &fakeEchoConn{ echo: make(chan []byte, 1), done: make(chan struct{}), closed: make(chan struct{}), mtu: initialMTU, cwnd: min32(4*initialMTU, max32(2*initialMTU, 4380)), rwnd: initialRecvBufSize, errClose: errClose, } } func (c *fakeEchoConn) Read(b []byte) (int, error) { r, ok := <-c.echo if ok { copy(b, r) c.once.Do(func() { close(c.done) }) c.mu.Lock() c.bytesReceived += uint64(len(r)) c.mu.Unlock() return len(r), nil } return 0, io.EOF } func (c *fakeEchoConn) Write(b []byte) (int, error) { c.mu.Lock() defer c.mu.Unlock() select { case <-c.closed: return 0, io.EOF default: } c.echo <- b c.bytesSent += uint64(len(b)) return len(b), nil } func (c *fakeEchoConn) Close() error { c.mu.Lock() defer c.mu.Unlock() close(c.echo) close(c.closed) return c.errClose } func (c *fakeEchoConn) LocalAddr() net.Addr { return nil } func (c *fakeEchoConn) RemoteAddr() net.Addr { return nil } func (c *fakeEchoConn) SetDeadline(time.Time) error { return nil } func (c *fakeEchoConn) SetReadDeadline(time.Time) error { return nil } func (c *fakeEchoConn) SetWriteDeadline(time.Time) error { return nil } func TestRoutineLeak(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() t.Run("Close failed", func(t *testing.T) { checkGoroutineLeaks(t) conn := newFakeEchoConn(io.EOF) assoc, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done err = assoc.Close() assert.Equal(t, io.EOF, err, "Close() should fail with EOF") select { case _, ok := <-assoc.closeWriteLoopCh: assert.False(t, ok, "closeWriteLoopCh should be closed") default: assert.Fail(t, "closeWriteLoopCh is expected to be closed, but not") } _ = assoc }) t.Run("Connection closed by remote host", func(t *testing.T) { checkGoroutineLeaks(t) conn := newFakeEchoConn(nil) a, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done err = conn.Close() // close connection assert.Equal(t, nil, err, "fake connection returned unexpected error") <-conn.closed <-time.After(10 * time.Millisecond) // switch context to make read/write loops finished select { case _, ok := <-a.closeWriteLoopCh: assert.False(t, ok, "closeWriteLoopCh should be closed") default: assert.Fail(t, "closeWriteLoopCh is expected to be closed, but not") } }) } func TestStats(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() conn := newFakeEchoConn(nil) assoc, err := Client(Config{NetConn: conn, LoggerFactory: loggerFactory}) assert.Equal(t, nil, err, "errored to initialize Client") <-conn.done assert.NoError(t, assoc.Close()) conn.mu.Lock() defer conn.mu.Unlock() assert.Equal(t, conn.bytesReceived, assoc.BytesReceived()) assert.Equal(t, conn.bytesSent, assoc.BytesSent()) assert.Equal(t, conn.mtu, assoc.MTU()) assert.Equal(t, conn.cwnd, assoc.CWND()) assert.Equal(t, conn.rwnd, assoc.RWND()) assert.Equal(t, conn.srtt, assoc.SRTT()) } func TestAssocHandleInit(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() handleInitTest := func(t *testing.T, initialState uint32, expectErr bool) { t.Helper() assoc := createAssociation(Config{ NetConn: &dumbConn{}, LoggerFactory: loggerFactory, }) assoc.setState(initialState) pkt := &packet{ sourcePort: 5001, destinationPort: 5002, } init := &chunkInit{} init.initialTSN = 1234 init.numOutboundStreams = 1001 init.numInboundStreams = 1002 init.initiateTag = 5678 init.advertisedReceiverWindowCredit = 512 * 1024 setSupportedExtensions(&init.chunkInitCommon) _, err := assoc.handleInit(pkt, init) if expectErr { assert.Error(t, err, "should fail") return } assert.NoError(t, err, "should succeed") assert.Equal(t, init.initialTSN-1, assoc.peerLastTSN(), "should match") assert.Equal(t, uint16(1001), assoc.myMaxNumOutboundStreams, "should match") assert.Equal(t, uint16(1002), assoc.myMaxNumInboundStreams, "should match") assert.Equal(t, uint32(5678), assoc.peerVerificationTag, "should match") assert.Equal(t, pkt.sourcePort, assoc.destinationPort, "should match") assert.Equal(t, pkt.destinationPort, assoc.sourcePort, "should match") assert.True(t, assoc.useForwardTSN, "should be set to true") } t.Run("normal", func(t *testing.T) { handleInitTest(t, closed, false) }) t.Run("unexpected state established", func(t *testing.T) { handleInitTest(t, established, true) }) t.Run("unexpected state shutdownAckSent", func(t *testing.T) { handleInitTest(t, shutdownAckSent, true) }) t.Run("unexpected state shutdownPending", func(t *testing.T) { handleInitTest(t, shutdownPending, true) }) t.Run("unexpected state shutdownReceived", func(t *testing.T) { handleInitTest(t, shutdownReceived, true) }) t.Run("unexpected state shutdownSent", func(t *testing.T) { handleInitTest(t, shutdownSent, true) }) } func TestAssocMaxMessageSize(t *testing.T) { t.Run("default", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(65536), a.MaxMessageSize(), "should match") s := a.createStream(1, false) assert.NotNil(t, s, "should succeed") p := make([]byte, 65537) var err error _, err = s.WriteSCTP(p[:65536], s.defaultPayloadType) assert.False(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") _, err = s.WriteSCTP(p[:65537], s.defaultPayloadType) assert.True(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") }) t.Run("explicit", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ MaxMessageSize: 30000, LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(30000), a.MaxMessageSize(), "should match") s := a.createStream(1, false) assert.NotNil(t, s, "should succeed") p := make([]byte, 30001) var err error _, err = s.WriteSCTP(p[:30000], s.defaultPayloadType) assert.False(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") _, err = s.WriteSCTP(p[:30001], s.defaultPayloadType) assert.True(t, strings.Contains(err.Error(), "larger than maximum"), "should be false") }) t.Run("set value", func(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() a := createAssociation(Config{ LoggerFactory: loggerFactory, }) assert.NotNil(t, a, "should succeed") assert.Equal(t, uint32(65536), a.MaxMessageSize(), "should match") a.SetMaxMessageSize(20000) assert.Equal(t, uint32(20000), a.MaxMessageSize(), "should match") }) } type dumbConnInboundHandler func([]byte) type dumbConn2 struct { net.Conn packets [][]byte closed bool localAddr net.Addr remoteAddr net.Addr remoteInboundHandler dumbConnInboundHandler mutex sync.Mutex cond *sync.Cond } func newDumbConn2(localAddr, remoteAddr net.Addr) *dumbConn2 { c := &dumbConn2{ packets: [][]byte{}, localAddr: localAddr, remoteAddr: remoteAddr, } c.cond = sync.NewCond(&c.mutex) return c } func (c *dumbConn2) setRemoteHandler(handler dumbConnInboundHandler) { c.mutex.Lock() c.remoteInboundHandler = handler c.mutex.Unlock() } // Implement the net.Conn interface methods. func (c *dumbConn2) Read(b []byte) (n int, err error) { c.mutex.Lock() defer c.mutex.Unlock() for { if len(c.packets) > 0 { packet := c.packets[0] c.packets = c.packets[1:] n := copy(b, packet) return n, nil } if c.closed { return 0, io.EOF } c.cond.Wait() } } func (c *dumbConn2) Write(b []byte) (int, error) { c.mutex.Lock() closed := c.closed c.mutex.Unlock() if closed { return 0, &net.OpError{Op: "write", Net: "udp", Addr: c.remoteAddr, Err: net.ErrClosed} } c.remoteInboundHandler(b) return len(b), nil } func (c *dumbConn2) Close() error { c.mutex.Lock() defer c.mutex.Unlock() c.closed = true c.cond.Signal() return nil } func (c *dumbConn2) LocalAddr() net.Addr { // Unused by Association return c.localAddr } func (c *dumbConn2) RemoteAddr() net.Addr { // Unused by Association return c.remoteAddr } func (c *dumbConn2) SetDeadline(time.Time) error { return nil } func (c *dumbConn2) SetReadDeadline(time.Time) error { return nil } func (c *dumbConn2) SetWriteDeadline(time.Time) error { return nil } func (c *dumbConn2) inboundHandler(packet []byte) { c.mutex.Lock() defer c.mutex.Unlock() if !c.closed { c.packets = append(c.packets, packet) c.cond.Signal() } } // crateUDPConnPair creates a pair of net.UDPConn objects that are connected with each other. func createUDPConnPair() (net.Conn, net.Conn) { addr1 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 1234} addr2 := &net.UDPAddr{IP: net.ParseIP("127.0.0.1"), Port: 5678} conn1 := newDumbConn2(addr1, addr2) conn2 := newDumbConn2(addr2, addr1) conn1.setRemoteHandler(conn2.inboundHandler) conn2.setRemoteHandler(conn1.inboundHandler) return conn1, conn2 } func createAssocs() (*Association, *Association, error) { //nolint:cyclop udp1, udp2 := createUDPConnPair() loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan any) a2Chan := make(chan any) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() go func() { a, err2 := createClientWithContext(ctx, Config{ NetConn: udp1, LoggerFactory: loggerFactory, }) if err2 != nil { a1Chan <- err2 } else { a1Chan <- a } }() go func() { a, err2 := createClientWithContext(ctx, Config{ NetConn: udp2, LoggerFactory: loggerFactory, }) if err2 != nil { a2Chan <- err2 } else { a2Chan <- a } }() var a1 *Association var a2 *Association loop: for { select { case v1 := <-a1Chan: switch v := v1.(type) { case *Association: a1 = v if a2 != nil { break loop } case error: return nil, nil, v } case v2 := <-a2Chan: switch v := v2.(type) { case *Association: a2 = v if a1 != nil { break loop } case error: return nil, nil, v } } } return a1, a2, nil } // udpDiscardReader blocks all reads after block is set to true. // This allows us to send arbitrary packets on a stream and block the packets received in response. type udpDiscardReader struct { net.Conn ctx context.Context //nolint:containedctx block atomic.Bool } func (d *udpDiscardReader) Read(b []byte) (n int, err error) { if d.block.Load() { <-d.ctx.Done() return 0, d.ctx.Err() } return d.Conn.Read(b) } func createAssociationPair(udpConn1 net.Conn, udpConn2 net.Conn) (*Association, *Association, error) { return createAssociationPairWithConfig(udpConn1, udpConn2, Config{}) } //nolint:cyclop func createAssociationPairWithConfig( udpConn1 net.Conn, udpConn2 net.Conn, config Config, ) (*Association, *Association, error) { loggerFactory := logging.NewDefaultLoggerFactory() a1Chan := make(chan any) a2Chan := make(chan any) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) defer cancel() go func() { cfg := config cfg.NetConn = udpConn1 cfg.LoggerFactory = loggerFactory a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a1Chan <- err2 } else { a1Chan <- a } }() go func() { cfg := config cfg.NetConn = udpConn2 cfg.LoggerFactory = loggerFactory if cfg.MaxReceiveBufferSize == 0 { cfg.MaxReceiveBufferSize = 100_000 } a, err2 := createClientWithContext(ctx, cfg) if err2 != nil { a2Chan <- err2 } else { a2Chan <- a } }() var a1 *Association var a2 *Association loop: for { select { case v1 := <-a1Chan: switch v := v1.(type) { case *Association: a1 = v if a2 != nil { break loop } case error: return nil, nil, v } case v2 := <-a2Chan: switch v := v2.(type) { case *Association: a2 = v if a1 != nil { break loop } case error: return nil, nil, v } } } return a1, a2, nil } func noErrorClose(t *testing.T, closeF func() error) { t.Helper() require.NoError(t, closeF()) } // blockingCloseConn simulates a TCP/TLS connection where Close blocks, and Read // only unblocks when a past read deadline is set. type blockingCloseConn struct { readBlocked chan struct{} closeBlocked chan struct{} once sync.Once } func newBlockingCloseConn() *blockingCloseConn { return &blockingCloseConn{ readBlocked: make(chan struct{}), closeBlocked: make(chan struct{}), } } func (c *blockingCloseConn) unblockRead() { c.once.Do(func() { close(c.readBlocked) }) } func (c *blockingCloseConn) Read(_ []byte) (int, error) { <-c.readBlocked return 0, os.ErrDeadlineExceeded } func (c *blockingCloseConn) Write(p []byte) (int, error) { return len(p), nil } func (c *blockingCloseConn) Close() error { <-c.closeBlocked c.unblockRead() return nil } func (c *blockingCloseConn) LocalAddr() net.Addr { return &net.IPAddr{} } func (c *blockingCloseConn) RemoteAddr() net.Addr { return &net.IPAddr{} } func (c *blockingCloseConn) SetDeadline(_ time.Time) error { return nil } func (c *blockingCloseConn) SetWriteDeadline(_ time.Time) error { return nil } func (c *blockingCloseConn) SetReadDeadline(t time.Time) error { if !t.IsZero() && !t.After(time.Now()) { c.unblockRead() } return nil } func TestAssociationAbortUnblocksStuckRead(t *testing.T) { conn := newBlockingCloseConn() assoc := createAssociation(Config{ NetConn: conn, LoggerFactory: logging.NewDefaultLoggerFactory(), }) assoc.init(false) done := make(chan struct{}) go func() { assoc.Abort("abort read") close(done) }() select { case <-done: case <-time.After(200 * time.Millisecond): require.FailNow(t, "Abort did not return while read loop was blocked") } close(conn.closeBlocked) } // blockingWriteConn simulates a connection whose Write blocks until a write // deadline is set, SetWriteDeadline unblocks the pending Write immediately. type blockingWriteConn struct { readBlocked chan struct{} writeBlocked chan struct{} writeDeadlineCalled chan struct{} unblockReadOnce sync.Once unblockWriteOnce sync.Once } func newBlockingWriteConn() *blockingWriteConn { return &blockingWriteConn{ readBlocked: make(chan struct{}), writeBlocked: make(chan struct{}), writeDeadlineCalled: make(chan struct{}, 1), } } func (c *blockingWriteConn) Read(_ []byte) (int, error) { <-c.readBlocked return 0, os.ErrDeadlineExceeded } func (c *blockingWriteConn) Write(p []byte) (int, error) { <-c.writeBlocked return len(p), nil } func (c *blockingWriteConn) Close() error { return nil } func (c *blockingWriteConn) LocalAddr() net.Addr { return &net.IPAddr{} } func (c *blockingWriteConn) RemoteAddr() net.Addr { return &net.IPAddr{} } func (c *blockingWriteConn) SetDeadline(_ time.Time) error { return nil } func (c *blockingWriteConn) SetWriteDeadline(_ time.Time) error { c.unblockWriteOnce.Do(func() { close(c.writeBlocked) }) select { case c.writeDeadlineCalled <- struct{}{}: default: } return nil } func (c *blockingWriteConn) SetReadDeadline(t time.Time) error { if !t.IsZero() && !t.After(time.Now()) { c.unblockReadOnce.Do(func() { close(c.readBlocked) }) } return nil } func TestAssociationAbortSetsWriteDeadline(t *testing.T) { conn := newBlockingWriteConn() assoc := createAssociation(Config{ NetConn: conn, LoggerFactory: logging.NewDefaultLoggerFactory(), }) assoc.init(false) done := make(chan struct{}) go func() { assoc.Abort("abort write deadline") close(done) }() select { case <-conn.writeDeadlineCalled: case <-time.After(200 * time.Millisecond): require.FailNow(t, "Abort did not call SetWriteDeadline") } select { case <-done: case <-time.After(300 * time.Millisecond): require.FailNow(t, "Abort did not return promptly") } } // readMyNextTSN uses a lock to read the myNextTSN field of the association. // Avoids a data race. func readMyNextTSN(a *Association) uint32 { a.lock.Lock() defer a.lock.Unlock() return a.myNextTSN } func TestAssociationReceiveWindow(t *testing.T) { udp1, udp2 := createUDPConnPair() ctx, cancel := context.WithCancel(context.Background()) dudp1 := &udpDiscardReader{Conn: udp1, ctx: ctx} // a1 is the association used for sending data // a2 is the association with receive window of 100kB which we will // try to bypass a1, a2, err := createAssociationPair(dudp1, udp2) require.NoError(t, err) defer noErrorClose(t, a2.Close) defer noErrorClose(t, a1.Close) s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) require.NoError(t, err) defer s1.Close() // nolint:errcheck,gosec _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) dudp1.block.Store(true) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) s2, err := a2.AcceptStream() require.NoError(t, err) require.Equal(t, uint16(1), s2.streamIdentifier) done := make(chan bool) go func() { chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] chunk := chunks[0] // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue chunk.tsn = readMyNextTSN(a1) + 1e9 for chunk.tsn > readMyNextTSN(a1) { select { case <-done: return default: } chunk.tsn-- pp := a1.bundleDataChunksIntoPackets(chunks) for _, p := range pp { raw, err := p.marshal(true) if err != nil { return } _, err = a1.netConn.Write(raw) if err != nil { return } } if chunk.tsn%10 == 0 { time.Sleep(10 * time.Millisecond) } } }() for cnt := 0; cnt < 15; cnt++ { bytesQueued := s2.getNumBytesInReassemblyQueue() if assert.Less(t, bytesQueued, 5_000_000, "too many bytes enqueued with receive window of 10kb") { break } t.Log("bytes queued", bytesQueued) time.Sleep(1 * time.Second) } close(done) cancel() } func TestAssociationFastRtxWnd(t *testing.T) { udp1, udp2 := createUDPConnPair() a1, a2, err := createAssociationPairWithConfig(udp1, udp2, Config{MinCwnd: 14000, FastRtxWnd: 14000}) require.NoError(t, err) defer noErrorClose(t, a2.Close) defer noErrorClose(t, a1.Close) s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) require.NoError(t, err) defer noErrorClose(t, s1.Close) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) _, err = a2.AcceptStream() require.NoError(t, err) a1.rtoMgr.setRTO(1000, true) // ack the hello packet time.Sleep(1 * time.Second) require.Equal(t, a1.minCwnd, a1.CWND()) var shouldDrop atomic.Bool var dropCounter atomic.Uint32 dbConn1, ok := udp1.(*dumbConn2) require.True(t, ok) dbConn2, ok := udp2.(*dumbConn2) require.True(t, ok) dbConn1.setRemoteHandler(func(packet []byte) { if !shouldDrop.Load() { dbConn2.inboundHandler(packet) } else { dropCounter.Add(1) } }) // intercept SACK var lastSACK atomic.Pointer[chunkSelectiveAck] dbConn2.setRemoteHandler(func(buf []byte) { p := &packet{} require.NoError(t, p.unmarshal(true, buf)) for _, c := range p.chunks { if ack, aok := c.(*chunkSelectiveAck); aok { lastSACK.Store(ack) } } dbConn1.inboundHandler(buf) }) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) require.Eventually(t, func() bool { return lastSACK.Load() != nil }, 1*time.Second, 10*time.Millisecond) shouldDrop.Store(true) // send packets and dropped buf := make([]byte, 700) for i := 0; i < 20; i++ { _, err = s1.WriteSCTP(buf, PayloadTypeWebRTCBinary) require.NoError(t, err) } require.Eventuallyf( t, func() bool { return dropCounter.Load() >= 15 }, 5*time.Second, 10*time.Millisecond, "drop %d", dropCounter.Load(), ) require.Zero(t, a1.stats.getNumFastRetrans()) require.False(t, a1.inFastRecovery) // sack to trigger fast retransmit ack := *(lastSACK.Load()) ack.gapAckBlocks = []gapAckBlock{{start: 11}} for i := 11; i < 14; i++ { ack.gapAckBlocks[0].end = uint16(i) //nolint:gosec // G115 pkt := a1.createPacket([]chunk{&ack}) pktBuf, err1 := pkt.marshal(true) require.NoError(t, err1) dbConn1.inboundHandler(pktBuf) } require.Eventually(t, func() bool { a1.lock.RLock() defer a1.lock.RUnlock() return a1.inFastRecovery }, 5*time.Second, 10*time.Millisecond) require.GreaterOrEqual(t, uint64(10), a1.stats.getNumFastRetrans()) // 7.2.4 b) In fast-recovery AND the Cumulative TSN Ack Point advanced // the miss indications are incremented for all TSNs reported missing // in the SACK. a1.lock.Lock() lastTSN := a1.inflightQueue.chunks.Back().tsn lastTSNMinusTwo := lastTSN - 2 lastChunk := a1.inflightQueue.chunks.Back() lastChunkMinusTwo, ok := a1.inflightQueue.get(lastTSNMinusTwo) a1.lock.Unlock() require.True(t, ok) require.True(t, lastTSN > ack.cumulativeTSNAck+uint32(ack.gapAckBlocks[0].end)+3) // sack with cumAckPoint advanced, lastTSN should not be marked as missing ack.cumulativeTSNAck++ end := lastTSN - 1 - ack.cumulativeTSNAck //nolint:gosec // G115 ack.gapAckBlocks = append(ack.gapAckBlocks, gapAckBlock{start: uint16(end), end: uint16(end)}) pkt := a1.createPacket([]chunk{&ack}) pktBuf, err := pkt.marshal(true) require.NoError(t, err) dbConn1.inboundHandler(pktBuf) require.Eventually(t, func() bool { a1.lock.Lock() defer a1.lock.Unlock() return lastChunkMinusTwo.missIndicator == 1 && lastChunk.missIndicator == 0 }, 5*time.Second, 10*time.Millisecond) } func TestAssociationMaxTSNOffset(t *testing.T) { udp1, udp2 := createUDPConnPair() // a1 is the association used for sending data // a2 is the association with receive window of 100kB which we will // try to bypass a1, a2, err := createAssociationPair(udp1, udp2) require.NoError(t, err) defer noErrorClose(t, a2.Close) defer noErrorClose(t, a1.Close) s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) require.NoError(t, err) defer noErrorClose(t, s1.Close) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) s2, err := a2.AcceptStream() require.NoError(t, err) require.Equal(t, uint16(1), s2.streamIdentifier) chunks, _ := s1.packetize(make([]byte, 1000), PayloadTypeWebRTCBinary) chunks = chunks[:1] sendChunk := func(tsn uint32) { chunk := chunks[0] // Fake the TSN and enqueue 1 chunk with a very high tsn in the payload queue chunk.tsn = tsn pp := a1.bundleDataChunksIntoPackets(chunks) for _, p := range pp { raw, err := p.marshal(true) assert.NoError(t, err) _, err = a1.netConn.Write(raw) assert.NoError(t, err) } } sendChunk(readMyNextTSN(a1) + 100_000) time.Sleep(100 * time.Millisecond) require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000) sendChunk(readMyNextTSN(a1) + 10_000) time.Sleep(100 * time.Millisecond) require.Less(t, s2.getNumBytesInReassemblyQueue(), 1000) sendChunk(readMyNextTSN(a1) + minTSNOffset - 100) time.Sleep(100 * time.Millisecond) require.Greater(t, s2.getNumBytesInReassemblyQueue(), 1000) } func TestAssociation_Shutdown(t *testing.T) { checkGoroutineLeaks(t) a1, a2, err := createAssocs() require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) defer cancel() err = a1.Shutdown(ctx) require.NoError(t, err) // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(1 * time.Second): assert.Fail(t, "timed out waiting for a2 read loop to close") } } func TestAssociation_ShutdownDuringWrite(t *testing.T) { checkGoroutineLeaks(t) a1, a2, err := createAssocs() require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) writingDone := make(chan struct{}) go func() { defer close(writingDone) var i byte for { i++ if i%100 == 0 { time.Sleep(20 * time.Millisecond) } _, writeErr := s21.Write([]byte{i}) if writeErr != nil { return } } }() testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) // running this test with -race flag is very slow so timeout needs to be high. timeout := 5 * time.Minute ctx, cancel := context.WithTimeout(context.Background(), timeout) defer cancel() err = a1.Shutdown(ctx) require.NoError(t, err, "timed out waiting for a1 shutdown to complete") select { case <-writingDone: case <-time.After(timeout): assert.Fail(t, "timed out waiting writing goroutine to exit") } // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(timeout): assert.Fail(t, "timed out waiting for a2 read loop to close") } } func TestAssociation_HandlePacketInCookieWaitState(t *testing.T) { loggerFactory := logging.NewDefaultLoggerFactory() testCases := map[string]struct { inputPacket *packet skipClose bool }{ "InitAck": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{ &chunkInitAck{ chunkInitCommon: chunkInitCommon{ initiateTag: 1, numInboundStreams: 1, numOutboundStreams: 1, advertisedReceiverWindowCredit: 1500, }, }, }, }, }, "Abort": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkAbort{}}, }, // Prevent "use of close network connection" error on close. skipClose: true, }, "CoockeEcho": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkCookieEcho{}}, }, }, "HeartBeat": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkHeartbeat{}}, }, }, "PayloadData": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkPayloadData{}}, }, }, "Sack": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkSelectiveAck{ cumulativeTSNAck: 1000, advertisedReceiverWindowCredit: 1500, gapAckBlocks: []gapAckBlock{ {start: 100, end: 200}, }, }}, }, }, "Reconfig": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkReconfig{ paramA: ¶mOutgoingResetRequest{}, paramB: ¶mReconfigResponse{}, }}, }, }, "ForwardTSN": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkForwardTSN{ newCumulativeTSN: 100, }}, }, }, "Error": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkError{}}, }, }, "Shutdown": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdown{}}, }, }, "ShutdownAck": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdownAck{}}, }, }, "ShutdownComplete": { inputPacket: &packet{ sourcePort: 1, destinationPort: 1, chunks: []chunk{&chunkShutdownComplete{}}, }, }, } for name, testCase := range testCases { testCase := testCase t.Run(name, func(t *testing.T) { aConn, charlieConn := pipeDump(t) assoc := createAssociation(Config{ NetConn: aConn, MaxReceiveBufferSize: 0, LoggerFactory: loggerFactory, }) assoc.init(true) if !testCase.skipClose { defer func() { assert.NoError(t, assoc.close()) }() } packet, err := assoc.marshalPacket(testCase.inputPacket) assert.NoError(t, err) _, err = charlieConn.Write(packet) assert.NoError(t, err) // Should not panic. time.Sleep(100 * time.Millisecond) }) } } func TestAssociation_Abort(t *testing.T) { checkGoroutineLeaks(t) a1, a2, err := createAssocs() require.NoError(t, err) s11, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) s21, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) testData := []byte("test") i, err := s11.Write(testData) assert.Equal(t, len(testData), i) assert.NoError(t, err) buf := make([]byte, len(testData)) i, err = s21.Read(buf) assert.Equal(t, len(testData), i) assert.NoError(t, err) assert.Equal(t, testData, buf) a1.Abort("1234") // Wait for close read loop channels to prevent flaky tests. select { case <-a2.readLoopCloseCh: case <-time.After(1 * time.Second): assert.Fail(t, "timed out waiting for a2 read loop to close") } i, err = s21.Read(buf) assert.Equal(t, i, 0, "expected no data read") assert.Error(t, err, "User Initiated Abort: 1234", "expected abort reason") } // TestAssociation_createClientWithContext tests that the client is closed when the context is canceled. func TestAssociation_createClientWithContext(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 5) defer lim.Stop() checkGoroutineLeaks(t) udp1, udp2 := createUDPConnPair() loggerFactory := logging.NewDefaultLoggerFactory() errCh1 := make(chan error) errCh2 := make(chan error) ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second) go func() { _, err2 := createClientWithContext(ctx, Config{ NetConn: udp1, LoggerFactory: loggerFactory, }) if err2 != nil { errCh1 <- err2 } else { errCh1 <- nil } }() go func() { _, err2 := createClientWithContext(ctx, Config{ NetConn: udp2, LoggerFactory: loggerFactory, }) if err2 != nil { errCh2 <- err2 } else { errCh2 <- nil } }() // Cancel the context immediately cancel() var err1 error var err2 error loop: for { select { case err1 = <-errCh1: if err1 != nil && err2 != nil { break loop } case err2 = <-errCh2: if err1 != nil && err2 != nil { break loop } } } assert.Error(t, err1, "context canceled") assert.Error(t, err2, "context canceled") } type customLogger struct { expectZeroChecksum bool t *testing.T } func (c customLogger) Trace(string) {} func (c customLogger) Tracef(string, ...any) {} func (c customLogger) Debug(string) {} func (c customLogger) Debugf(format string, args ...any) { if format == "[%s] sendZeroChecksum=%t (on initAck)" { assert.Equal(c.t, args[1], c.expectZeroChecksum) } } func (c customLogger) Info(string) {} func (c customLogger) Infof(string, ...any) {} func (c customLogger) Warn(string) {} func (c customLogger) Warnf(string, ...any) {} func (c customLogger) Error(string) {} func (c customLogger) Errorf(string, ...any) {} func (c customLogger) NewLogger(string) logging.LeveledLogger { return c } func TestAssociation_ZeroChecksum(t *testing.T) { checkGoroutineLeaks(t) lim := test.TimeOut(time.Second * 10) defer lim.Stop() for _, testCase := range []struct { clientZeroChecksum, serverZeroChecksum, expectChecksumEnabled bool }{ {true, true, true}, {false, false, false}, {true, false, false}, {false, true, true}, } { a1chan, a2chan := make(chan *Association), make(chan *Association) udp1, udp2 := createUDPConnPair() go func() { a1, err := Client(Config{ NetConn: udp1, LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t}, EnableZeroChecksum: testCase.clientZeroChecksum, }) assert.NoError(t, err) a1chan <- a1 }() go func() { a2, err := Server(Config{ NetConn: udp2, LoggerFactory: &customLogger{testCase.expectChecksumEnabled, t}, EnableZeroChecksum: testCase.serverZeroChecksum, }) assert.NoError(t, err) a2chan <- a2 }() a1, a2 := <-a1chan, <-a2chan writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) testData := []byte("test") _, err = writeStream.Write(testData) require.NoError(t, err) buf := make([]byte, len(testData)) _, err = readStream.Read(buf) assert.NoError(t, err) assert.Equal(t, testData, buf) require.NoError(t, a1.Close()) require.NoError(t, a2.Close()) } } func TestDataChunkBundlingIntoPacket(t *testing.T) { a := &Association{mtu: initialMTU} chunks := make([]*chunkPayloadData, 300) for i := 0; i < 300; i++ { chunks[i] = &chunkPayloadData{userData: []byte{1}} } packets := a.bundleDataChunksIntoPackets(chunks) for _, p := range packets { raw, err := p.marshal(false) require.NoError(t, err) assert.Less(t, len(raw), int(initialMTU), "packet too long") } } func TestAssociation_ReconfigRequestsLimited(t *testing.T) { checkGoroutineLeaks(t) lim := test.TimeOut(time.Second * 10) defer lim.Stop() a1chan, a2chan := make(chan *Association), make(chan *Association) udp1, udp2 := createUDPConnPair() go func() { a1, err := Client(Config{ NetConn: udp1, LoggerFactory: logging.NewDefaultLoggerFactory(), }) assert.NoError(t, err) a1chan <- a1 }() go func() { a2, err := Server(Config{ NetConn: udp2, LoggerFactory: logging.NewDefaultLoggerFactory(), }) assert.NoError(t, err) a2chan <- a2 }() a1, a2 := <-a1chan, <-a2chan writeStream, err := a1.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) readStream, err := a2.OpenStream(1, PayloadTypeWebRTCString) require.NoError(t, err) // exchange some data testData := []byte("test") _, err = writeStream.Write(testData) require.NoError(t, err) buf := make([]byte, len(testData)) _, err = readStream.Read(buf) assert.NoError(t, err) assert.Equal(t, testData, buf) a1.lock.RLock() tsn := a1.myNextTSN a1.lock.RUnlock() for i := 0; i < maxReconfigRequests+100; i++ { c := &chunkReconfig{ paramA: ¶mOutgoingResetRequest{ reconfigRequestSequenceNumber: 10 + uint32(i), //nolint:gosec // G115 senderLastTSN: tsn + 10, // has to be enqueued streamIdentifiers: []uint16{uint16(i)}, //nolint:gosec // G115 }, } p := a1.createPacket([]chunk{c}) buf, err := p.marshal(true) require.NoError(t, err) _, err = a1.netConn.Write(buf) require.NoError(t, err) if i%100 == 0 { time.Sleep(100 * time.Millisecond) } } // Let a2 process the requests time.Sleep(2 * time.Second) a2.lock.RLock() require.LessOrEqual(t, len(a2.reconfigRequests), maxReconfigRequests) a2.lock.RUnlock() require.NoError(t, a1.Close()) require.NoError(t, a2.Close()) } func TestAssociation_OpenStreamAfterClose(t *testing.T) { checkGoroutineLeaks(t) a1, a2, err := createAssocs() require.NoError(t, err) require.NoError(t, a1.Close()) require.NoError(t, a2.Close()) _, err = a1.OpenStream(1, PayloadTypeWebRTCString) require.ErrorIs(t, err, ErrAssociationClosed) _, err = a2.OpenStream(1, PayloadTypeWebRTCString) require.ErrorIs(t, err, ErrAssociationClosed) } // https://github.com/pion/sctp/pull/350 // may need to run with a high test count to reproduce if there // is ever a regression. func TestAssociation_OpenStreamAfterInternalClose(t *testing.T) { checkGoroutineLeaks(t) a1, a2, err := createAssocs() require.NoError(t, err) require.NoError(t, a1.netConn.Close()) require.NoError(t, a2.netConn.Close()) _, err = a1.OpenStream(1, PayloadTypeWebRTCString) require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) _, err = a2.OpenStream(1, PayloadTypeWebRTCString) require.True(t, err == nil || errors.Is(err, ErrAssociationClosed)) require.NoError(t, a1.Close()) require.NoError(t, a2.Close()) require.Equal(t, 0, len(a1.streams)) require.Equal(t, 0, len(a2.streams)) } func TestAssociation_BlockWrite(t *testing.T) { checkGoroutineLeaks(t) conn1, conn2 := createUDPConnPair() a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{BlockWrite: true, MaxReceiveBufferSize: 4000}) require.NoError(t, err) defer noErrorClose(t, a2.Close) defer noErrorClose(t, a1.Close) s1, err := a1.OpenStream(1, PayloadTypeWebRTCBinary) require.NoError(t, err) defer noErrorClose(t, s1.Close) _, err = s1.WriteSCTP([]byte("hello"), PayloadTypeWebRTCBinary) require.NoError(t, err) s2, err := a2.AcceptStream() require.NoError(t, err) data := make([]byte, 4000) n, err := s2.Read(data) require.NoError(t, err) require.Equal(t, "hello", string(data[:n])) // Write should block until data is sent dbConn1, ok := conn1.(*dumbConn2) require.True(t, ok) dbConn2, ok := conn2.(*dumbConn2) require.True(t, ok) dbConn1.setRemoteHandler(dbConn2.inboundHandler) _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) require.NoError(t, err) _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) require.NoError(t, err) // test write deadline // a2's awnd is 0, so write should be blocked require.NoError(t, s1.SetWriteDeadline(time.Now().Add(100*time.Millisecond))) _, err = s1.WriteSCTP(data, PayloadTypeWebRTCBinary) require.ErrorIs(t, err, context.DeadlineExceeded, err) // test write deadline cancel require.NoError(t, s1.SetWriteDeadline(time.Time{})) var deadLineCanceled atomic.Bool writeCanceled := make(chan struct{}, 2) // both write should be blocked and canceled by deadline go func() { _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) require.ErrorIs(t, err, context.DeadlineExceeded, err1) require.True(t, deadLineCanceled.Load()) writeCanceled <- struct{}{} }() go func() { _, err1 := s1.WriteSCTP(data, PayloadTypeWebRTCBinary) require.ErrorIs(t, err, context.DeadlineExceeded, err1) require.True(t, deadLineCanceled.Load()) writeCanceled <- struct{}{} }() time.Sleep(100 * time.Millisecond) deadLineCanceled.Store(true) require.NoError(t, s1.SetWriteDeadline(time.Now().Add(-1*time.Second))) <-writeCanceled <-writeCanceled require.NoError(t, s1.SetWriteDeadline(time.Time{})) rn, rerr := s2.Read(data) require.NoError(t, rerr) require.Equal(t, 4000, rn) // slow reader and fast writer, make sure all write is blocked go func() { for { bytes := make([]byte, 4000) rn, rerr = s2.Read(bytes) if errors.Is(rerr, io.EOF) { return } require.NoError(t, rerr) require.Equal(t, 4000, rn) time.Sleep(5 * time.Millisecond) } }() for i := 0; i < 10; i++ { _, err = s1.Write(data) require.NoError(t, err) // bufferedAmount should not exceed RWND+message size (inflight + pending) require.LessOrEqual(t, s1.BufferedAmount(), uint64(4000*2)) } } func TestConfigMTU(t *testing.T) { const expectedMTU = uint32(8765) conn1, conn2 := createUDPConnPair() a1, a2, err := createAssociationPairWithConfig(conn1, conn2, Config{MTU: 8765}) require.NoError(t, err) require.Equal(t, expectedMTU, a1.MTU()) require.Equal(t, expectedMTU, a2.MTU()) require.NoError(t, a1.Close()) require.NoError(t, conn1.Close()) require.NoError(t, a2.Close()) require.NoError(t, conn2.Close()) } // makes an Association without starting read/write loops, skips init(), just the minimal state. func newRackTestAssoc(t *testing.T) *Association { t.Helper() lg := logging.NewDefaultLoggerFactory() assoc := createAssociation(Config{ LoggerFactory: lg, }) // Put the association into a sane "established" state with fresh queues. assoc.setState(established) assoc.peerVerificationTag = 1 assoc.sourcePort = defaultSCTPSrcDstPort assoc.destinationPort = defaultSCTPSrcDstPort // Deterministic TSN base. assoc.initialTSN = 100 assoc.myNextTSN = 102 // we'll populate TSN=100,101 manually below assoc.cumulativeTSNAckPoint = 99 assoc.advancedPeerTSNAckPoint = 99 // fresh queues assoc.inflightQueue = newPayloadQueue() assoc.payloadQueue = newReceivePayloadQueue(getMaxTSNOffset(assoc.maxReceiveBufferSize)) // RACK defaults for tests assoc.rackReorderingSeen = false assoc.rackReoWndFloor = 0 // Have a non-zero SRTT so SRTT-bounding code runs deterministically. assoc.srtt.Store(float64(100.0)) // 100 ms return assoc } func mkChunk(tsn uint32, since time.Time) *chunkPayloadData { return &chunkPayloadData{ streamIdentifier: 1, streamSequenceNumber: 1, beginningFragment: true, endingFragment: true, userData: []byte("x"), tsn: tsn, since: since, nSent: 1, // original transmission } } func TestRACK_MarkLossOnACK(t *testing.T) { assoc := newRackTestAssoc(t) assoc.lock.Lock() if assoc.rackMinRTTWnd == nil { assoc.rackMinRTTWnd = newWindowedMin(30 * time.Second) } // MinRTT = 40ms → base reoWnd = 10ms assoc.rackMinRTT = 40 * time.Millisecond assoc.rackReoWnd = 10 * time.Millisecond now := time.Now() olderSend := now.Add(-50 * time.Millisecond) newerSend := now.Add(-1 * time.Millisecond) // Outstanding: TSN=100 (older), TSN=101 (newer, just delivered) cOld := mkChunk(100, olderSend) cNew := mkChunk(101, newerSend) // Treat them as original transmissions. cOld.nSent = 1 cNew.nSent = 1 assoc.inflightQueue.pushNoCheck(cOld) assoc.inflightQueue.pushNoCheck(cNew) // Track them in the RACK xmit-time list (ordered by send time). assoc.rackInsert(cOld) assoc.rackInsert(cNew) assoc.lock.Unlock() // Simulate an ACK that delivers TSN 101 with send-time newerSend. assoc.onRackAfterSACK( true, // deliveredFound newerSend, // newestDeliveredSendTime 101, // newestDeliveredOrigTSN &chunkSelectiveAck{}, // SACK contents don't matter for this test ) got, ok := assoc.inflightQueue.get(100) require.True(t, ok, "TSN 100 should still be in inflightQueue") require.NotNil(t, got) assert.True(t, got.retransmit, "RACK should mark older TSN lost on ACK") } func TestRACK_TimerMarksLost(t *testing.T) { assoc := newRackTestAssoc(t) assoc.lock.Lock() // Reordering window and delivered time such that the chunk is clearly overdue. assoc.rackReoWnd = 10 * time.Millisecond assoc.rackDeliveredTime = time.Now() now := time.Now() olderSend := now.Add(-50 * time.Millisecond) // One outstanding original transmission far in the past. c := mkChunk(100, olderSend) c.nSent = 1 assoc.inflightQueue.pushNoCheck(c) assoc.rackInsert(c) assoc.lock.Unlock() // Simulate the RACK timer firing. assoc.onRackTimeout() got, ok := assoc.inflightQueue.get(100) require.True(t, ok, "TSN 100 should still be in inflightQueue") require.NotNil(t, got) assert.True(t, got.retransmit, "RACK timer should mark overdue original as lost") } func TestRACK_DSACKInflatesAndDecays(t *testing.T) { assoc := newRackTestAssoc(t) assoc.rackMinRTT = 100 * time.Millisecond assoc.rackReoWnd = 25 * time.Millisecond // base is 25ms; will inflate by +25ms assoc.rackKeepInflatedRecoveries = 0 // DSACK (duplicate TSN) present -> inflate by max(minRTT/4, floor) and set counter=16 sack := &chunkSelectiveAck{ cumulativeTSNAck: 99, duplicateTSN: []uint32{123}, } // Note that we're checking for 15 and 14 instead of 16 and 15 because it immediately // decrements when not in fast recovery. assoc.onRackAfterSACK(false, time.Time{}, 0, sack) assert.Equal(t, 50*time.Millisecond, assoc.rackReoWnd, "reoWnd should inflate on DSACK") assert.Equal(t, 15, assoc.rackKeepInflatedRecoveries, "keep-inflated counter should be 15") // When not in fast recovery, the counter decays each pass. assoc.inFastRecovery = false assoc.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal(t, 14, assoc.rackKeepInflatedRecoveries) // Drive counter to zero and ensure reoWnd resets to base (minRTT/4). assoc.rackKeepInflatedRecoveries = 1 assoc.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal(t, 0, assoc.rackKeepInflatedRecoveries) assert.Equal(t, 25*time.Millisecond, assoc.rackReoWnd, "reoWnd should reset to base after decay") } func TestRACK_SuppressReoWndDuringRecovery_NoReorderingSeen(t *testing.T) { assoc := newRackTestAssoc(t) // Start with an empty rolling window (no RTT samples). assoc.rackReoWnd = 40 * time.Millisecond assoc.rackReorderingSeen = false assoc.inFastRecovery = true // During recovery with no reordering observed, reoWnd must go to zero. assoc.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal(t, time.Duration(0), assoc.rackReoWnd, "reoWnd should be suppressed during recovery w/o reordering") assoc.inFastRecovery = false assoc.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal(t, time.Duration(0), assoc.rackReoWnd, "reoWnd should stay 0 until a minRTT sample exists") now := time.Now() assoc.rackMinRTTWnd.Push(now, 120*time.Millisecond) assoc.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal( t, 30*time.Millisecond, assoc.rackReoWnd, "reoWnd should re-initialize to base (minRTT/4) after first sample", ) } func TestRACK_ReoWndBoundedBySRTT(t *testing.T) { a := newRackTestAssoc(t) // Set a very large reoWnd, and a small SRTT (10ms). a.rackReoWnd = 200 * time.Millisecond a.srtt.Store(float64(10.0)) // Any onRackAfterSACK pass should bound reoWnd by SRTT. a.onRackAfterSACK(false, time.Time{}, 0, &chunkSelectiveAck{}) assert.Equal(t, 10*time.Millisecond, a.rackReoWnd, "reoWnd must be bounded by SRTT") } func TestRACK_PTO_ProbesLatestOutstanding_WhenNoPending(t *testing.T) { assoc := newRackTestAssoc(t) // Two outstanding, none acked/abandoned. now := time.Now() c0 := mkChunk(100, now.Add(-10*time.Millisecond)) c1 := mkChunk(101, now) assoc.inflightQueue.pushNoCheck(c0) assoc.inflightQueue.pushNoCheck(c1) // No pending -> PTO should mark latest outstanding for retransmit. assoc.onPTOTimer() got0, _ := assoc.inflightQueue.get(100) got1, _ := assoc.inflightQueue.get(101) require.NotNil(t, got0) require.NotNil(t, got1) assert.False(t, got0.retransmit, "older TSN should not be probed by PTO") assert.True(t, got1.retransmit, "latest outstanding should be probed by PTO") } func TestRACK_PTO_DoesNotProbe_WhenPendingExists(t *testing.T) { assoc := newRackTestAssoc(t) // One outstanding assoc.inflightQueue.pushNoCheck(mkChunk(100, time.Now())) // Add something pending (generic non-nil chunk). assoc.pendingQueue.push(&chunkPayloadData{ streamIdentifier: 2, beginningFragment: true, endingFragment: true, userData: []byte("pending"), }) // With pending data, PTO should NOT mark retransmit and simply wake sender. assoc.onPTOTimer() got, _ := assoc.inflightQueue.get(100) require.NotNil(t, got) assert.False(t, got.retransmit, "PTO must prefer sending pending data over probing") } func newTLRAssociationForTest(t *testing.T) (*Association, net.Conn) { t.Helper() c1, c2 := net.Pipe() a := createAssociation(Config{ Name: "tlr-test", NetConn: c1, MTU: 1200, LoggerFactory: logging.NewDefaultLoggerFactory(), RTOMax: 1000, }) return a, c2 } func shutdownTLRAssociationForTest(a *Association, peer net.Conn) { if peer != nil { _ = peer.Close() } if a == nil { return } a.closeWriteLoopOnce.Do(func() { close(a.closeWriteLoopCh) }) a.closeAllTimers() _ = a.netConn.Close() } func pushPendingFullPacketChunks(t *testing.T, a *Association, n int) { t.Helper() userLen := int(a.MTU()) - int(commonHeaderSize+dataChunkHeaderSize) assert.True(t, userLen > 0) for i := 0; i < n; i++ { a.pendingQueue.push(&chunkPayloadData{ streamIdentifier: 0, beginningFragment: true, endingFragment: true, userData: make([]byte, userLen), }) } } func pushInflightRetransmitFullPacketChunks(t *testing.T, a *Association, startTSN uint32, n int) { t.Helper() userLen := int(a.MTU()) - int(commonHeaderSize+dataChunkHeaderSize) assert.True(t, userLen > 0) for i := 0; i < n; i++ { tsn := startTSN + uint32(i) //nolint:gosec a.inflightQueue.pushNoCheck(&chunkPayloadData{ tsn: tsn, streamIdentifier: 0, userData: make([]byte, userLen), nSent: 1, retransmit: true, }) } } func TestTLR_AllowSend_BudgetGatingAndFirstSendAlwaysAllowed(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = time.Now() assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT assoc.lock.Unlock() t.Run("GatesAfterBudgetExhausted_FirstRTT_AllowsExactly4MTUsWorth", func(t *testing.T) { assoc.lock.Lock() defer assoc.lock.Unlock() budget := assoc.tlrCurrentBurstBudgetScaledLocked() consumed := false allowed := 0 mtu := int(assoc.MTU()) for assoc.tlrAllowSendLocked(&budget, &consumed, mtu) { allowed++ } assert.Equal(t, 4, allowed) }) t.Run("FirstSendAllowedEvenIfBudgetTooSmall", func(t *testing.T) { assoc.lock.Lock() defer assoc.lock.Unlock() budget := int64(1000) consumed := false ok1 := assoc.tlrAllowSendLocked(&budget, &consumed, int(assoc.MTU())) ok2 := assoc.tlrAllowSendLocked(&budget, &consumed, int(assoc.MTU())) assert.True(t, ok1) assert.False(t, ok2) assert.True(t, consumed) assert.Equal(t, int64(0), budget) }) } func TestTLR_Begin_SetsEndTSNToHighestOutstanding(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() assoc.cumulativeTSNAckPoint = 99 pushInflightRetransmitFullPacketChunks(t, assoc, 100, 5) // TSN 100..104 exist assoc.tlrBeginLocked() assert.True(t, assoc.tlrActive) assert.True(t, assoc.tlrFirstRTT) assert.False(t, assoc.tlrHadAdditionalLoss) assert.Equal(t, uint32(104), assoc.tlrEndTSN) assert.False(t, assoc.tlrStartTime.IsZero()) } func TestTLR_PhaseSwitchesToLaterOnAckProgress(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = time.Now() assoc.tlrEndTSN = assoc.cumulativeTSNAckPoint + 100 // ensure we don't finish assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT assoc.tlrMaybeFinishLocked(true) // ackProgress triggers phase change assert.False(t, assoc.tlrFirstRTT) assert.Equal(t, int64(tlrBurstDefaultLaterRTT), assoc.tlrCurrentBurstUnitsLocked()) } func TestTLR_FirstRTTExpiresByTime_SRTTAndFallback(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) t.Run("UsesSRTTWhenAvailable", func(t *testing.T) { assoc.lock.Lock() defer assoc.lock.Unlock() assoc.srtt.Store(float64(100)) // 100ms now := time.Now() assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = now.Add(-150 * time.Millisecond) assoc.tlrUpdatePhaseLocked(now) assert.False(t, assoc.tlrFirstRTT) }) t.Run("FallsBackTo1sWhenNoSRTT", func(t *testing.T) { assoc.lock.Lock() defer assoc.lock.Unlock() assoc.srtt.Store(float64(0)) now := time.Now() assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = now.Add(-500 * time.Millisecond) assoc.tlrUpdatePhaseLocked(now) assert.True(t, assoc.tlrFirstRTT) assoc.tlrStartTime = now.Add(-1100 * time.Millisecond) assoc.tlrUpdatePhaseLocked(now) assert.False(t, assoc.tlrFirstRTT) }) } func TestTLR_ApplyAdditionalLoss_FirstRTT_StepDownAndClamp(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() assoc.srtt.Store(float64(1000)) now := time.Now() assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = now assoc.tlrHadAdditionalLoss = false assoc.tlrGoodOps = 7 assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT assoc.tlrApplyAdditionalLossLocked(now.Add(10 * time.Millisecond)) assert.True(t, assoc.tlrHadAdditionalLoss) assert.Equal(t, uint32(0), assoc.tlrGoodOps) assert.Equal(t, int64(12), assoc.tlrBurstFirstRTTUnits) // 16-4 assert.Equal(t, int64(tlrBurstDefaultLaterRTT), assoc.tlrBurstLaterRTTUnits) // should clamp at 8. assoc.tlrApplyAdditionalLossLocked(now.Add(20 * time.Millisecond)) // 12 -> 8 assoc.tlrApplyAdditionalLossLocked(now.Add(30 * time.Millisecond)) // 8 -> 8 (clamped) assert.Equal(t, int64(tlrBurstMinFirstRTT), assoc.tlrBurstFirstRTTUnits) } func TestTLR_ApplyAdditionalLoss_LaterRTT_StepDownAndClamp(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() now := time.Now() assoc.tlrActive = true assoc.tlrFirstRTT = false assoc.tlrStartTime = now.Add(-10 * time.Second) // irrelevant when tlrFirstRTT=false assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT for i := 0; i < 10; i++ { assoc.tlrApplyAdditionalLossLocked(now) } assert.Equal(t, int64(tlrBurstDefaultFirstRTT), assoc.tlrBurstFirstRTTUnits) assert.Equal(t, int64(tlrBurstMinLaterRTT), assoc.tlrBurstLaterRTTUnits) // clamped at 5 } func TestTLR_MaybeFinish_EndsAndClearsState_AndGoodOpsReset(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() // pretend about to complete the 16th "good op". assoc.tlrActive = true assoc.tlrFirstRTT = false assoc.tlrHadAdditionalLoss = false assoc.tlrEndTSN = 200 assoc.cumulativeTSNAckPoint = 200 assoc.tlrBurstFirstRTTUnits = tlrBurstMinFirstRTT assoc.tlrBurstLaterRTTUnits = tlrBurstMinLaterRTT assoc.tlrGoodOps = tlrGoodOpsResetThreshold - 1 assoc.tlrMaybeFinishLocked(false) // finished -> cleared assert.False(t, assoc.tlrActive) assert.False(t, assoc.tlrFirstRTT) assert.False(t, assoc.tlrHadAdditionalLoss) assert.Equal(t, uint32(0), assoc.tlrEndTSN) // reset after 16 good ops assert.Equal(t, int64(tlrBurstDefaultFirstRTT), assoc.tlrBurstFirstRTTUnits) assert.Equal(t, int64(tlrBurstDefaultLaterRTT), assoc.tlrBurstLaterRTTUnits) assert.Equal(t, uint32(0), assoc.tlrGoodOps) } func TestTLR_PopPendingDataChunksToSend_RespectsBurstBudget_FirstRTT_OnePacketPerChunk(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() assoc.setCWND(1_000_000) assoc.setRWND(1_000_000) assoc.tlrActive = true assoc.tlrFirstRTT = true assoc.tlrStartTime = time.Now() assoc.tlrBurstFirstRTTUnits = tlrBurstDefaultFirstRTT // 4 MTU pushPendingFullPacketChunks(t, assoc, 10) budget := assoc.tlrCurrentBurstBudgetScaledLocked() consumed := false chunks, _ := assoc.popPendingDataChunksToSend(&budget, &consumed) // 4 MTU burst, each chunk == 1 full MTU packet, so 4 chunks moved. assert.Equal(t, 4, len(chunks)) assert.Equal(t, 4, assoc.inflightQueue.size()) assert.Equal(t, 6, assoc.pendingQueue.size()) assert.True(t, consumed) assert.Equal(t, int64(0), budget) } func TestTLR_GetDataPacketsToRetransmit_RespectsBurstBudget_LaterRTT(t *testing.T) { assoc, peer := newTLRAssociationForTest(t) defer shutdownTLRAssociationForTest(assoc, peer) assoc.lock.Lock() defer assoc.lock.Unlock() assoc.setCWND(1_000_000) assoc.setRWND(1_000_000) assoc.tlrActive = true assoc.tlrFirstRTT = false assoc.tlrBurstLaterRTTUnits = tlrBurstDefaultLaterRTT // 2 MTU assoc.cumulativeTSNAckPoint = 99 // 6 full-MTU packets are eligible for retransmit. pushInflightRetransmitFullPacketChunks(t, assoc, 100, 6) budget := assoc.tlrCurrentBurstBudgetScaledLocked() consumed := false pkts := assoc.getDataPacketsToRetransmit(&budget, &consumed) // Later RTT burst is 2 MTU; each retransmit is a full MTU packet => 2 packets. assert.Equal(t, 2, len(pkts)) nChunks := 0 for _, p := range pkts { nChunks += len(p.chunks) } assert.Equal(t, 2, nChunks) assert.True(t, consumed) } sctp-1.9.0/chunk.go000066400000000000000000000003601512256410600141320ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type chunk interface { unmarshal(raw []byte) error marshal() ([]byte, error) check() (bool, error) valueLength() int } sctp-1.9.0/chunk_abort.go000066400000000000000000000050031512256410600153200ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "errors" "fmt" ) /* Abort represents an SCTP Chunk of type ABORT The ABORT chunk is sent to the peer of an association to close the association. The ABORT chunk may contain Cause Parameters to inform the receiver about the reason of the abort. DATA chunks MUST NOT be bundled with ABORT. Control chunks (except for INIT, INIT ACK, and SHUTDOWN COMPLETE) MAY be bundled with an ABORT, but they MUST be placed before the ABORT in the SCTP packet or they will be ignored by the receiver. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 6 |Reserved |T| Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | zero or more Error Causes | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkAbort struct { chunkHeader errorCauses []errorCause } // Abort chunk errors. var ( ErrChunkTypeNotAbort = errors.New("ChunkType is not of type ABORT") ErrBuildAbortChunkFailed = errors.New("failed build Abort Chunk") ) func (a *chunkAbort) unmarshal(raw []byte) error { if err := a.chunkHeader.unmarshal(raw); err != nil { return err } if a.typ != ctAbort { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotAbort, a.typ.String()) } offset := chunkHeaderSize for len(raw)-offset >= 4 { e, err := buildErrorCause(raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrBuildAbortChunkFailed, err) //nolint:errorlint } offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } return nil } func (a *chunkAbort) marshal() ([]byte, error) { a.chunkHeader.typ = ctAbort a.flags = 0x00 a.raw = []byte{} for _, ec := range a.errorCauses { raw, err := ec.marshal() if err != nil { return nil, err } a.raw = append(a.raw, raw...) } return a.chunkHeader.marshal() } func (a *chunkAbort) check() (abort bool, err error) { return false, nil } // String makes chunkAbort printable. func (a *chunkAbort) String() string { res := a.chunkHeader.String() for _, cause := range a.errorCauses { res += fmt.Sprintf("\n - %s", cause) } return res } sctp-1.9.0/chunk_abort_test.go000066400000000000000000000032601512256410600163620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestAbortChunk(t *testing.T) { t.Run("One error cause", func(t *testing.T) { abort1 := &chunkAbort{ errorCauses: []errorCause{&errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }}, } bytes, err := abort1.marshal() assert.NoError(t, err, "should succeed") abort2 := &chunkAbort{} err = abort2.unmarshal(bytes) assert.NoError(t, err, "should succeed") assert.Equal(t, 1, len(abort2.errorCauses), "should have only one cause") assert.Equal(t, abort1.errorCauses[0].errorCauseCode(), abort2.errorCauses[0].errorCauseCode(), "errorCause code should match") }) t.Run("Many error causes", func(t *testing.T) { abort1 := &chunkAbort{ errorCauses: []errorCause{ &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: invalidMandatoryParameter}, }, &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: unrecognizedChunkType}, }, &errorCauseProtocolViolation{ errorCauseHeader: errorCauseHeader{code: protocolViolation}, }, }, } bytes, err := abort1.marshal() assert.NoError(t, err, "should succeed") abort2 := &chunkAbort{} err = abort2.unmarshal(bytes) assert.NoError(t, err, "should succeed") assert.Equal(t, 3, len(abort2.errorCauses), "should have only one cause") for i, errorCause := range abort1.errorCauses { assert.Equal(t, errorCause.errorCauseCode(), abort2.errorCauses[i].errorCauseCode(), "errorCause code should match") } }) } sctp-1.9.0/chunk_cookie_ack.go000066400000000000000000000024021512256410600163000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* chunkCookieAck represents an SCTP Chunk of type chunkCookieAck 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 11 |Chunk Flags | Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkCookieAck struct { chunkHeader } // Cookie ack chunk errors. var ( ErrChunkTypeNotCookieAck = errors.New("ChunkType is not of type COOKIEACK") ) func (c *chunkCookieAck) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctCookieAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieAck, c.typ.String()) } return nil } func (c *chunkCookieAck) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieAck return c.chunkHeader.marshal() } func (c *chunkCookieAck) check() (abort bool, err error) { return false, nil } // String makes chunkCookieAck printable. func (c *chunkCookieAck) String() string { return c.chunkHeader.String() } sctp-1.9.0/chunk_cookie_echo.go000066400000000000000000000026251512256410600164670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* CookieEcho represents an SCTP Chunk of type CookieEcho 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 10 |Chunk Flags | Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cookie | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkCookieEcho struct { chunkHeader cookie []byte } // Cookie echo chunk errors. var ( ErrChunkTypeNotCookieEcho = errors.New("ChunkType is not of type COOKIEECHO") ) func (c *chunkCookieEcho) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctCookieEcho { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotCookieEcho, c.typ.String()) } c.cookie = c.raw return nil } func (c *chunkCookieEcho) marshal() ([]byte, error) { c.chunkHeader.typ = ctCookieEcho c.chunkHeader.raw = c.cookie return c.chunkHeader.marshal() } func (c *chunkCookieEcho) check() (abort bool, err error) { return false, nil } sctp-1.9.0/chunk_error.go000066400000000000000000000051101512256410600153410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "errors" "fmt" ) /* Operation Error (ERROR) (9) An endpoint sends this chunk to its peer endpoint to notify it of certain error conditions. It contains one or more error causes. An Operation Error is not considered fatal in and of itself, but may be used with an ERROR chunk to report a fatal condition. It has the following parameters: 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 9 | Chunk Flags | Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ \ \ / one or more Error Causes / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ Chunk Flags: 8 bits Set to 0 on transmit and ignored on receipt. Length: 16 bits (unsigned integer) Set to the size of the chunk in bytes, including the chunk header and all the Error Cause fields present. */ type chunkError struct { chunkHeader errorCauses []errorCause } // Error chunk errors. var ( ErrChunkTypeNotCtError = errors.New("ChunkType is not of type ctError") ErrBuildErrorChunkFailed = errors.New("failed build Error Chunk") ) func (a *chunkError) unmarshal(raw []byte) error { if err := a.chunkHeader.unmarshal(raw); err != nil { return err } if a.typ != ctError { return fmt.Errorf("%w, actually is %s", ErrChunkTypeNotCtError, a.typ.String()) } offset := chunkHeaderSize for len(raw)-offset >= 4 { e, err := buildErrorCause(raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrBuildErrorChunkFailed, err) //nolint:errorlint } offset += int(e.length()) a.errorCauses = append(a.errorCauses, e) } return nil } func (a *chunkError) marshal() ([]byte, error) { a.chunkHeader.typ = ctError a.flags = 0x00 a.raw = []byte{} for _, ec := range a.errorCauses { raw, err := ec.marshal() if err != nil { return nil, err } a.raw = append(a.raw, raw...) } return a.chunkHeader.marshal() } func (a *chunkError) check() (abort bool, err error) { return false, nil } // String makes chunkError printable. func (a *chunkError) String() string { res := a.chunkHeader.String() for _, cause := range a.errorCauses { res += fmt.Sprintf("\n - %s", cause) } return res } sctp-1.9.0/chunk_error_test.go000066400000000000000000000040041512256410600164010ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "reflect" "testing" "github.com/stretchr/testify/assert" ) func TestChunkErrorUnrecognizedChunkType(t *testing.T) { const chunkFlags byte = 0x00 orgUnrecognizedChunk := []byte{0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3} rawIn := append([]byte{byte(ctError), chunkFlags, 0x00, 0x10, 0x00, 0x06, 0x00, 0x0c}, orgUnrecognizedChunk...) t.Run("unmarshal", func(t *testing.T) { c := &chunkError{} err := c.unmarshal(rawIn) assert.Nil(t, err, "unmarshal should succeed") assert.Equal(t, ctError, c.typ, "chunk type should be ERROR") assert.Equal(t, 1, len(c.errorCauses), "there should be on errorCause") ec := c.errorCauses[0] assert.Equal(t, unrecognizedChunkType, ec.errorCauseCode(), "cause code should be unrecognizedChunkType") ecUnrecognizedChunkType, ok := ec.(*errorCauseUnrecognizedChunkType) assert.True(t, ok) unrecognizedChunk := ecUnrecognizedChunkType.unrecognizedChunk assert.True(t, reflect.DeepEqual(unrecognizedChunk, orgUnrecognizedChunk), "should have valid unrecognizedChunk") }) t.Run("marshal", func(t *testing.T) { ecUnrecognizedChunkType := &errorCauseUnrecognizedChunkType{ unrecognizedChunk: orgUnrecognizedChunk, } ec := &chunkError{ errorCauses: []errorCause{ errorCause(ecUnrecognizedChunkType), }, } raw, err := ec.marshal() assert.Nil(t, err, "marshal should succeed") assert.True(t, reflect.DeepEqual(raw, rawIn), "unexpected serialization result") }) t.Run("marshal with cause value being nil", func(t *testing.T) { expected := []byte{byte(ctError), chunkFlags, 0x00, 0x08, 0x00, 0x06, 0x00, 0x04} ecUnrecognizedChunkType := &errorCauseUnrecognizedChunkType{} ec := &chunkError{ errorCauses: []errorCause{ errorCause(ecUnrecognizedChunkType), }, } raw, err := ec.marshal() assert.Nil(t, err, "marshal should succeed") assert.True(t, reflect.DeepEqual(raw, expected), "unexpected serialization result") }) } sctp-1.9.0/chunk_forward_tsn.go000066400000000000000000000110611512256410600165420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) // This chunk shall be used by the data sender to inform the data // receiver to adjust its cumulative received TSN point forward because // some missing TSNs are associated with data chunks that SHOULD NOT be // transmitted or retransmitted by the sender. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type = 192 | Flags = 0x00 | Length = Variable | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | New Cumulative TSN | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream-1 | Stream Sequence-1 | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ / // / \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream-N | Stream Sequence-N | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type chunkForwardTSN struct { chunkHeader // This indicates the new cumulative TSN to the data receiver. Upon // the reception of this value, the data receiver MUST consider // any missing TSNs earlier than or equal to this value as received, // and stop reporting them as gaps in any subsequent SACKs. newCumulativeTSN uint32 streams []chunkForwardTSNStream } const ( newCumulativeTSNLength = 4 forwardTSNStreamLength = 4 ) // Forward TSN chunk errors. var ( ErrMarshalStreamFailed = errors.New("failed to marshal stream") ErrChunkTooShort = errors.New("chunk too short") ) func (c *chunkForwardTSN) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if len(c.raw) < newCumulativeTSNLength { return ErrChunkTooShort } c.newCumulativeTSN = binary.BigEndian.Uint32(c.raw[0:]) offset := newCumulativeTSNLength remaining := len(c.raw) - offset for remaining > 0 { s := chunkForwardTSNStream{} if err := s.unmarshal(c.raw[offset:]); err != nil { return fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) //nolint:errorlint } c.streams = append(c.streams, s) offset += s.length() remaining -= s.length() } return nil } func (c *chunkForwardTSN) marshal() ([]byte, error) { out := make([]byte, newCumulativeTSNLength) binary.BigEndian.PutUint32(out[0:], c.newCumulativeTSN) for _, s := range c.streams { b, err := s.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrMarshalStreamFailed, err) //nolint:errorlint } out = append(out, b...) //nolint:makezero // TODO: fix } c.typ = ctForwardTSN c.raw = out return c.chunkHeader.marshal() } func (c *chunkForwardTSN) check() (abort bool, err error) { return true, nil } // String makes chunkForwardTSN printable. func (c *chunkForwardTSN) String() string { res := fmt.Sprintf("New Cumulative TSN: %d\n", c.newCumulativeTSN) for _, s := range c.streams { res += fmt.Sprintf(" - si=%d, ssn=%d\n", s.identifier, s.sequence) } return res } type chunkForwardTSNStream struct { // This field holds a stream number that was skipped by this // FWD-TSN. identifier uint16 // This field holds the sequence number associated with the stream // that was skipped. The stream sequence field holds the largest // stream sequence number in this stream being skipped. The receiver // of the FWD-TSN's can use the Stream-N and Stream Sequence-N fields // to enable delivery of any stranded TSN's that remain on the stream // re-ordering queues. This field MUST NOT report TSN's corresponding // to DATA chunks that are marked as unordered. For ordered DATA // chunks this field MUST be filled in. sequence uint16 } func (s *chunkForwardTSNStream) length() int { return forwardTSNStreamLength } func (s *chunkForwardTSNStream) unmarshal(raw []byte) error { if len(raw) < forwardTSNStreamLength { return ErrChunkTooShort } s.identifier = binary.BigEndian.Uint16(raw[0:]) s.sequence = binary.BigEndian.Uint16(raw[2:]) return nil } func (s *chunkForwardTSNStream) marshal() ([]byte, error) { // nolint:unparam out := make([]byte, forwardTSNStreamLength) binary.BigEndian.PutUint16(out[0:], s.identifier) binary.BigEndian.PutUint16(out[2:], s.sequence) return out, nil } sctp-1.9.0/chunk_forward_tsn_test.go000066400000000000000000000024641512256410600176100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkForwardTSN() []byte { return []byte{0xc0, 0x0, 0x0, 0x8, 0x0, 0x0, 0x0, 0x3} } func TestChunkForwardTSN_Success(t *testing.T) { tt := []struct { binary []byte }{ {testChunkForwardTSN()}, {[]byte{0xc0, 0x0, 0x0, 0xc, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5}}, {[]byte{0xc0, 0x0, 0x0, 0x10, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6, 0x0, 0x7}}, } for i, tc := range tt { actual := &chunkForwardTSN{} err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) b, err := actual.marshal() assert.NoError(t, err) assert.Equalf(t, tc.binary, b, "test %d not equal", i) } } func TestChunkForwardTSNUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"chunk header to short", []byte{0xc0}}, {"missing New Cumulative TSN", []byte{0xc0, 0x0, 0x0, 0x4}}, {"missing stream sequence", []byte{0xc0, 0x0, 0x0, 0xe, 0x0, 0x0, 0x0, 0x3, 0x0, 0x4, 0x0, 0x5, 0x0, 0x6}}, } for i, tc := range tt { actual := &chunkForwardTSN{} err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/chunk_heartbeat.go000066400000000000000000000067571512256410600161710ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* chunkHeartbeat represents an SCTP Chunk of type HEARTBEAT (RFC 9260 section 3.3.6) An endpoint sends this chunk to probe reachability of a destination address. The chunk MUST contain exactly one variable-length parameter: Variable Parameters Status Type Value ------------------------------------------------------------- Heartbeat Info Mandatory 1 nolint:godot */ type chunkHeartbeat struct { chunkHeader params []param } // Heartbeat chunk errors. var ( ErrChunkTypeNotHeartbeat = errors.New("ChunkType is not of type HEARTBEAT") ErrHeartbeatNotLongEnoughInfo = errors.New("heartbeat is not long enough to contain Heartbeat Info") ErrParseParamTypeFailed = errors.New("failed to parse param type") ErrHeartbeatParam = errors.New("heartbeat should only have HEARTBEAT param") ErrHeartbeatChunkUnmarshal = errors.New("failed unmarshalling param in Heartbeat Chunk") ErrHeartbeatExtraNonZero = errors.New("heartbeat has non-zero trailing bytes after last parameter") ErrHeartbeatMarshalNoInfo = errors.New("heartbeat marshal requires exactly one Heartbeat Info parameter") ) func (h *chunkHeartbeat) unmarshal(raw []byte) error { //nolint:cyclop if err := h.chunkHeader.unmarshal(raw); err != nil { return err } if h.typ != ctHeartbeat { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotHeartbeat, h.typ.String()) } // if the body is completely empty, accept it but don't populate params. if len(h.raw) == 0 { return nil } // need at least a parameter header present (TLV: 4 bytes minimum). if len(h.raw) < initOptionalVarHeaderLength { return fmt.Errorf("%w: %d", ErrHeartbeatNotLongEnoughInfo, len(h.raw)) } pType, err := parseParamType(h.raw) if err != nil { return fmt.Errorf("%w: %v", ErrParseParamTypeFailed, err) //nolint:errorlint } if pType != heartbeatInfo { return fmt.Errorf("%w: instead have %s", ErrHeartbeatParam, pType.String()) } var pHeader paramHeader if e := pHeader.unmarshal(h.raw); e != nil { return fmt.Errorf("%w: %v", ErrParseParamTypeFailed, e) //nolint:errorlint } plen := pHeader.length() if plen < initOptionalVarHeaderLength || plen > len(h.raw) { return ErrHeartbeatNotLongEnoughInfo } p, err := buildParam(pType, h.raw[:plen]) if err != nil { return fmt.Errorf("%w: %v", ErrHeartbeatChunkUnmarshal, err) //nolint:errorlint } h.params = append(h.params, p) // any trailing bytes beyond the single param must be all zeros. if rem := h.raw[plen:]; len(rem) > 0 && !allZero(rem) { return ErrHeartbeatExtraNonZero } return nil } func (h *chunkHeartbeat) Marshal() ([]byte, error) { // exactly one Heartbeat Info param is required. if len(h.params) != 1 { return nil, ErrHeartbeatMarshalNoInfo } // enforce correct concrete type via type assertion (param interface has no type getter). if _, ok := h.params[0].(*paramHeartbeatInfo); !ok { return nil, ErrHeartbeatParam } pp, err := h.params[0].marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrHeartbeatChunkUnmarshal, err) //nolint:errorlint } // single TLV, no inter-parameter padding within the chunk body. h.chunkHeader.typ = ctHeartbeat h.chunkHeader.flags = 0 // sender MUST set to 0 h.chunkHeader.raw = append([]byte(nil), pp...) return h.chunkHeader.marshal() } func (h *chunkHeartbeat) check() (abort bool, err error) { return false, nil } sctp-1.9.0/chunk_heartbeat_ack.go000066400000000000000000000110431512256410600167670ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* chunkHeartbeatAck represents an SCTP Chunk of type HEARTBEAT ACK An endpoint should send this chunk to its peer endpoint as a response to a HEARTBEAT chunk (see Section 8.3). A HEARTBEAT ACK is always sent to the source IP address of the IP datagram containing the HEARTBEAT chunk to which this ack is responding. The parameter field contains a variable-length opaque data structure. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 5 | Chunk Flags | Heartbeat Ack Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Heartbeat Information TLV (Variable-Length) | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ Defined as a variable-length parameter using the format described in Section 3.2.1, i.e.: Variable Parameters Status Type Value ------------------------------------------------------------- Heartbeat Info Mandatory 1 . */ type chunkHeartbeatAck struct { chunkHeader params []param } // Heartbeat ack chunk errors. var ( // Deprecated: this error is no longer used but is kept for compatibility. ErrUnimplemented = errors.New("unimplemented") ErrChunkTypeNotHeartbeatAck = errors.New("chunk type is not of type HEARTBEAT ACK") ErrHeartbeatAckParams = errors.New("heartbeat Ack must have one param") ErrHeartbeatAckNotHeartbeatInfo = errors.New("heartbeat Ack must have one param, and it should be a HeartbeatInfo") ErrHeartbeatAckMarshalParam = errors.New("unable to marshal parameter for Heartbeat Ack") ) func (h *chunkHeartbeatAck) unmarshal(raw []byte) error { //nolint:cyclop if err := h.chunkHeader.unmarshal(raw); err != nil { return err } if h.typ != ctHeartbeatAck { return fmt.Errorf("%w %s", ErrChunkTypeNotHeartbeatAck, h.typ.String()) } // allow for an empty heartbeat: no RTT info -> ActiveHeartbeat just won't update SRTT. if len(h.raw) == 0 { h.params = nil return nil } if len(h.raw) < initOptionalVarHeaderLength { return fmt.Errorf("%w: %d", ErrHeartbeatAckParams, len(h.raw)) } pType, err := parseParamType(h.raw) if err != nil { return fmt.Errorf("%w: %v", ErrHeartbeatAckParams, err) //nolint:errorlint } if pType != heartbeatInfo { return fmt.Errorf("%w: instead have %s", ErrHeartbeatAckNotHeartbeatInfo, pType.String()) } var pHeader paramHeader if e := pHeader.unmarshal(h.raw); e != nil { return fmt.Errorf("%w: %v", ErrHeartbeatAckParams, e) //nolint:errorlint } plen := pHeader.length() if plen < initOptionalVarHeaderLength || plen > len(h.raw) { return fmt.Errorf("%w: %d", ErrHeartbeatAckParams, plen) } p, err := buildParam(pType, h.raw[:plen]) if err != nil { return fmt.Errorf("%w: %v", ErrHeartbeatAckMarshalParam, err) //nolint:errorlint } h.params = []param{p} // Any trailing bytes beyond the single param must be zero. if rem := h.raw[plen:]; len(rem) > 0 && !allZero(rem) { return ErrHeartbeatExtraNonZero } return nil } func (h *chunkHeartbeatAck) marshal() ([]byte, error) { if len(h.params) != 1 { return nil, ErrHeartbeatAckParams } switch h.params[0].(type) { case *paramHeartbeatInfo: // ParamHeartbeatInfo is valid default: return nil, ErrHeartbeatAckNotHeartbeatInfo } out := make([]byte, 0) for idx, p := range h.params { pp, err := p.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrHeartbeatAckMarshalParam, err) //nolint:errorlint } out = append(out, pp...) // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. *However, it does include padding of any variable-length // parameter except the last parameter in the chunk.* The receiver // MUST ignore the padding. if idx != len(h.params)-1 { out = padByte(out, getPadding(len(pp))) } } h.chunkHeader.typ = ctHeartbeatAck h.chunkHeader.raw = out return h.chunkHeader.marshal() } func (h *chunkHeartbeatAck) check() (abort bool, err error) { return false, nil } sctp-1.9.0/chunk_heartbeat_ack_test.go000066400000000000000000000122641512256410600200340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkHeartbeatAck_UnmarshalMarshal_Success(t *testing.T) { tt := []struct { name string info []byte }{ {"empty-info", []byte{}}, {"aligned-4", []byte{0x01, 0x02, 0x03, 0x04}}, {"non-aligned-5", []byte{0xaa, 0xbb, 0xcc, 0xdd, 0xee}}, } for _, tc := range tt { p := ¶mHeartbeatInfo{heartbeatInformation: tc.info} pp, err := p.marshal() if !assert.NoErrorf(t, err, "marshal paramHeartbeatInfo for %s", tc.name) { continue } ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: pp, } encoded, err := ch.marshal() if !assert.NoErrorf(t, err, "marshal chunkHeader for %s", tc.name) { continue } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(encoded) if !assert.NoErrorf(t, err, "unmarshal HeartbeatAck for %s", tc.name) { continue } assert.Equalf(t, ctHeartbeatAck, hbAck.typ, "chunk type for %s", tc.name) if assert.Lenf(t, hbAck.params, 1, "params length for %s", tc.name) { got, ok := hbAck.params[0].(*paramHeartbeatInfo) if assert.Truef(t, ok, "param type for %s", tc.name) { assert.Equalf(t, tc.info, got.heartbeatInformation, "heartbeat info for %s", tc.name) } } roundTrip, err := hbAck.marshal() if !assert.NoErrorf(t, err, "marshal HeartbeatAck (round-trip) for %s", tc.name) { continue } assert.Equalf(t, encoded, roundTrip, "round-trip bytes for %s", tc.name) } } func TestChunkHeartbeatAck_Unmarshal_EmptyBody(t *testing.T) { ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: nil, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) if !assert.NoError(t, err) { return } assert.Nil(t, hbAck.params) } func TestChunkHeartbeatAck_Unmarshal_Failure_WrongChunkType(t *testing.T) { p := ¶mHeartbeatInfo{heartbeatInformation: []byte{0x01, 0x02, 0x03, 0x04}} pp, err := p.marshal() if !assert.NoError(t, err) { return } ch := chunkHeader{ typ: ctHeartbeat, // wrong type on purpose flags: 0, raw: pp, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrChunkTypeNotHeartbeatAck) } func TestChunkHeartbeatAck_Unmarshal_Failure_BodyTooShort(t *testing.T) { // less than initOptionalVarHeaderLength bytes in the body. body := []byte{0xaa, 0xbb, 0xcc} ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: body, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatAckParams) } func TestChunkHeartbeatAck_Unmarshal_Failure_TruncatedParamBody(t *testing.T) { // build a valid HeartbeatInfo TLV, then truncate it so the advertised // length is larger than the available bytes. info := []byte{0x11, 0x22, 0x33, 0x44} p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return } if !assert.GreaterOrEqual(t, len(pp), initOptionalVarHeaderLength+1) { return } truncated := pp[:len(pp)-1] ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: truncated, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatAckParams) } func TestChunkHeartbeatAck_Unmarshal_Failure_WrongParamType(t *testing.T) { // construct a TLV with a param type that is *not* heartbeatInfo but is otherwise valid. plen := uint16(initOptionalVarHeaderLength) body := []byte{ 0x00, 0x02, // some param type != heartbeatInfo (which is 1 per RFC) byte(plen >> 8), byte(plen & 0xff), } ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: body, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatAckNotHeartbeatInfo) } func TestChunkHeartbeatAck_Unmarshal_Failure_TrailingNonZero(t *testing.T) { info := []byte{0x01, 0x02, 0x03, 0x04} p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return } // append a non-zero byte after the single parameter. body := append(append([]byte{}, pp...), 0x01) ch := chunkHeader{ typ: ctHeartbeatAck, flags: 0, raw: body, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hbAck := &chunkHeartbeatAck{} err = hbAck.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatExtraNonZero) } func TestChunkHeartbeatAck_Marshal_Failure_ParamCount(t *testing.T) { // no params hbAck := &chunkHeartbeatAck{} _, err := hbAck.marshal() assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatAckParams) // too many params p := ¶mHeartbeatInfo{} hbAck = &chunkHeartbeatAck{ params: []param{p, p}, } _, err = hbAck.marshal() assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatAckParams) } sctp-1.9.0/chunk_heartbeat_test.go000066400000000000000000000146751512256410600172260ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) // helper to build a raw HEARTBEAT chunk with a single HeartbeatInfo param. func buildHeartbeatChunk(t *testing.T, info []byte) []byte { t.Helper() p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return nil } ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: pp, } raw, err := ch.marshal() if !assert.NoError(t, err) { return nil } return raw } func TestChunkHeartbeat_Unmarshal_Success_EmptyBody(t *testing.T) { // HEARTBEAT with no body (header only) should be accepted, params left empty. ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: nil, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.NoError(t, err) assert.Equal(t, ctHeartbeat, hb.typ) assert.Len(t, hb.params, 0) } func TestChunkHeartbeat_Unmarshal_Success_WithInfo(t *testing.T) { info := []byte{0xde, 0xad, 0xbe, 0xef} raw := buildHeartbeatChunk(t, info) if raw == nil { return } hb := &chunkHeartbeat{} err := hb.unmarshal(raw) if !assert.NoError(t, err) { return } assert.Equal(t, ctHeartbeat, hb.typ) if assert.Len(t, hb.params, 1) { p, ok := hb.params[0].(*paramHeartbeatInfo) if assert.True(t, ok, "param should be *paramHeartbeatInfo") { assert.Equal(t, info, p.heartbeatInformation) } } } func TestChunkHeartbeat_Unmarshal_Success_WithZeroPaddingAfterTLV(t *testing.T) { info := []byte{0xaa, 0xbb, 0xcc, 0xdd} p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return } // add extra zero padding outside the TLV body := append([]byte{}, pp...) body = append(body, 0x00, 0x00, 0x00, 0x00) ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: body, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) if !assert.NoError(t, err) { return } if assert.Len(t, hb.params, 1) { p2, ok := hb.params[0].(*paramHeartbeatInfo) if assert.True(t, ok) { assert.Equal(t, info, p2.heartbeatInformation) } } } func TestChunkHeartbeat_Unmarshal_Failure_TrailingNonZeroBytes(t *testing.T) { info := []byte{0x01, 0x02, 0x03, 0x04} p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return } body := append([]byte{}, pp...) body = append(body, 0x00, 0x00, 0x00, 0x01) // non-zero trailing byte ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: body, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatExtraNonZero) } func TestChunkHeartbeat_Unmarshal_Failure_WrongChunkType(t *testing.T) { ch := chunkHeader{ typ: ctInit, // not ctHeartbeat flags: 0, raw: nil, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrChunkTypeNotHeartbeat) } func TestChunkHeartbeat_Unmarshal_Failure_TooShortBody(t *testing.T) { // raw shorter than initOptionalVarHeaderLength should fail. ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: []byte{0x00, 0x01, 0x02}, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatNotLongEnoughInfo) } func TestChunkHeartbeat_Unmarshal_Failure_TruncatedParamBody(t *testing.T) { // build a valid HeartbeatInfo TLV, then truncate it so the advertised // length is larger than the available bytes. info := []byte{0x11, 0x22, 0x33, 0x44} p := ¶mHeartbeatInfo{heartbeatInformation: info} pp, err := p.marshal() if !assert.NoError(t, err) { return } if !assert.GreaterOrEqual(t, len(pp), initOptionalVarHeaderLength+1) { return } truncated := pp[:len(pp)-1] // chop one byte off -> length mismatch ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: truncated, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrParseParamTypeFailed) } func TestChunkHeartbeat_Unmarshal_Failure_ParamTypeNotHeartbeatInfo(t *testing.T) { // use a different param type (e.g., StateCookie) to trigger ErrHeartbeatParam. p := ¶mStateCookie{cookie: []byte{0x01, 0x02, 0x03}} pp, err := p.marshal() if !assert.NoError(t, err) { return } ch := chunkHeader{ typ: ctHeartbeat, flags: 0, raw: pp, } raw, err := ch.marshal() if !assert.NoError(t, err) { return } hb := &chunkHeartbeat{} err = hb.unmarshal(raw) assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatParam) } func TestChunkHeartbeat_Marshal_RoundTrip(t *testing.T) { info := []byte{0xca, 0xfe, 0xba, 0xbe} hb := &chunkHeartbeat{ params: []param{ ¶mHeartbeatInfo{ heartbeatInformation: info, }, }, } raw, err := hb.Marshal() if !assert.NoError(t, err) { return } hb2 := &chunkHeartbeat{} err = hb2.unmarshal(raw) if !assert.NoError(t, err) { return } if assert.Len(t, hb2.params, 1) { p, ok := hb2.params[0].(*paramHeartbeatInfo) if assert.True(t, ok) { assert.Equal(t, info, p.heartbeatInformation) } } } func TestChunkHeartbeat_Marshal_Failure_NoParams(t *testing.T) { hb := &chunkHeartbeat{ params: nil, } _, err := hb.Marshal() assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatMarshalNoInfo) } func TestChunkHeartbeat_Marshal_Failure_MoreThanOneParam(t *testing.T) { hb := &chunkHeartbeat{ params: []param{ ¶mHeartbeatInfo{heartbeatInformation: []byte{0x01}}, ¶mHeartbeatInfo{heartbeatInformation: []byte{0x02}}, }, } _, err := hb.Marshal() assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatMarshalNoInfo) } func TestChunkHeartbeat_Marshal_Failure_WrongParamType(t *testing.T) { // Not a *paramHeartbeatInfo -> ErrHeartbeatParam. hb := &chunkHeartbeat{ params: []param{ ¶mStateCookie{cookie: []byte{0x01}}, }, } _, err := hb.Marshal() assert.Error(t, err) assert.ErrorIs(t, err, ErrHeartbeatParam) } func TestChunkHeartbeat_Check(t *testing.T) { hb := &chunkHeartbeat{} abort, err := hb.check() assert.False(t, abort) assert.NoError(t, err) } sctp-1.9.0/chunk_init.go000066400000000000000000000117201512256410600151570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "errors" "fmt" ) /* Init represents an SCTP Chunk of type INIT See chunkInitCommon for the fixed headers Variable Parameters Status Type Value ------------------------------------------------------------- IPv4 IP (Note 1) Optional 5 IPv6 IP (Note 1) Optional 6 Cookie Preservative Optional 9 Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) Host Name IP (Note 3) Optional 11 Supported IP Types (Note 4) Optional 12 */ type chunkInit struct { chunkHeader chunkInitCommon } // Init chunk errors. var ( ErrChunkTypeNotTypeInit = errors.New("ChunkType is not of type INIT") ErrChunkValueNotLongEnough = errors.New("chunk Value isn't long enough for mandatory parameters exp") ErrChunkTypeInitFlagZero = errors.New("ChunkType of type INIT flags must be all 0") ErrChunkTypeInitUnmarshalFailed = errors.New("failed to unmarshal INIT body") ErrChunkTypeInitMarshalFailed = errors.New("failed marshaling INIT common data") ErrChunkTypeInitInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") ErrInitInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") ErrInitOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") ErrInitAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") ErrInitUnknownParam = errors.New("INIT with unknown param") ) func (i *chunkInit) unmarshal(raw []byte) error { if err := i.chunkHeader.unmarshal(raw); err != nil { return err } if i.typ != ctInit { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotTypeInit, i.typ.String()) } else if len(i.raw) < initChunkMinLength { return fmt.Errorf("%w: %d actual: %d", ErrChunkValueNotLongEnough, initChunkMinLength, len(i.raw)) } // The Chunk Flags field in INIT is reserved, and all bits in it should // be set to 0 by the sender and ignored by the receiver. The sequence // of parameters within an INIT can be processed in any order. if i.flags != 0 { return ErrChunkTypeInitFlagZero } if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { return fmt.Errorf("%w: %v", ErrChunkTypeInitUnmarshalFailed, err) //nolint:errorlint } return nil } func (i *chunkInit) marshal() ([]byte, error) { initShared, err := i.chunkInitCommon.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkTypeInitMarshalFailed, err) //nolint:errorlint } i.chunkHeader.typ = ctInit i.chunkHeader.raw = initShared return i.chunkHeader.marshal() } func (i *chunkInit) check() (abort bool, err error) { // The receiver of the INIT (the responding end) records the value of // the Initiate Tag parameter. This value MUST be placed into the // Verification Tag field of every SCTP packet that the receiver of // the INIT transmits within this association. // // The Initiate Tag is allowed to have any value except 0. See // Section 5.3.1 for more on the selection of the tag value. // // If the value of the Initiate Tag in a received INIT chunk is found // to be 0, the receiver MUST treat it as an error and close the // association by transmitting an ABORT. if i.initiateTag == 0 { return true, ErrChunkTypeInitInitateTagZero } // Defines the maximum number of streams the sender of this INIT // chunk allows the peer end to create in this association. The // value 0 MUST NOT be used. // // Note: There is no negotiation of the actual number of streams but // instead the two endpoints will use the min(requested, offered). // See Section 5.1.1 for details. // // Note: A receiver of an INIT with the MIS value of 0 SHOULD abort // the association. if i.numInboundStreams == 0 { return true, ErrInitInboundStreamRequestZero } // Defines the number of outbound streams the sender of this INIT // chunk wishes to create in this association. The value of 0 MUST // NOT be used. // // Note: A receiver of an INIT with the OS value set to 0 SHOULD // abort the association. if i.numOutboundStreams == 0 { return true, ErrInitOutboundStreamRequestZero } // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT // ACK. if i.advertisedReceiverWindowCredit < 1500 { return true, ErrInitAdvertisedReceiver1500 } for _, p := range i.unrecognizedParams { if p.unrecognizedAction == paramHeaderUnrecognizedActionStop || p.unrecognizedAction == paramHeaderUnrecognizedActionStopAndReport { return true, ErrInitUnknownParam } } return false, nil } // String makes chunkInit printable. func (i *chunkInit) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } sctp-1.9.0/chunk_init_ack.go000066400000000000000000000117461512256410600160050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "errors" "fmt" ) /* chunkInitAck represents an SCTP Chunk of type INIT ACK See chunkInitCommon for the fixed headers Variable Parameters Status Type Value ------------------------------------------------------------- State Cookie Mandatory 7 IPv4 IP (Note 1) Optional 5 IPv6 IP (Note 1) Optional 6 Unrecognized Parameter Optional 8 Reserved for ECN Capable (Note 2) Optional 32768 (0x8000) Host Name IP (Note 3) Optional 11 */ type chunkInitAck struct { chunkHeader chunkInitCommon } // Init ack chunk errors. var ( ErrChunkTypeNotInitAck = errors.New("ChunkType is not of type INIT ACK") ErrChunkNotLongEnoughForParams = errors.New("chunk Value isn't long enough for mandatory parameters exp") ErrChunkTypeInitAckFlagZero = errors.New("ChunkType of type INIT ACK flags must be all 0") ErrInitAckUnmarshalFailed = errors.New("failed to unmarshal INIT body") ErrInitCommonDataMarshalFailed = errors.New("failed marshaling INIT common data") ErrChunkTypeInitAckInitateTagZero = errors.New("ChunkType of type INIT ACK InitiateTag must not be 0") ErrInitAckInboundStreamRequestZero = errors.New("INIT ACK inbound stream request must be > 0") ErrInitAckOutboundStreamRequestZero = errors.New("INIT ACK outbound stream request must be > 0") ErrInitAckAdvertisedReceiver1500 = errors.New("INIT ACK Advertised Receiver Window Credit (a_rwnd) must be >= 1500") ) func (i *chunkInitAck) unmarshal(raw []byte) error { if err := i.chunkHeader.unmarshal(raw); err != nil { return err } if i.typ != ctInitAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotInitAck, i.typ.String()) } else if len(i.raw) < initChunkMinLength { return fmt.Errorf("%w: %d actual: %d", ErrChunkNotLongEnoughForParams, initChunkMinLength, len(i.raw)) } // The Chunk Flags field in INIT is reserved, and all bits in it should // be set to 0 by the sender and ignored by the receiver. The sequence // of parameters within an INIT can be processed in any order. if i.flags != 0 { return ErrChunkTypeInitAckFlagZero } if err := i.chunkInitCommon.unmarshal(i.raw); err != nil { return fmt.Errorf("%w: %v", ErrInitAckUnmarshalFailed, err) //nolint:errorlint } return nil } func (i *chunkInitAck) marshal() ([]byte, error) { initShared, err := i.chunkInitCommon.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrInitCommonDataMarshalFailed, err) //nolint:errorlint } i.chunkHeader.typ = ctInitAck i.chunkHeader.raw = initShared return i.chunkHeader.marshal() } func (i *chunkInitAck) check() (abort bool, err error) { // The receiver of the INIT ACK records the value of the Initiate Tag // parameter. This value MUST be placed into the Verification Tag // field of every SCTP packet that the INIT ACK receiver transmits // within this association. // // The Initiate Tag MUST NOT take the value 0. See Section 5.3.1 for // more on the selection of the Initiate Tag value. // // If the value of the Initiate Tag in a received INIT ACK chunk is // found to be 0, the receiver MUST destroy the association // discarding its TCB. The receiver MAY send an ABORT for debugging // purpose. if i.initiateTag == 0 { abort = true return abort, ErrChunkTypeInitAckInitateTagZero } // Defines the maximum number of streams the sender of this INIT ACK // chunk allows the peer end to create in this association. The // value 0 MUST NOT be used. // // Note: There is no negotiation of the actual number of streams but // instead the two endpoints will use the min(requested, offered). // See Section 5.1.1 for details. // // Note: A receiver of an INIT ACK with the MIS value set to 0 SHOULD // destroy the association discarding its TCB. if i.numInboundStreams == 0 { abort = true return abort, ErrInitAckInboundStreamRequestZero } // Defines the number of outbound streams the sender of this INIT ACK // chunk wishes to create in this association. The value of 0 MUST // NOT be used, and the value MUST NOT be greater than the MIS value // sent in the INIT chunk. // // Note: A receiver of an INIT ACK with the OS value set to 0 SHOULD // destroy the association discarding its TCB. if i.numOutboundStreams == 0 { abort = true return abort, ErrInitAckOutboundStreamRequestZero } // An SCTP receiver MUST be able to receive a minimum of 1500 bytes in // one SCTP packet. This means that an SCTP endpoint MUST NOT indicate // less than 1500 bytes in its initial a_rwnd sent in the INIT or INIT // ACK. if i.advertisedReceiverWindowCredit < 1500 { abort = true return abort, ErrInitAckAdvertisedReceiver1500 } return false, nil } // String makes chunkInitAck printable. func (i *chunkInitAck) String() string { return fmt.Sprintf("%s\n%s", i.chunkHeader, i.chunkInitCommon) } sctp-1.9.0/chunk_init_common.go000066400000000000000000000140201512256410600165230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkInitCommon represents an SCTP Chunk body of type INIT and INIT ACK 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 1 | Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Initiate Tag | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Advertised Receiver Window Credit (a_rwnd) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Number of Outbound Streams | Number of Inbound Streams | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Initial TSN | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Optional/Variable-Length Parameters | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ The INIT chunk contains the following parameters. Unless otherwise noted, each parameter MUST only be included once in the INIT chunk. Fixed Parameters Status ---------------------------------------------- Initiate Tag Mandatory Advertised Receiver Window Credit Mandatory Number of Outbound Streams Mandatory Number of Inbound Streams Mandatory Initial TSN Mandatory */ type chunkInitCommon struct { initiateTag uint32 advertisedReceiverWindowCredit uint32 numOutboundStreams uint16 numInboundStreams uint16 initialTSN uint32 params []param unrecognizedParams []paramHeader } const ( initChunkMinLength = 16 initOptionalVarHeaderLength = 4 ) // Init chunk errors. var ( ErrInitChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrInitAckMarshalParam = errors.New("unable to marshal parameter for INIT/INITACK") ) func (i *chunkInitCommon) unmarshal(raw []byte) error { i.initiateTag = binary.BigEndian.Uint32(raw[0:]) i.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(raw[4:]) i.numOutboundStreams = binary.BigEndian.Uint16(raw[8:]) i.numInboundStreams = binary.BigEndian.Uint16(raw[10:]) i.initialTSN = binary.BigEndian.Uint32(raw[12:]) // https://tools.ietf.org/html/rfc4960#section-3.2.1 // // Chunk values of SCTP control chunks consist of a chunk-type-specific // header of required fields, followed by zero or more parameters. The // optional and variable-length parameters contained in a chunk are // defined in a Type-Length-Value format as shown below. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type | Parameter Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | | // | Parameter Value | // | | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ offset := initChunkMinLength remaining := len(raw) - offset for remaining > 0 { if remaining > initOptionalVarHeaderLength { var pHeader paramHeader if err := pHeader.unmarshal(raw[offset:]); err != nil { return fmt.Errorf("%w: %v", ErrInitChunkParseParamTypeFailed, err) //nolint:errorlint } p, err := buildParam(pHeader.typ, raw[offset:]) if err != nil { i.unrecognizedParams = append(i.unrecognizedParams, pHeader) } else { i.params = append(i.params, p) } padding := getPadding(pHeader.length()) offset += pHeader.length() + padding remaining -= pHeader.length() + padding } else { break } } return nil } func (i *chunkInitCommon) marshal() ([]byte, error) { out := make([]byte, initChunkMinLength) binary.BigEndian.PutUint32(out[0:], i.initiateTag) binary.BigEndian.PutUint32(out[4:], i.advertisedReceiverWindowCredit) binary.BigEndian.PutUint16(out[8:], i.numOutboundStreams) binary.BigEndian.PutUint16(out[10:], i.numInboundStreams) binary.BigEndian.PutUint32(out[12:], i.initialTSN) for idx, p := range i.params { pp, err := p.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrInitAckMarshalParam, err) //nolint:errorlint } out = append(out, pp...) //nolint:makezero // TODO: fix // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. *However, it does include padding of any variable-length // parameter except the last parameter in the chunk.* The receiver // MUST ignore the padding. if idx != len(i.params)-1 { out = padByte(out, getPadding(len(pp))) } } return out, nil } // String makes chunkInitCommon printable. func (i chunkInitCommon) String() string { format := `initiateTag: %d advertisedReceiverWindowCredit: %d numOutboundStreams: %d numInboundStreams: %d initialTSN: %d` res := fmt.Sprintf(format, i.initiateTag, i.advertisedReceiverWindowCredit, i.numOutboundStreams, i.numInboundStreams, i.initialTSN, ) for i, param := range i.params { res += fmt.Sprintf("Param %d:\n %s", i, param) } return res } // allZero returns true if every byte is 0x00. func allZero(b []byte) bool { for _, v := range b { if v != 0 { return false } } return true } sctp-1.9.0/chunk_init_test.go000066400000000000000000000023501512256410600162150ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkInit_UnrecognizedParameters(t *testing.T) { initChunkHeader := []byte{ 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, } unrecognizedSkip := append([]byte{}, initChunkHeader...) unrecognizedSkip = append(unrecognizedSkip, byte(paramHeaderUnrecognizedActionSkip), 0xFF, 0x00, 0x04, 0x00) initCommonChunk := &chunkInitCommon{} assert.NoError(t, initCommonChunk.unmarshal(unrecognizedSkip)) assert.Equal(t, 1, len(initCommonChunk.unrecognizedParams)) assert.Equal(t, paramHeaderUnrecognizedActionSkip, initCommonChunk.unrecognizedParams[0].unrecognizedAction) unrecognizedStop := append([]byte{}, initChunkHeader...) unrecognizedStop = append(unrecognizedStop, byte(paramHeaderUnrecognizedActionStop), 0xFF, 0x00, 0x04, 0x00) initCommonChunk = &chunkInitCommon{} assert.NoError(t, initCommonChunk.unmarshal(unrecognizedStop)) assert.Equal(t, 1, len(initCommonChunk.unrecognizedParams)) assert.Equal(t, paramHeaderUnrecognizedActionStop, initCommonChunk.unrecognizedParams[0].unrecognizedAction) } sctp-1.9.0/chunk_payload_data.go000066400000000000000000000151211512256410600166350ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" "time" ) /* chunkPayloadData represents an SCTP Chunk of type DATA 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 0 | Reserved|U|B|E| Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | TSN | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Stream Identifier S | Stream Sequence Number n | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Payload Protocol Identifier | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | User Data (seq n of Stream S) | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ An unfragmented user message shall have both the B and E bits set to '1'. Setting both B and E bits to '0' indicates a middle fragment of a multi-fragment user message, as summarized in the following table: B E Description ============================================================ | 1 0 | First piece of a fragmented user message | +----------------------------------------------------------+ | 0 0 | Middle piece of a fragmented user message | +----------------------------------------------------------+ | 0 1 | Last piece of a fragmented user message | +----------------------------------------------------------+ | 1 1 | Unfragmented message | ============================================================ | Table 1: Fragment Description Flags | ============================================================ */ type chunkPayloadData struct { chunkHeader unordered bool beginningFragment bool endingFragment bool immediateSack bool tsn uint32 streamIdentifier uint16 streamSequenceNumber uint16 payloadType PayloadProtocolIdentifier userData []byte // Whether this data chunk was acknowledged (received by peer) acked bool missIndicator uint32 // Partial-reliability parameters used only by sender since time.Time nSent uint32 // number of transmission made for this chunk _abandoned bool _allInflight bool // valid only with the first fragment // Retransmission flag set when T1-RTX timeout occurred and this // chunk is still in the inflight queue retransmit bool head *chunkPayloadData // link to the head of the fragment rackPrev *chunkPayloadData rackNext *chunkPayloadData rackInList bool } const ( payloadDataEndingFragmentBitmask = 1 payloadDataBeginingFragmentBitmask = 2 payloadDataUnorderedBitmask = 4 payloadDataImmediateSACK = 8 payloadDataHeaderSize = 12 ) // PayloadProtocolIdentifier is an enum for DataChannel payload types. type PayloadProtocolIdentifier uint32 // PayloadProtocolIdentifier enums // https://www.iana.org/assignments/sctp-parameters/sctp-parameters.xhtml#sctp-parameters-25 const ( PayloadTypeUnknown PayloadProtocolIdentifier = 0 PayloadTypeWebRTCDCEP PayloadProtocolIdentifier = 50 PayloadTypeWebRTCString PayloadProtocolIdentifier = 51 PayloadTypeWebRTCBinary PayloadProtocolIdentifier = 53 PayloadTypeWebRTCStringEmpty PayloadProtocolIdentifier = 56 PayloadTypeWebRTCBinaryEmpty PayloadProtocolIdentifier = 57 ) // Data chunk errors. var ( ErrChunkPayloadSmall = errors.New("packet is smaller than the header size") ) func (p PayloadProtocolIdentifier) String() string { switch p { case PayloadTypeWebRTCDCEP: return "WebRTC DCEP" case PayloadTypeWebRTCString: return "WebRTC String" case PayloadTypeWebRTCBinary: return "WebRTC Binary" case PayloadTypeWebRTCStringEmpty: return "WebRTC String (Empty)" case PayloadTypeWebRTCBinaryEmpty: return "WebRTC Binary (Empty)" default: return fmt.Sprintf("Unknown Payload Protocol Identifier: %d", p) } } func (p *chunkPayloadData) unmarshal(raw []byte) error { if err := p.chunkHeader.unmarshal(raw); err != nil { return err } p.immediateSack = p.flags&payloadDataImmediateSACK != 0 p.unordered = p.flags&payloadDataUnorderedBitmask != 0 p.beginningFragment = p.flags&payloadDataBeginingFragmentBitmask != 0 p.endingFragment = p.flags&payloadDataEndingFragmentBitmask != 0 if len(p.raw) < payloadDataHeaderSize { return ErrChunkPayloadSmall } p.tsn = binary.BigEndian.Uint32(p.raw[0:]) p.streamIdentifier = binary.BigEndian.Uint16(p.raw[4:]) p.streamSequenceNumber = binary.BigEndian.Uint16(p.raw[6:]) p.payloadType = PayloadProtocolIdentifier(binary.BigEndian.Uint32(p.raw[8:])) p.userData = p.raw[payloadDataHeaderSize:] return nil } func (p *chunkPayloadData) marshal() ([]byte, error) { payRaw := make([]byte, payloadDataHeaderSize+len(p.userData)) binary.BigEndian.PutUint32(payRaw[0:], p.tsn) binary.BigEndian.PutUint16(payRaw[4:], p.streamIdentifier) binary.BigEndian.PutUint16(payRaw[6:], p.streamSequenceNumber) binary.BigEndian.PutUint32(payRaw[8:], uint32(p.payloadType)) copy(payRaw[payloadDataHeaderSize:], p.userData) flags := uint8(0) if p.endingFragment { flags = 1 } if p.beginningFragment { flags |= 1 << 1 } if p.unordered { flags |= 1 << 2 } if p.immediateSack { flags |= 1 << 3 } p.chunkHeader.flags = flags p.chunkHeader.typ = ctPayloadData p.chunkHeader.raw = payRaw return p.chunkHeader.marshal() } func (p *chunkPayloadData) check() (abort bool, err error) { return false, nil } // String makes chunkPayloadData printable. func (p *chunkPayloadData) String() string { return fmt.Sprintf("%s\n%d", p.chunkHeader, p.tsn) } func (p *chunkPayloadData) abandoned() bool { if p.head != nil { return p.head._abandoned && p.head._allInflight } return p._abandoned && p._allInflight } func (p *chunkPayloadData) setAbandoned(abandoned bool) { if p.head != nil { p.head._abandoned = abandoned return } p._abandoned = abandoned } func (p *chunkPayloadData) setAllInflight() { if p.endingFragment { if p.head != nil { p.head._allInflight = true } else { p._allInflight = true } } } func (p *chunkPayloadData) isFragmented() bool { return p.head != nil || !p.beginningFragment || !p.endingFragment } sctp-1.9.0/chunk_reconfig.go000066400000000000000000000062011512256410600160060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) // https://tools.ietf.org/html/rfc6525#section-3.1 // chunkReconfig represents an SCTP Chunk used to reconfigure streams. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type = 130 | Chunk Flags | Chunk Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ \ // / Re-configuration Parameter / // \ \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // \ \ // / Re-configuration Parameter (optional) / // \ \ // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type chunkReconfig struct { chunkHeader paramA param paramB param } // Reconfigure chunk errors. var ( ErrChunkParseParamTypeFailed = errors.New("failed to parse param type") ErrChunkMarshalParamAReconfigFailed = errors.New("unable to marshal parameter A for reconfig") ErrChunkMarshalParamBReconfigFailed = errors.New("unable to marshal parameter B for reconfig") ) func (c *chunkReconfig) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } pType, err := parseParamType(c.raw) if err != nil { return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) //nolint:errorlint } a, err := buildParam(pType, c.raw) if err != nil { return err } c.paramA = a padding := getPadding(a.length()) offset := a.length() + padding if len(c.raw) > offset { pType, err := parseParamType(c.raw[offset:]) if err != nil { return fmt.Errorf("%w: %v", ErrChunkParseParamTypeFailed, err) //nolint:errorlint } b, err := buildParam(pType, c.raw[offset:]) if err != nil { return err } c.paramB = b } return nil } func (c *chunkReconfig) marshal() ([]byte, error) { out, err := c.paramA.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamAReconfigFailed, err) //nolint:errorlint } if c.paramB != nil { // Pad param A out = padByte(out, getPadding(len(out))) outB, err := c.paramB.marshal() if err != nil { return nil, fmt.Errorf("%w: %v", ErrChunkMarshalParamBReconfigFailed, err) //nolint:errorlint } out = append(out, outB...) } c.typ = ctReconfig c.raw = out return c.chunkHeader.marshal() } func (c *chunkReconfig) check() (abort bool, err error) { // nolint:godox // TODO: check allowed combinations: // https://tools.ietf.org/html/rfc6525#section-3.1 return true, nil } // String makes chunkReconfig printable. func (c *chunkReconfig) String() string { res := fmt.Sprintf("Param A:\n %s", c.paramA) if c.paramB != nil { res += fmt.Sprintf("Param B:\n %s", c.paramB) } return res } sctp-1.9.0/chunk_reconfig_test.go000066400000000000000000000033741512256410600170550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkReconfig_Success(t *testing.T) { tt := []struct { binary []byte }{ { // Note: chunk trailing padding is added in packet.marshal append( []byte{0x82, 0x0, 0x0, 0x1a}, testChunkReconfigParamA()..., ), }, {append([]byte{0x82, 0x0, 0x0, 0x14}, testChunkReconfigParamB()...)}, {append([]byte{0x82, 0x0, 0x0, 0x10}, testChunkReconfigResponce()...)}, { append( append([]byte{0x82, 0x0, 0x0, 0x2c}, padByte(testChunkReconfigParamA(), 2)...), testChunkReconfigParamB()...), }, { // Note: chunk trailing padding is added in packet.marshal append( append([]byte{0x82, 0x0, 0x0, 0x2a}, testChunkReconfigParamB()...), testChunkReconfigParamA()..., ), }, } for i, tc := range tt { actual := &chunkReconfig{} err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d: %v", i) b, err := actual.marshal() assert.NoError(t, err) assert.Equalf(t, tc.binary, b, "test %d not equal", i) } } func TestChunkReconfigUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"chunk header to short", []byte{0x82}}, {"missing parse param type (A)", []byte{0x82, 0x0, 0x0, 0x4}}, {"wrong param (A)", []byte{0x82, 0x0, 0x0, 0x8, 0x0, 0xd, 0x0, 0x0}}, { "wrong param (B)", append(append([]byte{0x82, 0x0, 0x0, 0x18}, testChunkReconfigParamB()...), []byte{0x0, 0xd, 0x0, 0x0}...), }, } for i, tc := range tt { actual := &chunkReconfig{} err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/chunk_selective_ack.go000066400000000000000000000124471512256410600170240ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkSelectiveAck represents an SCTP Chunk of type SACK This chunk is sent to the peer endpoint to acknowledge received DATA chunks and to inform the peer endpoint of gaps in the received subsequences of DATA chunks as represented by their TSNs. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 3 |Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cumulative TSN Ack | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Advertised Receiver Window Credit (a_rwnd) | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Number of Gap Ack Blocks = N | Number of Duplicate TSNs = X | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Gap Ack Block #1 Start | Gap Ack Block #1 End | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / \ ... \ / / +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Gap Ack Block #N Start | Gap Ack Block #N End | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Duplicate TSN 1 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / / \ ... \ / / +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Duplicate TSN X | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type gapAckBlock struct { start uint16 end uint16 } // Selective ack chunk errors. var ( ErrChunkTypeNotSack = errors.New("ChunkType is not of type SACK") ErrSackSizeNotLargeEnoughInfo = errors.New("SACK Chunk size is not large enough to contain header") ErrSackSizeNotMatchPredicted = errors.New("SACK Chunk size does not match predicted amount from header values") ) // String makes gapAckBlock printable. func (g gapAckBlock) String() string { return fmt.Sprintf("%d - %d", g.start, g.end) } type chunkSelectiveAck struct { chunkHeader cumulativeTSNAck uint32 advertisedReceiverWindowCredit uint32 gapAckBlocks []gapAckBlock duplicateTSN []uint32 } const ( selectiveAckHeaderSize = 12 ) func (s *chunkSelectiveAck) unmarshal(raw []byte) error { if err := s.chunkHeader.unmarshal(raw); err != nil { return err } if s.typ != ctSack { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotSack, s.typ.String()) } if len(s.raw) < selectiveAckHeaderSize { return fmt.Errorf("%w: %v remaining, needs %v bytes", ErrSackSizeNotLargeEnoughInfo, len(s.raw), selectiveAckHeaderSize) } s.cumulativeTSNAck = binary.BigEndian.Uint32(s.raw[0:]) s.advertisedReceiverWindowCredit = binary.BigEndian.Uint32(s.raw[4:]) s.gapAckBlocks = make([]gapAckBlock, binary.BigEndian.Uint16(s.raw[8:])) s.duplicateTSN = make([]uint32, binary.BigEndian.Uint16(s.raw[10:])) if len(s.raw) != selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN))) { return ErrSackSizeNotMatchPredicted } offset := selectiveAckHeaderSize for i := range s.gapAckBlocks { s.gapAckBlocks[i].start = binary.BigEndian.Uint16(s.raw[offset:]) s.gapAckBlocks[i].end = binary.BigEndian.Uint16(s.raw[offset+2:]) offset += 4 } for i := range s.duplicateTSN { s.duplicateTSN[i] = binary.BigEndian.Uint32(s.raw[offset:]) offset += 4 } return nil } func (s *chunkSelectiveAck) marshal() ([]byte, error) { sackRaw := make([]byte, selectiveAckHeaderSize+(4*len(s.gapAckBlocks)+(4*len(s.duplicateTSN)))) binary.BigEndian.PutUint32(sackRaw[0:], s.cumulativeTSNAck) binary.BigEndian.PutUint32(sackRaw[4:], s.advertisedReceiverWindowCredit) binary.BigEndian.PutUint16(sackRaw[8:], uint16(len(s.gapAckBlocks))) //nolint:gosec // G115 binary.BigEndian.PutUint16(sackRaw[10:], uint16(len(s.duplicateTSN))) //nolint:gosec // G115 offset := selectiveAckHeaderSize for _, g := range s.gapAckBlocks { binary.BigEndian.PutUint16(sackRaw[offset:], g.start) binary.BigEndian.PutUint16(sackRaw[offset+2:], g.end) offset += 4 } for _, t := range s.duplicateTSN { binary.BigEndian.PutUint32(sackRaw[offset:], t) offset += 4 } s.chunkHeader.typ = ctSack s.chunkHeader.raw = sackRaw return s.chunkHeader.marshal() } func (s *chunkSelectiveAck) check() (abort bool, err error) { return false, nil } // String makes chunkSelectiveAck printable. func (s *chunkSelectiveAck) String() string { res := fmt.Sprintf("SACK cumTsnAck=%d arwnd=%d dupTsn=%d", s.cumulativeTSNAck, s.advertisedReceiverWindowCredit, s.duplicateTSN) for _, gap := range s.gapAckBlocks { res = fmt.Sprintf("%s\n gap ack: %s", res, gap) } return res } sctp-1.9.0/chunk_shutdown.go000066400000000000000000000033561512256410600160750ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkShutdown represents an SCTP Chunk of type chunkShutdown 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 7 | Chunk Flags | Length = 8 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cumulative TSN Ack | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdown struct { chunkHeader cumulativeTSNAck uint32 } const ( cumulativeTSNAckLength = 4 ) // Shutdown chunk errors. var ( ErrInvalidChunkSize = errors.New("invalid chunk size") ErrChunkTypeNotShutdown = errors.New("ChunkType is not of type SHUTDOWN") ) func (c *chunkShutdown) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdown { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdown, c.typ.String()) } if len(c.raw) != cumulativeTSNAckLength { return ErrInvalidChunkSize } c.cumulativeTSNAck = binary.BigEndian.Uint32(c.raw[0:]) return nil } func (c *chunkShutdown) marshal() ([]byte, error) { out := make([]byte, cumulativeTSNAckLength) binary.BigEndian.PutUint32(out[0:], c.cumulativeTSNAck) c.typ = ctShutdown c.raw = out return c.chunkHeader.marshal() } func (c *chunkShutdown) check() (abort bool, err error) { return false, nil } // String makes chunkShutdown printable. func (c *chunkShutdown) String() string { return c.chunkHeader.String() } sctp-1.9.0/chunk_shutdown_ack.go000066400000000000000000000024151512256410600167060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* chunkShutdownAck represents an SCTP Chunk of type chunkShutdownAck 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 8 | Chunk Flags | Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdownAck struct { chunkHeader } // Shutdown ack chunk errors. var ( ErrChunkTypeNotShutdownAck = errors.New("ChunkType is not of type SHUTDOWN-ACK") ) func (c *chunkShutdownAck) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdownAck { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownAck, c.typ.String()) } return nil } func (c *chunkShutdownAck) marshal() ([]byte, error) { c.typ = ctShutdownAck return c.chunkHeader.marshal() } func (c *chunkShutdownAck) check() (abort bool, err error) { return false, nil } // String makes chunkShutdownAck printable. func (c *chunkShutdownAck) String() string { return c.chunkHeader.String() } sctp-1.9.0/chunk_shutdown_ack_test.go000066400000000000000000000020641512256410600177450ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //nolint:dupl package sctp import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestChunkShutdownAck_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x08, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownAck{} err := actual.unmarshal(tc.binary) require.NoErrorf(t, err, "failed to unmarshal #%d", i) b, err := actual.marshal() require.NoError(t, err) assert.Equalf(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdownAck_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"length too short", []byte{0x08, 0x00, 0x00}}, {"length too long", []byte{0x08, 0x00, 0x00, 0x04, 0x12}}, {"invalid type", []byte{0x0f, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownAck{} err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/chunk_shutdown_complete.go000066400000000000000000000025231512256410600177600ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* chunkShutdownComplete represents an SCTP Chunk of type chunkShutdownComplete 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Type = 14 |Reserved |T| Length = 4 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+. */ type chunkShutdownComplete struct { chunkHeader } // Shutdown complete chunk errors. var ( ErrChunkTypeNotShutdownComplete = errors.New("ChunkType is not of type SHUTDOWN-COMPLETE") ) func (c *chunkShutdownComplete) unmarshal(raw []byte) error { if err := c.chunkHeader.unmarshal(raw); err != nil { return err } if c.typ != ctShutdownComplete { return fmt.Errorf("%w: actually is %s", ErrChunkTypeNotShutdownComplete, c.typ.String()) } return nil } func (c *chunkShutdownComplete) marshal() ([]byte, error) { c.typ = ctShutdownComplete return c.chunkHeader.marshal() } func (c *chunkShutdownComplete) check() (abort bool, err error) { return false, nil } // String makes chunkShutdownComplete printable. func (c *chunkShutdownComplete) String() string { return c.chunkHeader.String() } sctp-1.9.0/chunk_shutdown_complete_test.go000066400000000000000000000021271512256410600210170ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //nolint:dupl package sctp import ( "testing" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) func TestChunkShutdownComplete_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x0e, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownComplete{} err := actual.unmarshal(tc.binary) require.NoErrorf(t, err, "failed to unmarshal #%d", i) b, err := actual.marshal() require.NoError(t, err) assert.Equalf(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdownComplete_Failure(t *testing.T) { //nolint:dupl tt := []struct { name string binary []byte }{ {"length too short", []byte{0x0e, 0x00, 0x00}}, {"length too long", []byte{0x0e, 0x00, 0x00, 0x04, 0x12}}, {"invalid type", []byte{0x0f, 0x00, 0x00, 0x04}}, } for i, tc := range tt { actual := &chunkShutdownComplete{} err := actual.unmarshal(tc.binary) require.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/chunk_shutdown_test.go000066400000000000000000000023661512256410600171340ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkShutdown_Success(t *testing.T) { tt := []struct { binary []byte }{ {[]byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}}, } for i, tc := range tt { actual := &chunkShutdown{} err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d: %v", i) b, err := actual.marshal() assert.NoError(t, err) assert.Equalf(t, tc.binary, b, "test %d not equal", i) } } func TestChunkShutdown_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"length too short", []byte{0x07, 0x00, 0x00, 0x07, 0x12, 0x34, 0x56, 0x78}}, {"length too long", []byte{0x07, 0x00, 0x00, 0x09, 0x12, 0x34, 0x56, 0x78}}, {"payload too short", []byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56}}, {"payload too long", []byte{0x07, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78, 0x9f}}, {"invalid type", []byte{0x08, 0x00, 0x00, 0x08, 0x12, 0x34, 0x56, 0x78}}, } for i, tc := range tt { actual := &chunkShutdown{} err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/chunk_test.go000066400000000000000000000240351512256410600151760ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestInitChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) initChunk, ok := pkt.chunks[0].(*chunkInit) assert.True(t, ok, "Failed to cast Chunk -> Init") assert.NoError(t, err, "Unmarshal init Chunk failed") assert.Equalf(t, initChunk.initiateTag, uint32(1438213285), "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 1438213285, initChunk.initiateTag) assert.Equalf(t, initChunk.advertisedReceiverWindowCredit, uint32(131072), "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 131072, initChunk.advertisedReceiverWindowCredit) assert.Equalf(t, initChunk.numOutboundStreams, uint16(1024), "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1024, initChunk.numOutboundStreams) assert.Equalf(t, initChunk.numInboundStreams, uint16(2048), "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 2048, initChunk.numInboundStreams) assert.Equalf(t, initChunk.initialTSN, uint32(3899461680), "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", uint32(3899461680), initChunk.initialTSN) } func TestInitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0x96, 0x19, 0xe8, 0xb2, 0x02, 0x00, 0x00, 0x1c, 0xeb, 0x81, 0x4e, 0x01, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x50, 0xdf, 0x90, 0xd9, 0x00, 0x07, 0x00, 0x08, 0x94, 0x06, 0x2f, 0x93, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) _, ok := pkt.chunks[0].(*chunkInitAck) assert.True(t, ok, "Failed to cast Chunk -> InitAck") assert.NoError(t, err) } func TestChromeChunk1Init(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0xbc, 0xb3, 0x45, 0xa2, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) rawPkt2, err := pkt.marshal(true) assert.NoError(t, err) assert.Equal(t, rawPkt, rawPkt2) } func TestChromeChunk2InitAck(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0xce, 0x15, 0x79, 0xa2, 0xb5, 0xdb, 0x2d, 0x93, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x00, 0x07, 0x01, 0x38, 0x4b, 0x41, 0x4d, 0x45, 0x2d, 0x42, 0x53, 0x44, 0x20, 0x31, 0x2e, 0x31, 0x00, 0x00, 0x00, 0x00, 0x9c, 0x1e, 0x49, 0x5b, 0x00, 0x00, 0x00, 0x00, 0xd2, 0x42, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x60, 0xea, 0x00, 0x00, 0xc4, 0x13, 0x3d, 0xe9, 0x86, 0xb1, 0x85, 0x75, 0xa2, 0x79, 0x15, 0xce, 0x9b, 0xd5, 0xb3, 0x6f, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x20, 0xe0, 0x9f, 0x89, 0xe0, 0x27, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x01, 0x00, 0x01, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x56, 0xce, 0x15, 0x79, 0xa2, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0x94, 0x57, 0x95, 0xc0, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0xff, 0x5c, 0x49, 0x19, 0x4a, 0x94, 0xe8, 0x2a, 0xec, 0x58, 0x55, 0x62, 0x29, 0x1f, 0x8e, 0x23, 0xcd, 0x7c, 0xe8, 0x46, 0xba, 0x58, 0x1b, 0x3d, 0xab, 0xd7, 0x7e, 0x50, 0xf2, 0x41, 0xb1, 0x2e, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0x02, 0x00, 0x01, 0x90, 0x9b, 0xd5, 0xb3, 0x6f, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xef, 0xb4, 0x72, 0x87, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x2e, 0xf9, 0x9c, 0x10, 0x63, 0x72, 0xed, 0x0d, 0x33, 0xc2, 0xdc, 0x7f, 0x9f, 0xd7, 0xef, 0x1b, 0xc9, 0xc4, 0xa7, 0x41, 0x9a, 0x07, 0x68, 0x6b, 0x66, 0xfb, 0x6a, 0x4e, 0x32, 0x5d, 0xe4, 0x25, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, 0xca, 0x0c, 0x21, 0x11, 0xce, 0xf4, 0xfc, 0xb3, 0x66, 0x99, 0x4f, 0xdb, 0x4f, 0x95, 0x6b, 0x6f, 0x3b, 0xb1, 0xdb, 0x5a, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) rawPkt2, err := pkt.marshal(true) assert.NoError(t, err) assert.Equal(t, rawPkt, rawPkt2) } func TestInitMarshalUnmarshal(t *testing.T) { sctpPacket := &packet{} sctpPacket.destinationPort = 1 sctpPacket.sourcePort = 1 sctpPacket.verificationTag = 123 initAck := &chunkInitAck{} initAck.initialTSN = 123 initAck.numOutboundStreams = 1 initAck.numInboundStreams = 1 initAck.initiateTag = 123 initAck.advertisedReceiverWindowCredit = 1024 cookie, ErrRand := newRandomStateCookie() assert.NoError(t, ErrRand, "Failed to generate random state cookie") initAck.params = []param{cookie} sctpPacket.chunks = []chunk{initAck} rawPkt, err := sctpPacket.marshal(true) assert.NoError(t, err) pkt := &packet{} err = pkt.unmarshal(true, rawPkt) assert.NoError(t, err) initAckChunk, ok := pkt.chunks[0].(*chunkInitAck) assert.True(t, ok, "Failed to cast Chunk -> InitAck") assert.NoError(t, err, "Unmarshal init ack Chunk failed") assert.Equalf(t, initAckChunk.initiateTag, uint32(123), "Unmarshal passed for SCTP packet, but got incorrect initiate tag exp: %d act: %d", 123, initAckChunk.initiateTag) assert.Equalf(t, initAckChunk.advertisedReceiverWindowCredit, uint32(1024), "Unmarshal passed for SCTP packet, but got incorrect advertisedReceiverWindowCredit exp: %d act: %d", 1024, initAckChunk.advertisedReceiverWindowCredit) assert.Equalf(t, initAckChunk.numOutboundStreams, uint16(1), "Unmarshal passed for SCTP packet, but got incorrect numOutboundStreams tag exp: %d act: %d", 1, initAckChunk.numOutboundStreams) assert.Equalf(t, initAckChunk.numInboundStreams, uint16(1), "Unmarshal passed for SCTP packet, but got incorrect numInboundStreams exp: %d act: %d", 1, initAckChunk.numInboundStreams) assert.Equalf(t, initAckChunk.initialTSN, uint32(123), "Unmarshal passed for SCTP packet, but got incorrect initialTSN exp: %d act: %d", 123, initAckChunk.initialTSN) } func TestPayloadDataMarshalUnmarshal(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0xfc, 0xd6, 0x3f, 0xc6, 0xbe, 0xfa, 0xdc, 0x52, 0x0a, 0x00, 0x00, 0x24, 0x9b, 0x28, 0x7e, 0x48, 0xa3, 0x7b, 0xc1, 0x83, 0xc4, 0x4b, 0x41, 0x04, 0xa4, 0xf7, 0xed, 0x4c, 0x93, 0x62, 0xc3, 0x49, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x1f, 0xa8, 0x79, 0xa1, 0xc7, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x32, 0x03, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x03, 0x00, 0x00, 0x66, 0x6f, 0x6f, 0x00, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) _, ok := pkt.chunks[1].(*chunkPayloadData) assert.True(t, ok, "Failed to cast Chunk -> PayloadData") } func TestSelectAckChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0xc2, 0x98, 0x98, 0x0f, 0x42, 0x31, 0xea, 0x78, 0x03, 0x00, 0x00, 0x14, 0x87, 0x73, 0xbd, 0xa4, 0x00, 0x01, 0xfe, 0x74, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x02, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) _, ok := pkt.chunks[0].(*chunkSelectiveAck) assert.True(t, ok, "Failed to cast Chunk -> SelectiveAck") } func TestReconfigChunk(t *testing.T) { pkt := &packet{} rawPkt := []byte{ 0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x75, 0x3b, 0x12, 0xd3, 0x82, 0x0, 0x0, 0x16, 0x0, 0xd, 0x0, 0x12, 0x4e, 0x1c, 0xb9, 0xe6, 0x3a, 0x74, 0x8d, 0xff, 0x4e, 0x1c, 0xb9, 0xe6, 0x0, 0x1, 0x0, 0x0, } err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) c, ok := pkt.chunks[0].(*chunkReconfig) assert.True(t, ok, "Failed to cast Chunk -> Reconfig") iden := c.paramA.(*paramOutgoingResetRequest).streamIdentifiers[0] //nolint:forcetypeassert assert.Equalf(t, iden, uint16(1), "unexpected stream identifier: %d", iden) } func TestForwardTSNChunk(t *testing.T) { pkt := &packet{} rawPkt := append( []byte{0x13, 0x88, 0x13, 0x88, 0xb6, 0xa5, 0x12, 0xe5, 0x1f, 0x9d, 0xa0, 0xfb}, testChunkForwardTSN()..., ) err := pkt.unmarshal(true, rawPkt) assert.NoError(t, err) c, ok := pkt.chunks[0].(*chunkForwardTSN) assert.True(t, ok, "Failed to cast Chunk -> Forward TSN") assert.Equalf(t, c.newCumulativeTSN, uint32(3), "unexpected New Cumulative TSN") } sctp-1.9.0/chunkheader.go000066400000000000000000000065451512256410600153160ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) /* chunkHeader represents a SCTP Chunk header, defined in https://tools.ietf.org/html/rfc4960#section-3.2 The figure below illustrates the field format for the chunks to be transmitted in the SCTP packet. Each chunk is formatted with a Chunk Type field, a chunk-specific Flag field, a Chunk Length field, and a Value field. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk Type | Chunk Flags | Chunk Length | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | | | Chunk Value | | | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type chunkHeader struct { typ chunkType flags byte raw []byte } const ( chunkHeaderSize = 4 ) // SCTP chunk header errors. var ( ErrChunkHeaderTooSmall = errors.New("raw is too small for a SCTP chunk") ErrChunkHeaderNotEnoughSpace = errors.New("not enough data left in SCTP packet to satisfy requested length") ErrChunkHeaderPaddingNonZero = errors.New("chunk padding is non-zero at offset") ) func (c *chunkHeader) unmarshal(raw []byte) error { if len(raw) < chunkHeaderSize { return fmt.Errorf( "%w: raw only %d bytes, %d is the minimum length", ErrChunkHeaderTooSmall, len(raw), chunkHeaderSize, ) } c.typ = chunkType(raw[0]) c.flags = raw[1] length := binary.BigEndian.Uint16(raw[2:]) // Length includes Chunk header valueLength := int(length - chunkHeaderSize) lengthAfterValue := len(raw) - (chunkHeaderSize + valueLength) if lengthAfterValue < 0 { return fmt.Errorf("%w: remain %d req %d ", ErrChunkHeaderNotEnoughSpace, valueLength, len(raw)-chunkHeaderSize) } else if lengthAfterValue < 4 { // https://tools.ietf.org/html/rfc4960#section-3.2 // The Chunk Length field does not count any chunk padding. // Chunks (including Type, Length, and Value fields) are padded out // by the sender with all zero bytes to be a multiple of 4 bytes // long. This padding MUST NOT be more than 3 bytes in total. The // Chunk Length value does not include terminating padding of the // chunk. However, it does include padding of any variable-length // parameter except the last parameter in the chunk. The receiver // MUST ignore the padding. for i := lengthAfterValue; i > 0; i-- { paddingOffset := chunkHeaderSize + valueLength + (i - 1) if raw[paddingOffset] != 0 { return fmt.Errorf("%w: %d ", ErrChunkHeaderPaddingNonZero, paddingOffset) } } } c.raw = raw[chunkHeaderSize : chunkHeaderSize+valueLength] return nil } func (c *chunkHeader) marshal() ([]byte, error) { raw := make([]byte, 4+len(c.raw)) raw[0] = uint8(c.typ) raw[1] = c.flags binary.BigEndian.PutUint16(raw[2:], uint16(len(c.raw)+chunkHeaderSize)) //nolint:gosec // G115 copy(raw[4:], c.raw) return raw, nil } func (c *chunkHeader) valueLength() int { return len(c.raw) } // String makes chunkHeader printable. func (c chunkHeader) String() string { return c.typ.String() } sctp-1.9.0/chunkheader_test.go000066400000000000000000000001631512256410600163430ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp sctp-1.9.0/chunktype.go000066400000000000000000000032361512256410600150410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import "fmt" // chunkType is an enum for SCTP Chunk Type field // This field identifies the type of information contained in the // Chunk Value field. type chunkType uint8 // List of known chunkType enums. const ( ctPayloadData chunkType = 0 ctInit chunkType = 1 ctInitAck chunkType = 2 ctSack chunkType = 3 ctHeartbeat chunkType = 4 ctHeartbeatAck chunkType = 5 ctAbort chunkType = 6 ctShutdown chunkType = 7 ctShutdownAck chunkType = 8 ctError chunkType = 9 ctCookieEcho chunkType = 10 ctCookieAck chunkType = 11 ctCWR chunkType = 13 ctShutdownComplete chunkType = 14 ctReconfig chunkType = 130 ctForwardTSN chunkType = 192 ) func (c chunkType) String() string { //nolint:cyclop switch c { case ctPayloadData: return "DATA" case ctInit: return "INIT" case ctInitAck: return "INIT-ACK" case ctSack: return "SACK" case ctHeartbeat: return "HEARTBEAT" case ctHeartbeatAck: return "HEARTBEAT-ACK" case ctAbort: return "ABORT" case ctShutdown: return "SHUTDOWN" case ctShutdownAck: return "SHUTDOWN-ACK" case ctError: return "ERROR" case ctCookieEcho: return "COOKIE-ECHO" case ctCookieAck: return "COOKIE-ACK" case ctCWR: return "ECNE" // Explicit Congestion Notification Echo case ctShutdownComplete: return "SHUTDOWN-COMPLETE" case ctReconfig: return "RECONFIG" // Re-configuration case ctForwardTSN: return "FORWARD-TSN" default: return fmt.Sprintf("Unknown ChunkType: %d", c) } } sctp-1.9.0/chunktype_test.go000066400000000000000000000016431512256410600161000ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestChunkType_String(t *testing.T) { tt := []struct { chunkType chunkType expected string }{ {ctPayloadData, "DATA"}, {ctInit, "INIT"}, {ctInitAck, "INIT-ACK"}, {ctSack, "SACK"}, {ctHeartbeat, "HEARTBEAT"}, {ctHeartbeatAck, "HEARTBEAT-ACK"}, {ctAbort, "ABORT"}, {ctShutdown, "SHUTDOWN"}, {ctShutdownAck, "SHUTDOWN-ACK"}, {ctError, "ERROR"}, {ctCookieEcho, "COOKIE-ECHO"}, {ctCookieAck, "COOKIE-ACK"}, {ctCWR, "ECNE"}, {ctShutdownComplete, "SHUTDOWN-COMPLETE"}, {ctReconfig, "RECONFIG"}, {ctForwardTSN, "FORWARD-TSN"}, {chunkType(255), "Unknown ChunkType: 255"}, } for _, tc := range tt { assert.Equalf(t, tc.expected, tc.chunkType.String(), "chunkType %v should be %s", tc.chunkType, tc.expected) } } sctp-1.9.0/codecov.yml000066400000000000000000000007151512256410600146440ustar00rootroot00000000000000# # DO NOT EDIT THIS FILE # # It is automatically copied from https://github.com/pion/.goassets repository. # # SPDX-FileCopyrightText: 2023 The Pion community # SPDX-License-Identifier: MIT coverage: status: project: default: # Allow decreasing 2% of total coverage to avoid noise. threshold: 2% patch: default: target: 70% only_pulls: true ignore: - "examples/*" - "examples/**/*" sctp-1.9.0/control_queue.go000066400000000000000000000011271512256410600157100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // control queue type controlQueue struct { queue []*packet } func newControlQueue() *controlQueue { return &controlQueue{queue: []*packet{}} } func (q *controlQueue) push(c *packet) { q.queue = append(q.queue, c) } func (q *controlQueue) pushAll(packets []*packet) { q.queue = append(q.queue, packets...) } func (q *controlQueue) popAll() []*packet { packets := q.queue q.queue = []*packet{} return packets } func (q *controlQueue) size() int { return len(q.queue) } sctp-1.9.0/error_cause.go000066400000000000000000000057231512256410600153430ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) // errorCauseCode is a cause code that appears in either a ERROR or ABORT chunk. type errorCauseCode uint16 type errorCause interface { unmarshal([]byte) error marshal() ([]byte, error) length() uint16 String() string errorCauseCode() errorCauseCode } // Error and abort chunk errors. var ( ErrBuildErrorCaseHandle = errors.New("BuildErrorCause does not handle") ) // buildErrorCause delegates the building of a error cause from raw bytes to the correct structure. func buildErrorCause(raw []byte) (errorCause, error) { var errCause errorCause c := errorCauseCode(binary.BigEndian.Uint16(raw[0:])) switch c { case invalidMandatoryParameter: errCause = &errorCauseInvalidMandatoryParameter{} case unrecognizedChunkType: errCause = &errorCauseUnrecognizedChunkType{} case protocolViolation: errCause = &errorCauseProtocolViolation{} case userInitiatedAbort: errCause = &errorCauseUserInitiatedAbort{} default: return nil, fmt.Errorf("%w: %s", ErrBuildErrorCaseHandle, c.String()) } if err := errCause.unmarshal(raw); err != nil { return nil, err } return errCause, nil } const ( invalidStreamIdentifier errorCauseCode = 1 missingMandatoryParameter errorCauseCode = 2 staleCookieError errorCauseCode = 3 outOfResource errorCauseCode = 4 unresolvableAddress errorCauseCode = 5 unrecognizedChunkType errorCauseCode = 6 invalidMandatoryParameter errorCauseCode = 7 unrecognizedParameters errorCauseCode = 8 noUserData errorCauseCode = 9 cookieReceivedWhileShuttingDown errorCauseCode = 10 restartOfAnAssociationWithNewAddresses errorCauseCode = 11 userInitiatedAbort errorCauseCode = 12 protocolViolation errorCauseCode = 13 ) func (e errorCauseCode) String() string { //nolint:cyclop switch e { case invalidStreamIdentifier: return "Invalid Stream Identifier" case missingMandatoryParameter: return "Missing Mandatory Parameter" case staleCookieError: return "Stale Cookie Error" case outOfResource: return "Out Of Resource" case unresolvableAddress: return "Unresolvable IP" case unrecognizedChunkType: return "Unrecognized Chunk Type" case invalidMandatoryParameter: return "Invalid Mandatory Parameter" case unrecognizedParameters: return "Unrecognized Parameters" case noUserData: return "No User Data" case cookieReceivedWhileShuttingDown: return "Cookie Received While Shutting Down" case restartOfAnAssociationWithNewAddresses: return "Restart Of An Association With New Addresses" case userInitiatedAbort: return "User Initiated Abort" case protocolViolation: return "Protocol Violation" default: return fmt.Sprintf("Unknown CauseCode: %d", e) } } sctp-1.9.0/error_cause_header.go000066400000000000000000000026641512256410600166540ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" ) // errorCauseHeader represents the shared header that is shared by all error causes. type errorCauseHeader struct { code errorCauseCode len uint16 raw []byte } const ( errorCauseHeaderLength = 4 ) // ErrInvalidSCTPChunk is returned when an SCTP chunk is invalid. var ErrInvalidSCTPChunk = errors.New("invalid SCTP chunk") func (e *errorCauseHeader) marshal() ([]byte, error) { e.len = uint16(len(e.raw)) + uint16(errorCauseHeaderLength) //nolint:gosec // G115 raw := make([]byte, e.len) binary.BigEndian.PutUint16(raw[0:], uint16(e.code)) binary.BigEndian.PutUint16(raw[2:], e.len) copy(raw[errorCauseHeaderLength:], e.raw) return raw, nil } func (e *errorCauseHeader) unmarshal(raw []byte) error { e.code = errorCauseCode(binary.BigEndian.Uint16(raw[0:])) e.len = binary.BigEndian.Uint16(raw[2:]) if e.len < errorCauseHeaderLength || int(e.len) > len(raw) { return ErrInvalidSCTPChunk } valueLength := e.len - errorCauseHeaderLength e.raw = raw[errorCauseHeaderLength : errorCauseHeaderLength+valueLength] return nil } func (e *errorCauseHeader) length() uint16 { return e.len } func (e *errorCauseHeader) errorCauseCode() errorCauseCode { return e.code } // String makes errorCauseHeader printable. func (e errorCauseHeader) String() string { return e.code.String() } sctp-1.9.0/error_cause_invalid_mandatory_parameter.go000066400000000000000000000012221512256410600231550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // errorCauseInvalidMandatoryParameter represents an SCTP error cause. type errorCauseInvalidMandatoryParameter struct { errorCauseHeader } func (e *errorCauseInvalidMandatoryParameter) marshal() ([]byte, error) { return e.errorCauseHeader.marshal() } func (e *errorCauseInvalidMandatoryParameter) unmarshal(raw []byte) error { return e.errorCauseHeader.unmarshal(raw) } // String makes errorCauseInvalidMandatoryParameter printable. func (e *errorCauseInvalidMandatoryParameter) String() string { return e.errorCauseHeader.String() } sctp-1.9.0/error_cause_protocol_violation.go000066400000000000000000000034471512256410600213510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) /* This error cause MAY be included in ABORT chunks that are sent because an SCTP endpoint detects a protocol violation of the peer that is not covered by the error causes described in Section 3.3.10.1 to Section 3.3.10.12. An implementation MAY provide additional information specifying what kind of protocol violation has been detected. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cause Code=13 | Cause Length=Variable | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / Additional Information / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type errorCauseProtocolViolation struct { errorCauseHeader additionalInformation []byte } // Abort chunk errors. var ( ErrProtocolViolationUnmarshal = errors.New("unable to unmarshal Protocol Violation error") ) func (e *errorCauseProtocolViolation) marshal() ([]byte, error) { e.raw = e.additionalInformation return e.errorCauseHeader.marshal() } func (e *errorCauseProtocolViolation) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return fmt.Errorf("%w: %v", ErrProtocolViolationUnmarshal, err) //nolint:errorlint } e.additionalInformation = e.raw return nil } // String makes errorCauseProtocolViolation printable. func (e *errorCauseProtocolViolation) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader, e.additionalInformation) } sctp-1.9.0/error_cause_unrecognized_chunk_type.go000066400000000000000000000015001512256410600223350ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // errorCauseUnrecognizedChunkType represents an SCTP error cause. type errorCauseUnrecognizedChunkType struct { errorCauseHeader unrecognizedChunk []byte } func (e *errorCauseUnrecognizedChunkType) marshal() ([]byte, error) { e.code = unrecognizedChunkType e.errorCauseHeader.raw = e.unrecognizedChunk return e.errorCauseHeader.marshal() } func (e *errorCauseUnrecognizedChunkType) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return err } e.unrecognizedChunk = e.errorCauseHeader.raw return nil } // String makes errorCauseUnrecognizedChunkType printable. func (e *errorCauseUnrecognizedChunkType) String() string { return e.errorCauseHeader.String() } sctp-1.9.0/error_cause_user_initiated_abort.go000066400000000000000000000031301512256410600216100ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "fmt" ) /* This error cause MAY be included in ABORT chunks that are sent because of an upper-layer request. The upper layer can specify an Upper Layer Abort Reason that is transported by SCTP transparently and MAY be delivered to the upper-layer protocol at the peer. 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Cause Code=12 | Cause Length=Variable | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ / Upper Layer Abort Reason / \ \ +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type errorCauseUserInitiatedAbort struct { errorCauseHeader upperLayerAbortReason []byte } func (e *errorCauseUserInitiatedAbort) marshal() ([]byte, error) { e.code = userInitiatedAbort e.errorCauseHeader.raw = e.upperLayerAbortReason return e.errorCauseHeader.marshal() } func (e *errorCauseUserInitiatedAbort) unmarshal(raw []byte) error { err := e.errorCauseHeader.unmarshal(raw) if err != nil { return err } e.upperLayerAbortReason = e.errorCauseHeader.raw return nil } // String makes errorCauseUserInitiatedAbort printable. func (e *errorCauseUserInitiatedAbort) String() string { return fmt.Sprintf("%s: %s", e.errorCauseHeader.String(), e.upperLayerAbortReason) } sctp-1.9.0/examples/000077500000000000000000000000001512256410600143125ustar00rootroot00000000000000sctp-1.9.0/examples/ping-pong/000077500000000000000000000000001512256410600162105ustar00rootroot00000000000000sctp-1.9.0/examples/ping-pong/README.md000066400000000000000000000011271512256410600174700ustar00rootroot00000000000000# ping-pong ping-pong is a sctp example that shows how you can send/recv messages. In this example, there are 2 types of peers: **ping** and **pong**. **Ping** will always send `ping ` messages to **pong** and receive `pong ` messages from **pong**. **Pong** will always receive `ping ` from **ping** and send `pong ` messages to **ping**. ## Instruction ### Run ping and pong ### Run pong ```sh go run github.com/pion/sctp/examples/ping-pong/pong@latest ``` ### Run ping ```sh go run github.com/pion/sctp/examples/ping-pong/ping@latest ``` sctp-1.9.0/examples/ping-pong/ping/000077500000000000000000000000001512256410600171455ustar00rootroot00000000000000sctp-1.9.0/examples/ping-pong/ping/conn.go000066400000000000000000000035461512256410600204410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a simple ping-pong example package main import ( "net" "sync" "time" ) // Reference: https://github.com/pion/sctp/blob/master/association_test.go // Since UDP is connectionless, as a server, it doesn't know how to reply // simply using the `Write` method. So, to make it work, `disconnectedPacketConn` // will infer the last packet that it reads as the reply address for `Write` type disconnectedPacketConn struct { // nolint: unused mu sync.RWMutex rAddr net.Addr pConn net.PacketConn } // Read. func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { return 0, err } c.mu.Lock() c.rAddr = rAddr c.mu.Unlock() return i, err } // Write writes len(p) bytes from p to the DTLS connection. func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) { //nolint:unused return c.pConn.WriteTo(p, c.RemoteAddr()) } // Close closes the conn and releases any Read calls. func (c *disconnectedPacketConn) Close() error { //nolint:unused return c.pConn.Close() } // LocalAddr is a stub. func (c *disconnectedPacketConn) LocalAddr() net.Addr { //nolint:unused if c.pConn != nil { return c.pConn.LocalAddr() } return nil } // RemoteAddr is a stub. func (c *disconnectedPacketConn) RemoteAddr() net.Addr { //nolint:unused c.mu.RLock() defer c.mu.RUnlock() return c.rAddr } // SetDeadline is a stub. func (c *disconnectedPacketConn) SetDeadline(time.Time) error { //nolint:unused return nil } // SetReadDeadline is a stub. func (c *disconnectedPacketConn) SetReadDeadline(time.Time) error { //nolint:unused return nil } // SetWriteDeadline is a stub. func (c *disconnectedPacketConn) SetWriteDeadline(time.Time) error { //nolint:unused return nil } sctp-1.9.0/examples/ping-pong/ping/main.go000066400000000000000000000030611512256410600204200ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT //go:build !pong // +build !pong package main import ( "fmt" "log" "net" "time" "github.com/pion/logging" "github.com/pion/sctp" ) func main() { //nolint:cyclop conn, err := net.Dial("udp", "127.0.0.1:9899") //nolint: noctx if err != nil { log.Panic(err) } defer func() { if closeErr := conn.Close(); closeErr != nil { log.Panic(err) } }() fmt.Println("dialed udp ponger") config := sctp.Config{ NetConn: conn, LoggerFactory: logging.NewDefaultLoggerFactory(), } a, err := sctp.Client(config) if err != nil { log.Panic(err) } defer func() { if closeErr := a.Close(); closeErr != nil { log.Panic(err) } }() fmt.Println("created a client") stream, err := a.OpenStream(0, sctp.PayloadTypeWebRTCString) if err != nil { log.Panic(err) } defer func() { if closeErr := stream.Close(); closeErr != nil { log.Panic(err) } }() fmt.Println("opened a stream") // set unordered = true and 10ms treshold for dropping packets stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, 10) go func() { var pingSeqNum int for { pingMsg := fmt.Sprintf("ping %d", pingSeqNum) _, err = stream.Write([]byte(pingMsg)) if err != nil { log.Panic(err) } fmt.Println("sent:", pingMsg) pingSeqNum++ time.Sleep(time.Second) } }() for { buff := make([]byte, 1024) _, err = stream.Read(buff) if err != nil { log.Panic(err) } pongMsg := string(buff) fmt.Println("received:", pongMsg) } } sctp-1.9.0/examples/ping-pong/pong/000077500000000000000000000000001512256410600171535ustar00rootroot00000000000000sctp-1.9.0/examples/ping-pong/pong/conn.go000066400000000000000000000035461512256410600204470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package main implements a simple ping-pong example package main import ( "net" "sync" "time" ) // Reference: https://github.com/pion/sctp/blob/master/association_test.go // Since UDP is connectionless, as a server, it doesn't know how to reply // simply using the `Write` method. So, to make it work, `disconnectedPacketConn` // will infer the last packet that it reads as the reply address for `Write` type disconnectedPacketConn struct { // nolint: unused mu sync.RWMutex rAddr net.Addr pConn net.PacketConn } // Read. func (c *disconnectedPacketConn) Read(p []byte) (int, error) { //nolint:unused i, rAddr, err := c.pConn.ReadFrom(p) if err != nil { return 0, err } c.mu.Lock() c.rAddr = rAddr c.mu.Unlock() return i, err } // Write writes len(p) bytes from p to the DTLS connection. func (c *disconnectedPacketConn) Write(p []byte) (n int, err error) { //nolint:unused return c.pConn.WriteTo(p, c.RemoteAddr()) } // Close closes the conn and releases any Read calls. func (c *disconnectedPacketConn) Close() error { //nolint:unused return c.pConn.Close() } // LocalAddr is a stub. func (c *disconnectedPacketConn) LocalAddr() net.Addr { //nolint:unused if c.pConn != nil { return c.pConn.LocalAddr() } return nil } // RemoteAddr is a stub. func (c *disconnectedPacketConn) RemoteAddr() net.Addr { //nolint:unused c.mu.RLock() defer c.mu.RUnlock() return c.rAddr } // SetDeadline is a stub. func (c *disconnectedPacketConn) SetDeadline(time.Time) error { //nolint:unused return nil } // SetReadDeadline is a stub. func (c *disconnectedPacketConn) SetReadDeadline(time.Time) error { //nolint:unused return nil } // SetWriteDeadline is a stub. func (c *disconnectedPacketConn) SetWriteDeadline(time.Time) error { //nolint:unused return nil } sctp-1.9.0/examples/ping-pong/pong/main.go000066400000000000000000000032121512256410600204240ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package main import ( "fmt" "log" "net" "time" "github.com/pion/logging" "github.com/pion/sctp" ) func main() { //nolint:cyclop addr := net.UDPAddr{ IP: net.IPv4(127, 0, 0, 1), Port: 9899, } conn, err := net.ListenUDP("udp", &addr) if err != nil { log.Panic(err) } defer func() { if closeErr := conn.Close(); closeErr != nil { log.Panic(closeErr) } }() fmt.Println("created a udp listener") config := sctp.Config{ NetConn: &disconnectedPacketConn{pConn: conn}, LoggerFactory: logging.NewDefaultLoggerFactory(), } a, err := sctp.Server(config) if err != nil { log.Panic(err) } defer func() { if closeErr := a.Close(); closeErr != nil { log.Panic(closeErr) } }() defer fmt.Println("created a server") stream, err := a.AcceptStream() if err != nil { log.Panic(err) } defer func() { if closeErr := stream.Close(); closeErr != nil { log.Panic(closeErr) } }() fmt.Println("accepted a stream") // set unordered = true and 10ms treshold for dropping packets stream.SetReliabilityParams(true, sctp.ReliabilityTypeTimed, 10) var pongSeqNum int for { buff := make([]byte, 1024) _, err = stream.Read(buff) if err != nil { log.Panic(err) } pingMsg := string(buff) fmt.Println("received:", pingMsg) _, err = fmt.Sscanf(pingMsg, "ping %d", &pongSeqNum) if err != nil { log.Panic(err) } pongMsg := fmt.Sprintf("pong %d", pongSeqNum) _, err = stream.Write([]byte(pongMsg)) if err != nil { log.Panic(err) } fmt.Println("sent:", pongMsg) time.Sleep(time.Second) } } sctp-1.9.0/go.mod000066400000000000000000000010441512256410600136010ustar00rootroot00000000000000module github.com/pion/sctp require ( github.com/pion/logging v0.2.4 github.com/pion/randutil v0.1.0 github.com/pion/transport/v3 v3.1.1 github.com/stretchr/testify v1.11.1 ) require ( github.com/davecgh/go-spew v1.1.1 // indirect github.com/kr/pretty v0.1.0 // indirect github.com/pmezard/go-difflib v1.0.0 // indirect gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect ) go 1.21 // Retract version with ZeroChecksum misinterpretation (bi-directional/global handling) retract v1.8.12 sctp-1.9.0/go.sum000066400000000000000000000037751512256410600136430ustar00rootroot00000000000000github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/pion/logging v0.2.4 h1:tTew+7cmQ+Mc1pTBLKH2puKsOvhm32dROumOZ655zB8= github.com/pion/logging v0.2.4/go.mod h1:DffhXTKYdNZU+KtJ5pyQDjvOAh/GsNSyv1lbkFbe3so= github.com/pion/randutil v0.1.0 h1:CFG1UdESneORglEsnimhUjf33Rwjubwj6xfiOXBa3mA= github.com/pion/randutil v0.1.0/go.mod h1:XcJrSMMbbMRhASFVOlj/5hQial/Y8oH/HVo7TBZq+j8= github.com/pion/transport/v3 v3.1.1 h1:Tr684+fnnKlhPceU+ICdrw6KKkTms+5qHMgw6bIkYOM= github.com/pion/transport/v3 v3.1.1/go.mod h1:+c2eewC5WJQHiAA46fkMMzoYZSuGzA/7E2FPrOYHctQ= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= github.com/stretchr/testify v1.11.1 h1:7s2iGBzp5EwR7/aIZr8ao5+dra3wiQyKjjFuvgVKu7U= github.com/stretchr/testify v1.11.1/go.mod h1:wZwfW3scLgRK+23gO65QZefKpKQRnfz6sD981Nm4B6U= golang.org/x/net v0.34.0 h1:Mb7Mrk043xzHgnRM88suvJFwzVrRfHEHJEl5/71CKw0= golang.org/x/net v0.34.0/go.mod h1:di0qlW3YNM5oh6GqDGQr92MyTozJPmybPK4Ev/Gm31k= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15 h1:YR8cESwS4TdDjEe65xsg0ogRM/Nc3DYOhEAlW+xobZo= gopkg.in/check.v1 v1.0.0-20190902080502-41f04d3bba15/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= sctp-1.9.0/packet.go000066400000000000000000000166761512256410600143120ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" "hash/crc32" ) // Create the crc32 table we'll use for the checksum. var castagnoliTable = crc32.MakeTable(crc32.Castagnoli) // nolint:gochecknoglobals // Allocate and zero this data once. // We need to use it for the checksum and don't want to allocate/clear each time. var fourZeroes [4]byte // nolint:gochecknoglobals /* Packet represents an SCTP packet, defined in https://tools.ietf.org/html/rfc4960#section-3 An SCTP packet is composed of a common header and chunks. A chunk contains either control information or user data. SCTP Packet Format 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Common Header | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk #1 | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | ... | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Chunk #n | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ SCTP Common Header Format 0 1 2 3 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Source Value Number | Destination Value Number | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Verification Tag | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ | Checksum | +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ */ type packet struct { sourcePort uint16 destinationPort uint16 verificationTag uint32 chunks []chunk } const ( packetHeaderSize = 12 ) // SCTP packet errors. var ( ErrPacketRawTooSmall = errors.New("raw is smaller than the minimum length for a SCTP packet") ErrParseSCTPChunkNotEnoughData = errors.New("unable to parse SCTP chunk, not enough data for complete header") ErrUnmarshalUnknownChunkType = errors.New("failed to unmarshal, contains unknown chunk type") ErrChecksumMismatch = errors.New("checksum mismatch theirs") ) func (p *packet) unmarshal(doChecksum bool, raw []byte) error { //nolint:cyclop if len(raw) < packetHeaderSize { return fmt.Errorf("%w: raw only %d bytes, %d is the minimum length", ErrPacketRawTooSmall, len(raw), packetHeaderSize) } offset := packetHeaderSize // Check if doing CRC32c is required. // Without having SCTP AUTH implemented, this depends only on the type // og the first chunk. if offset+chunkHeaderSize <= len(raw) { switch chunkType(raw[offset]) { case ctInit, ctCookieEcho: doChecksum = true default: } } theirChecksum := binary.LittleEndian.Uint32(raw[8:]) if theirChecksum != 0 || doChecksum { ourChecksum := generatePacketChecksum(raw) if theirChecksum != ourChecksum { return fmt.Errorf("%w: %d ours: %d", ErrChecksumMismatch, theirChecksum, ourChecksum) } } p.sourcePort = binary.BigEndian.Uint16(raw[0:]) p.destinationPort = binary.BigEndian.Uint16(raw[2:]) p.verificationTag = binary.BigEndian.Uint32(raw[4:]) for offset < len(raw) { // guaranteed to be safe by loop condition remaining := raw[offset:] // nolint:gosec // must have at least a full chunk header to continue. if len(remaining) < chunkHeaderSize { return fmt.Errorf("%w: offset %d remaining %d", ErrParseSCTPChunkNotEnoughData, offset, len(remaining)) } ctype := chunkType(remaining[0]) var dataChunk chunk switch ctype { case ctInit: dataChunk = &chunkInit{} case ctInitAck: dataChunk = &chunkInitAck{} case ctAbort: dataChunk = &chunkAbort{} case ctCookieEcho: dataChunk = &chunkCookieEcho{} case ctCookieAck: dataChunk = &chunkCookieAck{} case ctHeartbeat: dataChunk = &chunkHeartbeat{} case ctPayloadData: dataChunk = &chunkPayloadData{} case ctSack: dataChunk = &chunkSelectiveAck{} case ctReconfig: dataChunk = &chunkReconfig{} case ctForwardTSN: dataChunk = &chunkForwardTSN{} case ctError: dataChunk = &chunkError{} case ctShutdown: dataChunk = &chunkShutdown{} case ctShutdownAck: dataChunk = &chunkShutdownAck{} case ctShutdownComplete: dataChunk = &chunkShutdownComplete{} default: return fmt.Errorf("%w: %s", ErrUnmarshalUnknownChunkType, ctype.String()) } if err := dataChunk.unmarshal(remaining); err != nil { return err } p.chunks = append(p.chunks, dataChunk) chunkValuePadding := getPadding(dataChunk.valueLength()) offset += chunkHeaderSize + dataChunk.valueLength() + chunkValuePadding } // if we overshot then should error. if offset != len(raw) { if offset > len(raw) { overshoot := offset - len(raw) return fmt.Errorf("%w: parsed past end of buffer by %d bytes (offset %d, length %d)", ErrParseSCTPChunkNotEnoughData, overshoot, offset, len(raw)) } remaining := len(raw) - offset return fmt.Errorf("%w: unparsed data remaining: %d bytes (offset %d, length %d)", ErrParseSCTPChunkNotEnoughData, remaining, offset, len(raw)) } return nil } func (p *packet) marshal(doChecksum bool) ([]byte, error) { raw := make([]byte, packetHeaderSize) // Populate static headers // 8-12 is Checksum which will be populated when packet is complete binary.BigEndian.PutUint16(raw[0:], p.sourcePort) binary.BigEndian.PutUint16(raw[2:], p.destinationPort) binary.BigEndian.PutUint32(raw[4:], p.verificationTag) // Populate chunks for _, c := range p.chunks { chunkRaw, err := c.marshal() if err != nil { return nil, err } raw = append(raw, chunkRaw...) //nolint:makezero // todo:fix paddingNeeded := getPadding(len(raw)) if paddingNeeded != 0 { raw = append(raw, make([]byte, paddingNeeded)...) //nolint:makezero // todo:fix } } if doChecksum { // golang CRC32C uses reflected input and reflected output, the // net result of this is to have the bytes flipped compared to // the non reflected variant that the spec expects. // // Use LittleEndian.PutUint32 to avoid flipping the bytes in to // the spec compliant checksum order binary.LittleEndian.PutUint32(raw[8:], generatePacketChecksum(raw)) } return raw, nil } func generatePacketChecksum(raw []byte) (sum uint32) { // Fastest way to do a crc32 without allocating. sum = crc32.Update(sum, castagnoliTable, raw[0:8]) sum = crc32.Update(sum, castagnoliTable, fourZeroes[:]) sum = crc32.Update(sum, castagnoliTable, raw[12:]) return sum } // String makes packet printable. func (p *packet) String() string { format := `Packet: sourcePort: %d destinationPort: %d verificationTag: %d ` res := fmt.Sprintf(format, p.sourcePort, p.destinationPort, p.verificationTag, ) for i, chunk := range p.chunks { res += fmt.Sprintf("Chunk %d:\n %s", i, chunk) } return res } // TryMarshalUnmarshal attempts to marshal and unmarshal a message. Added for fuzzing. func TryMarshalUnmarshal(msg []byte) int { p := &packet{} err := p.unmarshal(false, msg) if err != nil { return 0 } _, err = p.marshal(false) if err != nil { return 0 } return 1 } sctp-1.9.0/packet_test.go000066400000000000000000000047521512256410600153410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestPacketUnmarshal(t *testing.T) { pkt := &packet{} assert.Error(t, pkt.unmarshal(true, []byte{}), "Unmarshal should fail when a packet is too small to be SCTP") headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} err := pkt.unmarshal(true, headerOnly) assert.NoError(t, err, "Unmarshal failed for SCTP packet with no chunks") assert.Equal(t, uint16(defaultSCTPSrcDstPort), pkt.sourcePort, "Unmarshal passed for SCTP packet, but got incorrect source port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.sourcePort) assert.Equal(t, uint16(defaultSCTPSrcDstPort), pkt.destinationPort, "Unmarshal passed for SCTP packet, but got incorrect destination port exp: %d act: %d", defaultSCTPSrcDstPort, pkt.destinationPort) assert.Equal(t, uint32(0), pkt.verificationTag, "Unmarshal passed for SCTP packet, but got incorrect verification tag exp: %d act: %d", 0, pkt.verificationTag) rawChunk := []byte{ 0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x81, 0x46, 0x9d, 0xfc, 0x01, 0x00, 0x00, 0x56, 0x55, 0xb9, 0x64, 0xa5, 0x00, 0x02, 0x00, 0x00, 0x04, 0x00, 0x08, 0x00, 0xe8, 0x6d, 0x10, 0x30, 0xc0, 0x00, 0x00, 0x04, 0x80, 0x08, 0x00, 0x09, 0xc0, 0x0f, 0xc1, 0x80, 0x82, 0x00, 0x00, 0x00, 0x80, 0x02, 0x00, 0x24, 0x9f, 0xeb, 0xbb, 0x5c, 0x50, 0xc9, 0xbf, 0x75, 0x9c, 0xb1, 0x2c, 0x57, 0x4f, 0xa4, 0x5a, 0x51, 0xba, 0x60, 0x17, 0x78, 0x27, 0x94, 0x5c, 0x31, 0xe6, 0x5d, 0x5b, 0x09, 0x47, 0xe2, 0x22, 0x06, 0x80, 0x04, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x80, 0x03, 0x00, 0x06, 0x80, 0xc1, 0x00, 0x00, } assert.NoError(t, pkt.unmarshal(true, rawChunk)) } func TestPacketMarshal(t *testing.T) { pkt := &packet{} headerOnly := []byte{0x13, 0x88, 0x13, 0x88, 0x00, 0x00, 0x00, 0x00, 0x06, 0xa9, 0x00, 0xe1} assert.NoError(t, pkt.unmarshal(true, headerOnly), "Unmarshal failed for SCTP packet with no chunks") headerOnlyMarshaled, err := pkt.marshal(true) if assert.NoError(t, err, "Marshal failed for SCTP packet with no chunks") { assert.Equal(t, headerOnly, headerOnlyMarshaled, "Unmarshal/Marshaled header only packet did not match \nheaderOnly: % 02x \nheaderOnlyMarshaled % 02x", headerOnly, headerOnlyMarshaled) } } func BenchmarkPacketGenerateChecksum(b *testing.B) { var data [1024]byte for i := 0; i < b.N; i++ { _ = generatePacketChecksum(data[:]) } } sctp-1.9.0/param.go000066400000000000000000000025341512256410600141270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" ) type param interface { marshal() ([]byte, error) length() int } // ErrParamTypeUnhandled is returned if unknown parameter type is specified. var ErrParamTypeUnhandled = errors.New("unhandled ParamType") func buildParam(typeParam paramType, rawParam []byte) (param, error) { //nolint:cyclop switch typeParam { case forwardTSNSupp: return (¶mForwardTSNSupported{}).unmarshal(rawParam) case supportedExt: return (¶mSupportedExtensions{}).unmarshal(rawParam) case ecnCapable: return (¶mECNCapable{}).unmarshal(rawParam) case random: return (¶mRandom{}).unmarshal(rawParam) case reqHMACAlgo: return (¶mRequestedHMACAlgorithm{}).unmarshal(rawParam) case chunkList: return (¶mChunkList{}).unmarshal(rawParam) case stateCookie: return (¶mStateCookie{}).unmarshal(rawParam) case heartbeatInfo: return (¶mHeartbeatInfo{}).unmarshal(rawParam) case outSSNResetReq: return (¶mOutgoingResetRequest{}).unmarshal(rawParam) case reconfigResp: return (¶mReconfigResponse{}).unmarshal(rawParam) case zeroChecksumAcceptable: return (¶mZeroChecksumAcceptable{}).unmarshal(rawParam) default: return nil, fmt.Errorf("%w: %v", ErrParamTypeUnhandled, typeParam) } } sctp-1.9.0/param_chunk_list.go000066400000000000000000000011671512256410600163530ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type paramChunkList struct { paramHeader chunkTypes []chunkType } func (c *paramChunkList) marshal() ([]byte, error) { c.typ = chunkList c.raw = make([]byte, len(c.chunkTypes)) for i, t := range c.chunkTypes { c.raw[i] = byte(t) } return c.paramHeader.marshal() } func (c *paramChunkList) unmarshal(raw []byte) (param, error) { err := c.paramHeader.unmarshal(raw) if err != nil { return nil, err } for _, t := range c.raw { c.chunkTypes = append(c.chunkTypes, chunkType(t)) } return c, nil } sctp-1.9.0/param_ecn_capable.go000066400000000000000000000007001512256410600164140ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type paramECNCapable struct { paramHeader } func (r *paramECNCapable) marshal() ([]byte, error) { r.typ = ecnCapable r.raw = []byte{} return r.paramHeader.marshal() } func (r *paramECNCapable) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } return r, nil } sctp-1.9.0/param_ecn_capable_test.go000066400000000000000000000024141512256410600174570ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "testing" "github.com/stretchr/testify/assert" ) func testParamECNCapabale() []byte { return []byte{0x80, 0x0, 0x0, 0x4} } func TestParamECNCapabale_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramECNCapable }{ { testParamECNCapabale(), ¶mECNCapable{ paramHeader: paramHeader{ typ: ecnCapable, unrecognizedAction: paramHeaderUnrecognizedActionSkip, len: 4, raw: []byte{}, }, }, }, } for i, tc := range tt { actual := ¶mECNCapable{} _, err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.binary, b) } } func TestParamECNCapabale_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"param too short", []byte{0x0, 0xd, 0x0}}, } for i, tc := range tt { actual := ¶mECNCapable{} _, err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/param_forward_tsn_supported.go000066400000000000000000000016761512256410600206520ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // At the initialization of the association, the sender of the INIT or // INIT ACK chunk MAY include this OPTIONAL parameter to inform its peer // that it is able to support the Forward TSN chunk // // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 49152 | Parameter Length = 4 | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramForwardTSNSupported struct { paramHeader } func (f *paramForwardTSNSupported) marshal() ([]byte, error) { f.typ = forwardTSNSupp f.raw = []byte{} return f.paramHeader.marshal() } func (f *paramForwardTSNSupported) unmarshal(raw []byte) (param, error) { err := f.paramHeader.unmarshal(raw) if err != nil { return nil, err } return f, nil } sctp-1.9.0/param_forward_tsn_supported_test.go000066400000000000000000000025351512256410600217040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp // nolint:dupl import ( "testing" "github.com/stretchr/testify/assert" ) func testParamForwardTSNSupported() []byte { return []byte{0xc0, 0x0, 0x0, 0x4} } func TestParamForwardTSNSupported_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramForwardTSNSupported }{ { testParamForwardTSNSupported(), ¶mForwardTSNSupported{ paramHeader: paramHeader{ typ: forwardTSNSupp, len: 4, unrecognizedAction: paramHeaderUnrecognizedActionSkipAndReport, raw: []byte{}, }, }, }, } for i, tc := range tt { actual := ¶mForwardTSNSupported{} _, err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.binary, b) } } func TestParamForwardTSNSupported_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"param too short", []byte{0x0, 0xd, 0x0}}, } for i, tc := range tt { actual := ¶mForwardTSNSupported{} _, err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/param_heartbeat_info.go000066400000000000000000000010271512256410600171550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type paramHeartbeatInfo struct { paramHeader heartbeatInformation []byte } func (h *paramHeartbeatInfo) marshal() ([]byte, error) { h.typ = heartbeatInfo h.raw = h.heartbeatInformation return h.paramHeader.marshal() } func (h *paramHeartbeatInfo) unmarshal(raw []byte) (param, error) { err := h.paramHeader.unmarshal(raw) if err != nil { return nil, err } h.heartbeatInformation = h.raw return h, nil } sctp-1.9.0/param_outgoing_reset_request.go000066400000000000000000000077161512256410600210230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" ) const ( paramOutgoingResetRequestStreamIdentifiersOffset = 12 ) // This parameter is used by the sender to request the reset of some or // all outgoing streams. // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 13 | Parameter Length = 16 + 2 * N | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Request Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Response Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Sender's Last Assigned TSN | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream Number 1 (optional) | Stream Number 2 (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // / ...... / // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Stream Number N-1 (optional) | Stream Number N (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramOutgoingResetRequest struct { paramHeader // reconfigRequestSequenceNumber is used to identify the request. It is a monotonically // increasing number that is initialized to the same value as the // initial TSN. It is increased by 1 whenever sending a new Re- // configuration Request Parameter. reconfigRequestSequenceNumber uint32 // When this Outgoing SSN Reset Request Parameter is sent in response // to an Incoming SSN Reset Request Parameter, this parameter is also // an implicit response to the incoming request. This field then // holds the Re-configuration Request Sequence Number of the incoming // request. In other cases, it holds the next expected // Re-configuration Request Sequence Number minus 1. reconfigResponseSequenceNumber uint32 // This value holds the next TSN minus 1 -- in other words, the last // TSN that this sender assigned. senderLastTSN uint32 // This optional field, if included, is used to indicate specific // streams that are to be reset. If no streams are listed, then all // streams are to be reset. streamIdentifiers []uint16 } // Outgoing reset request parameter errors. var ( ErrSSNResetRequestParamTooShort = errors.New("outgoing SSN reset request parameter too short") ) func (r *paramOutgoingResetRequest) marshal() ([]byte, error) { r.typ = outSSNResetReq r.raw = make([]byte, paramOutgoingResetRequestStreamIdentifiersOffset+2*len(r.streamIdentifiers)) binary.BigEndian.PutUint32(r.raw, r.reconfigRequestSequenceNumber) binary.BigEndian.PutUint32(r.raw[4:], r.reconfigResponseSequenceNumber) binary.BigEndian.PutUint32(r.raw[8:], r.senderLastTSN) for i, sID := range r.streamIdentifiers { binary.BigEndian.PutUint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:], sID) } return r.paramHeader.marshal() } func (r *paramOutgoingResetRequest) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw) < paramOutgoingResetRequestStreamIdentifiersOffset { return nil, ErrSSNResetRequestParamTooShort } r.reconfigRequestSequenceNumber = binary.BigEndian.Uint32(r.raw) r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw[4:]) r.senderLastTSN = binary.BigEndian.Uint32(r.raw[8:]) lim := (len(r.raw) - paramOutgoingResetRequestStreamIdentifiersOffset) / 2 r.streamIdentifiers = make([]uint16, lim) for i := 0; i < lim; i++ { r.streamIdentifiers[i] = binary.BigEndian.Uint16(r.raw[paramOutgoingResetRequestStreamIdentifiersOffset+2*i:]) } return r, nil } sctp-1.9.0/param_outgoing_reset_request_test.go000066400000000000000000000040751512256410600220550ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkReconfigParamA() []byte { return []byte{ 0x00, 0x0d, 0x00, 0x16, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x02, 0x00, 0x00, 0x00, 0x03, 0x00, 0x04, 0x00, 0x05, 0x00, 0x06, } } func testChunkReconfigParamB() []byte { return []byte{0x0, 0xd, 0x0, 0x10, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x2, 0x0, 0x0, 0x0, 0x3} } func TestParamOutgoingResetRequest_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramOutgoingResetRequest }{ { testChunkReconfigParamA(), ¶mOutgoingResetRequest{ paramHeader: paramHeader{ typ: outSSNResetReq, len: 22, raw: testChunkReconfigParamA()[4:], }, reconfigRequestSequenceNumber: 1, reconfigResponseSequenceNumber: 2, senderLastTSN: 3, streamIdentifiers: []uint16{4, 5, 6}, }, }, { testChunkReconfigParamB(), ¶mOutgoingResetRequest{ paramHeader: paramHeader{ typ: outSSNResetReq, len: 16, raw: testChunkReconfigParamB()[4:], }, reconfigRequestSequenceNumber: 1, reconfigResponseSequenceNumber: 2, senderLastTSN: 3, streamIdentifiers: []uint16{}, }, }, } for i, tc := range tt { actual := ¶mOutgoingResetRequest{} _, err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoErrorf(t, err, "failed to marshal #%d", i) assert.Equal(t, tc.binary, b) } } func TestParamOutgoingResetRequest_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"packet too short", testChunkReconfigParamA()[:8]}, {"param too short", []byte{0x0, 0xd, 0x0, 0x4}}, } for i, tc := range tt { actual := ¶mOutgoingResetRequest{} _, err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/param_random.go000066400000000000000000000007351512256410600154700ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type paramRandom struct { paramHeader randomData []byte } func (r *paramRandom) marshal() ([]byte, error) { r.typ = random r.raw = r.randomData return r.paramHeader.marshal() } func (r *paramRandom) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } r.randomData = r.raw return r, nil } sctp-1.9.0/param_reconfig_response.go000066400000000000000000000065031512256410600177210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) // This parameter is used by the receiver of a Re-configuration Request // Parameter to respond to the request. // // 0 1 2 3 // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Parameter Type = 16 | Parameter Length | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Re-configuration Response Sequence Number | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Result | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Sender's Next TSN (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Receiver's Next TSN (optional) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramReconfigResponse struct { paramHeader // This value is copied from the request parameter and is used by the // receiver of the Re-configuration Response Parameter to tie the // response to the request. reconfigResponseSequenceNumber uint32 // This value describes the result of the processing of the request. result reconfigResult } type reconfigResult uint32 const ( reconfigResultSuccessNOP reconfigResult = 0 reconfigResultSuccessPerformed reconfigResult = 1 reconfigResultDenied reconfigResult = 2 reconfigResultErrorWrongSSN reconfigResult = 3 reconfigResultErrorRequestAlreadyInProgress reconfigResult = 4 reconfigResultErrorBadSequenceNumber reconfigResult = 5 reconfigResultInProgress reconfigResult = 6 ) // Reconfiguration response errors. var ( ErrReconfigRespParamTooShort = errors.New("reconfig response parameter too short") ) func (t reconfigResult) String() string { switch t { case reconfigResultSuccessNOP: return "0: Success - Nothing to do" case reconfigResultSuccessPerformed: return "1: Success - Performed" case reconfigResultDenied: return "2: Denied" case reconfigResultErrorWrongSSN: return "3: Error - Wrong SSN" case reconfigResultErrorRequestAlreadyInProgress: return "4: Error - Request already in progress" case reconfigResultErrorBadSequenceNumber: return "5: Error - Bad Sequence Number" case reconfigResultInProgress: return "6: In progress" default: return fmt.Sprintf("Unknown reconfigResult: %d", t) } } func (r *paramReconfigResponse) marshal() ([]byte, error) { r.typ = reconfigResp r.raw = make([]byte, 8) binary.BigEndian.PutUint32(r.raw, r.reconfigResponseSequenceNumber) binary.BigEndian.PutUint32(r.raw[4:], uint32(r.result)) return r.paramHeader.marshal() } func (r *paramReconfigResponse) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw) < 8 { return nil, ErrReconfigRespParamTooShort } r.reconfigResponseSequenceNumber = binary.BigEndian.Uint32(r.raw) r.result = reconfigResult(binary.BigEndian.Uint32(r.raw[4:])) return r, nil } sctp-1.9.0/param_reconfig_response_test.go000066400000000000000000000041201512256410600207510ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testChunkReconfigResponce() []byte { return []byte{0x0, 0x10, 0x0, 0xc, 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x1} } func TestParamReconfigResponse_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramReconfigResponse }{ { testChunkReconfigResponce(), ¶mReconfigResponse{ paramHeader: paramHeader{ typ: reconfigResp, len: 12, raw: testChunkReconfigResponce()[4:], }, reconfigResponseSequenceNumber: 1, result: reconfigResultSuccessPerformed, }, }, } for i, tc := range tt { actual := ¶mReconfigResponse{} _, err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d: %v", i) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoErrorf(t, err, "failed to marshal #%d: %v", i) assert.Equal(t, tc.binary, b) } } func TestParamReconfigResponse_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"packet too short", testChunkReconfigParamA()[:8]}, {"param too short", []byte{0x0, 0x10, 0x0, 0x4}}, } for i, tc := range tt { actual := ¶mReconfigResponse{} _, err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } func TestReconfigResultStringer(t *testing.T) { tt := []struct { result reconfigResult expected string }{ {reconfigResultSuccessNOP, "0: Success - Nothing to do"}, {reconfigResultSuccessPerformed, "1: Success - Performed"}, {reconfigResultDenied, "2: Denied"}, {reconfigResultErrorWrongSSN, "3: Error - Wrong SSN"}, {reconfigResultErrorRequestAlreadyInProgress, "4: Error - Request already in progress"}, {reconfigResultErrorBadSequenceNumber, "5: Error - Bad Sequence Number"}, {reconfigResultInProgress, "6: In progress"}, } for i, tc := range tt { actual := tc.result.String() assert.Equalf(t, tc.expected, actual, "Test case %d", i) } } sctp-1.9.0/param_requested_hmac_algorithm.go000066400000000000000000000034531512256410600212470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) type hmacAlgorithm uint16 const ( hmacResv1 hmacAlgorithm = 0 hmacSHA128 hmacAlgorithm = 1 hmacResv2 hmacAlgorithm = 2 hmacSHA256 hmacAlgorithm = 3 ) // ErrInvalidAlgorithmType is returned if unknown auth algorithm is specified. var ErrInvalidAlgorithmType = errors.New("invalid algorithm type") // ErrInvalidChunkLength is returned if the chunk length is invalid. var ErrInvalidChunkLength = errors.New("invalid chunk length") func (c hmacAlgorithm) String() string { switch c { case hmacResv1: return "HMAC Reserved (0x00)" case hmacSHA128: return "HMAC SHA-128" case hmacResv2: return "HMAC Reserved (0x02)" case hmacSHA256: return "HMAC SHA-256" default: return fmt.Sprintf("Unknown HMAC Algorithm type: %d", c) } } type paramRequestedHMACAlgorithm struct { paramHeader availableAlgorithms []hmacAlgorithm } func (r *paramRequestedHMACAlgorithm) marshal() ([]byte, error) { r.typ = reqHMACAlgo r.raw = make([]byte, len(r.availableAlgorithms)*2) i := 0 for _, a := range r.availableAlgorithms { binary.BigEndian.PutUint16(r.raw[i:], uint16(a)) i += 2 } return r.paramHeader.marshal() } func (r *paramRequestedHMACAlgorithm) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw)%2 == 1 { return nil, ErrInvalidChunkLength } i := 0 for i < len(r.raw) { a := hmacAlgorithm(binary.BigEndian.Uint16(r.raw[i:])) switch a { case hmacSHA128: fallthrough case hmacSHA256: r.availableAlgorithms = append(r.availableAlgorithms, a) default: return nil, fmt.Errorf("%w: %v", ErrInvalidAlgorithmType, a) } i += 2 } return r, nil } sctp-1.9.0/param_state_cookie.go000066400000000000000000000017041512256410600166560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "crypto/rand" "fmt" ) type paramStateCookie struct { paramHeader cookie []byte } func newRandomStateCookie() (*paramStateCookie, error) { randCookie := make([]byte, 32) _, err := rand.Read(randCookie) // crypto/rand.Read returns n == len(b) if and only if err == nil. if err != nil { return nil, err } s := ¶mStateCookie{ cookie: randCookie, } return s, nil } func (s *paramStateCookie) marshal() ([]byte, error) { s.typ = stateCookie s.raw = s.cookie return s.paramHeader.marshal() } func (s *paramStateCookie) unmarshal(raw []byte) (param, error) { err := s.paramHeader.unmarshal(raw) if err != nil { return nil, err } s.cookie = s.raw return s, nil } // String makes paramStateCookie printable. func (s *paramStateCookie) String() string { return fmt.Sprintf("%s: %s", s.paramHeader, s.cookie) } sctp-1.9.0/param_supported_extensions.go000066400000000000000000000012311512256410600205040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type paramSupportedExtensions struct { paramHeader ChunkTypes []chunkType } func (s *paramSupportedExtensions) marshal() ([]byte, error) { s.typ = supportedExt s.raw = make([]byte, len(s.ChunkTypes)) for i, c := range s.ChunkTypes { s.raw[i] = byte(c) } return s.paramHeader.marshal() } func (s *paramSupportedExtensions) unmarshal(raw []byte) (param, error) { err := s.paramHeader.unmarshal(raw) if err != nil { return nil, err } for _, t := range s.raw { s.ChunkTypes = append(s.ChunkTypes, chunkType(t)) } return s, nil } sctp-1.9.0/param_test.go000066400000000000000000000021031512256410600151560ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestBuildParam_Success(t *testing.T) { tt := []struct { binary []byte }{ {testChunkReconfigParamA()}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) assert.NoErrorf(t, err, "failed to parse param type #%d", i) p, err := buildParam(pType, tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) b, err := p.marshal() assert.NoErrorf(t, err, "failed to marshal #%d", i) assert.Equal(t, tc.binary, b) } } func TestBuildParam_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"invalid ParamType", []byte{0x0, 0x0}}, {"build failure", testChunkReconfigParamA()[:8]}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) assert.NoErrorf(t, err, "failed to parse param type #%d", i) _, err = buildParam(pType, tc.binary) assert.Errorf(t, err, "expected buildParam #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/param_zero_checksum.go000066400000000000000000000031441512256410600170460ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" ) // This parameter is used to inform the receiver that a sender is willing to // accept zero as checksum if some other error detection method is used // instead. // // 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type = 0x8001 (suggested) | Length = 8 | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Error Detection Method Identifier (EDMID) | // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ type paramZeroChecksumAcceptable struct { paramHeader // The Error Detection Method Identifier (EDMID) specifies an alternate // error detection method the sender of this parameter is willing to use for // received packets. edmid uint32 } // Zero Checksum parameter error. var ( ErrZeroChecksumParamTooShort = errors.New("zero checksum parameter too short") ) const ( dtlsErrorDetectionMethod uint32 = 1 ) func (r *paramZeroChecksumAcceptable) marshal() ([]byte, error) { r.typ = zeroChecksumAcceptable r.raw = make([]byte, 4) binary.BigEndian.PutUint32(r.raw, r.edmid) return r.paramHeader.marshal() } func (r *paramZeroChecksumAcceptable) unmarshal(raw []byte) (param, error) { err := r.paramHeader.unmarshal(raw) if err != nil { return nil, err } if len(r.raw) < 4 { return nil, ErrZeroChecksumParamTooShort } r.edmid = binary.BigEndian.Uint32(r.raw) return r, nil } sctp-1.9.0/param_zero_checksum_test.go000066400000000000000000000017501512256410600201060ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestParamZeroChecksum(t *testing.T) { tt := []struct { binary []byte parsed *paramZeroChecksumAcceptable }{ { binary: []byte{0x80, 0x01, 0x00, 0x08, 0x00, 0x00, 0x00, 0x01}, parsed: ¶mZeroChecksumAcceptable{ paramHeader: paramHeader{ typ: zeroChecksumAcceptable, unrecognizedAction: paramHeaderUnrecognizedActionSkip, len: 8, raw: []byte{0x00, 0x00, 0x00, 0x01}, }, edmid: 1, }, }, } for i, tc := range tt { t.Run(fmt.Sprintf("%v", i), func(t *testing.T) { actual := ¶mZeroChecksumAcceptable{} _, err := actual.unmarshal(tc.binary) assert.NoError(t, err) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoError(t, err) assert.Equal(t, tc.binary, b) }) } } sctp-1.9.0/paramheader.go000066400000000000000000000067531512256410600153070ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "encoding/hex" "errors" "fmt" ) type paramHeaderUnrecognizedAction byte type paramHeader struct { typ paramType unrecognizedAction paramHeaderUnrecognizedAction len int raw []byte } /* The Parameter Types are encoded such that the highest-order 2 bits specify the action that is taken if the processing endpoint does not recognize the Parameter Type. 00 - Stop processing this parameter and do not process any further parameters within this chunk. 01 - Stop processing this parameter, do not process any further parameters within this chunk, and report the unrecognized parameter, as described in Section 3.2.2. 10 - Skip this parameter and continue processing. 11 - Skip this parameter and continue processing, but report the unrecognized parameter, as described in Section 3.2.2. https://www.rfc-editor.org/rfc/rfc9260.html#section-3.2.1 */ const ( paramHeaderUnrecognizedActionMask = 0b11000000 paramHeaderUnrecognizedActionStop paramHeaderUnrecognizedAction = 0b00000000 paramHeaderUnrecognizedActionStopAndReport paramHeaderUnrecognizedAction = 0b01000000 paramHeaderUnrecognizedActionSkip paramHeaderUnrecognizedAction = 0b10000000 paramHeaderUnrecognizedActionSkipAndReport paramHeaderUnrecognizedAction = 0b11000000 paramHeaderLength = 4 ) // Parameter header parse errors. var ( ErrParamHeaderTooShort = errors.New("param header too short") ErrParamHeaderSelfReportedLengthShorter = errors.New("param self reported length is shorter than header length") ErrParamHeaderSelfReportedLengthLonger = errors.New("param self reported length is longer than header length") ErrParamHeaderParseFailed = errors.New("failed to parse param type") ) func (p *paramHeader) marshal() ([]byte, error) { paramLengthPlusHeader := paramHeaderLength + len(p.raw) rawParam := make([]byte, paramLengthPlusHeader) binary.BigEndian.PutUint16(rawParam[0:], uint16(p.typ)) binary.BigEndian.PutUint16(rawParam[2:], uint16(paramLengthPlusHeader)) //nolint:gosec // G115 copy(rawParam[paramHeaderLength:], p.raw) return rawParam, nil } func (p *paramHeader) unmarshal(raw []byte) error { if len(raw) < paramHeaderLength { return ErrParamHeaderTooShort } paramLengthPlusHeader := binary.BigEndian.Uint16(raw[2:]) if int(paramLengthPlusHeader) < paramHeaderLength { return fmt.Errorf( "%w: param self reported length (%d) shorter than header length (%d)", ErrParamHeaderSelfReportedLengthShorter, int(paramLengthPlusHeader), paramHeaderLength, ) } if len(raw) < int(paramLengthPlusHeader) { return fmt.Errorf( "%w: param length (%d) shorter than its self reported length (%d)", ErrParamHeaderSelfReportedLengthLonger, len(raw), int(paramLengthPlusHeader), ) } typ, err := parseParamType(raw[0:]) if err != nil { return fmt.Errorf("%w: %v", ErrParamHeaderParseFailed, err) //nolint:errorlint } p.typ = typ p.unrecognizedAction = paramHeaderUnrecognizedAction(raw[0] & paramHeaderUnrecognizedActionMask) p.raw = raw[paramHeaderLength:paramLengthPlusHeader] p.len = int(paramLengthPlusHeader) return nil } func (p *paramHeader) length() int { return p.len } // String makes paramHeader printable. func (p paramHeader) String() string { return fmt.Sprintf("%s (%d): %s", p.typ, p.len, hex.Dump(p.raw)) } sctp-1.9.0/paramheader_test.go000066400000000000000000000024621512256410600163370ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func testParamHeader() []byte { return []byte{0x0, 0x1, 0x0, 0x4} } func TestParamHeader_Success(t *testing.T) { tt := []struct { binary []byte parsed *paramHeader }{ { testParamHeader(), ¶mHeader{ typ: heartbeatInfo, len: 4, raw: []byte{}, }, }, } for i, tc := range tt { actual := ¶mHeader{} err := actual.unmarshal(tc.binary) assert.NoErrorf(t, err, "failed to unmarshal #%d", i) assert.Equal(t, tc.parsed, actual) b, err := actual.marshal() assert.NoErrorf(t, err, "failed to marshal #%d", i) assert.Equal(t, tc.binary, b) } } func TestParamHeaderUnmarshal_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"header too short", testParamHeader()[:2]}, // {"wrong param type", []byte{0x0, 0x0, 0x0, 0x4}}, // Not possible to fail parseParamType atm. {"reported length below header length", []byte{0x0, 0xd, 0x0, 0x3}}, {"wrong reported length", testChunkReconfigParamA()[:4]}, } for i, tc := range tt { actual := ¶mHeader{} err := actual.unmarshal(tc.binary) assert.Errorf(t, err, "expected unmarshal #%d: '%s' to fail.", i, tc.name) } } sctp-1.9.0/paramtype.go000066400000000000000000000105011512256410600150220ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "encoding/binary" "errors" "fmt" ) // paramType represents a SCTP INIT/INITACK parameter. type paramType uint16 const ( heartbeatInfo paramType = 1 // Heartbeat Info [RFC9260] ipV4Addr paramType = 5 // IPv4 IP [RFC9260] ipV6Addr paramType = 6 // IPv6 IP [RFC9260] stateCookie paramType = 7 // State Cookie [RFC9260] unrecognizedParam paramType = 8 // Unrecognized Parameters [RFC9260] cookiePreservative paramType = 9 // Cookie Preservative [RFC9260] hostNameAddr paramType = 11 // Host Name Address [RFC9260] supportedAddrTypes paramType = 12 // Supported IP Types [RFC9260] outSSNResetReq paramType = 13 // Outgoing SSN Reset Request Parameter [RFC6525] incSSNResetReq paramType = 14 // Incoming SSN Reset Request Parameter [RFC6525] ssnTSNResetReq paramType = 15 // SSN/TSN Reset Request Parameter [RFC6525] reconfigResp paramType = 16 // Re-configuration Response Parameter [RFC6525] addOutStreamsReq paramType = 17 // Add Outgoing Streams Request Parameter [RFC6525] addIncStreamsReq paramType = 18 // Add Incoming Streams Request Parameter [RFC6525] ecnCapable paramType = 32768 // ECN Capable (0x8000) [RFC2960] zeroChecksumAcceptable paramType = 32769 // Zero Checksum Acceptable [draft-ietf-tsvwg-sctp-zero-checksum-00] random paramType = 32770 // Random (0x8002) [RFC4895] chunkList paramType = 32771 // Chunk List (0x8003) [RFC4895] reqHMACAlgo paramType = 32772 // Requested HMAC Algorithm Parameter (0x8004) [RFC4895] padding paramType = 32773 // Padding (0x8005) supportedExt paramType = 32776 // Supported Extensions (0x8008) [RFC5061] forwardTSNSupp paramType = 49152 // Forward TSN supported (0xC000) [RFC3758] addIPAddr paramType = 49153 // Add IP Address (0xC001) [RFC5061] delIPAddr paramType = 49154 // Delete IP Address (0xC002) [RFC5061] errClauseInd paramType = 49155 // Error Cause Indication (0xC003) [RFC5061] setPriAddr paramType = 49156 // Set Primary IP (0xC004) [RFC5061] successInd paramType = 49157 // Success Indication (0xC005) [RFC5061] adaptLayerInd paramType = 49158 // Adaptation Layer Indication (0xC006) [RFC5061] ) // Parameter packet errors. var ( ErrParamPacketTooShort = errors.New("packet too short") ) func parseParamType(raw []byte) (paramType, error) { if len(raw) < 2 { return paramType(0), ErrParamPacketTooShort } return paramType(binary.BigEndian.Uint16(raw)), nil } func (p paramType) String() string { //nolint:cyclop switch p { case heartbeatInfo: return "Heartbeat Info" case ipV4Addr: return "IPv4 IP" case ipV6Addr: return "IPv6 IP" case stateCookie: return "State Cookie" case unrecognizedParam: return "Unrecognized Parameters" case cookiePreservative: return "Cookie Preservative" case hostNameAddr: return "Host Name Address" case supportedAddrTypes: return "Supported IP Types" case outSSNResetReq: return "Outgoing SSN Reset Request Parameter" case incSSNResetReq: return "Incoming SSN Reset Request Parameter" case ssnTSNResetReq: return "SSN/TSN Reset Request Parameter" case reconfigResp: return "Re-configuration Response Parameter" case addOutStreamsReq: return "Add Outgoing Streams Request Parameter" case addIncStreamsReq: return "Add Incoming Streams Request Parameter" case ecnCapable: return "ECN Capable" case zeroChecksumAcceptable: return "Zero Checksum Acceptable" case random: return "Random" case chunkList: return "Chunk List" case reqHMACAlgo: return "Requested HMAC Algorithm Parameter" case padding: return "Padding" case supportedExt: return "Supported Extensions" case forwardTSNSupp: return "Forward TSN supported" case addIPAddr: return "Add IP Address" case delIPAddr: return "Delete IP Address" case errClauseInd: return "Error Cause Indication" case setPriAddr: return "Set Primary IP" case successInd: return "Success Indication" case adaptLayerInd: return "Adaptation Layer Indication" default: return fmt.Sprintf("Unknown ParamType: %d", p) } } sctp-1.9.0/paramtype_test.go000066400000000000000000000054121512256410600160660ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "fmt" "testing" "github.com/stretchr/testify/assert" ) func TestParseParamType_Success(t *testing.T) { tt := []struct { binary []byte expected paramType }{ {[]byte{0x0, 0x1}, heartbeatInfo}, {[]byte{0x0, 0xd}, outSSNResetReq}, } for i, tc := range tt { pType, err := parseParamType(tc.binary) assert.NoErrorf(t, err, "failed to parse paramType #%d", i) assert.Equal(t, tc.expected, pType) } } func TestParseParamType_Failure(t *testing.T) { tt := []struct { name string binary []byte }{ {"empty packet", []byte{}}, } for i, tc := range tt { _, err := parseParamType(tc.binary) assert.Errorf(t, err, "expected parseParamType #%d: '%s' to fail.", i, tc.name) } } func TestParamType_String(t *testing.T) { tests := []struct { name string in paramType want string }{ {"heartbeatInfo", heartbeatInfo, "Heartbeat Info"}, {"ipV4Addr", ipV4Addr, "IPv4 IP"}, {"ipV6Addr", ipV6Addr, "IPv6 IP"}, {"stateCookie", stateCookie, "State Cookie"}, {"unrecognizedParam", unrecognizedParam, "Unrecognized Parameters"}, {"cookiePreservative", cookiePreservative, "Cookie Preservative"}, {"hostNameAddr", hostNameAddr, "Host Name Address"}, {"supportedAddrTypes", supportedAddrTypes, "Supported IP Types"}, {"outSSNResetReq", outSSNResetReq, "Outgoing SSN Reset Request Parameter"}, {"incSSNResetReq", incSSNResetReq, "Incoming SSN Reset Request Parameter"}, {"ssnTSNResetReq", ssnTSNResetReq, "SSN/TSN Reset Request Parameter"}, {"reconfigResp", reconfigResp, "Re-configuration Response Parameter"}, {"addOutStreamsReq", addOutStreamsReq, "Add Outgoing Streams Request Parameter"}, {"addIncStreamsReq", addIncStreamsReq, "Add Incoming Streams Request Parameter"}, {"ecnCapable", ecnCapable, "ECN Capable"}, {"zeroChecksumAcceptable", zeroChecksumAcceptable, "Zero Checksum Acceptable"}, {"random", random, "Random"}, {"chunkList", chunkList, "Chunk List"}, {"reqHMACAlgo", reqHMACAlgo, "Requested HMAC Algorithm Parameter"}, {"padding", padding, "Padding"}, {"supportedExt", supportedExt, "Supported Extensions"}, {"forwardTSNSupp", forwardTSNSupp, "Forward TSN supported"}, {"addIPAddr", addIPAddr, "Add IP Address"}, {"delIPAddr", delIPAddr, "Delete IP Address"}, {"errClauseInd", errClauseInd, "Error Cause Indication"}, {"setPriAddr", setPriAddr, "Set Primary IP"}, {"successInd", successInd, "Success Indication"}, {"adaptLayerInd", adaptLayerInd, "Adaptation Layer Indication"}, {"unknownValue", paramType(0x4242), fmt.Sprintf("Unknown ParamType: %d", 0x4242)}, } for _, tc := range tests { t.Run(tc.name, func(t *testing.T) { got := tc.in.String() assert.Equal(t, tc.want, got) }) } } sctp-1.9.0/payload_queue.go000066400000000000000000000031111512256410600156540ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type payloadQueue struct { chunks *queue[*chunkPayloadData] nBytes int } func newPayloadQueue() *payloadQueue { return &payloadQueue{chunks: newQueue[*chunkPayloadData](128)} } func (q *payloadQueue) pushNoCheck(p *chunkPayloadData) { q.chunks.PushBack(p) q.nBytes += len(p.userData) } // pop pops only if the oldest chunk's TSN matches the given TSN. func (q *payloadQueue) pop(tsn uint32) (*chunkPayloadData, bool) { if q.chunks.Len() > 0 && tsn == q.chunks.Front().tsn { c := q.chunks.PopFront() q.nBytes -= len(c.userData) return c, true } return nil, false } // get returns reference to chunkPayloadData with the given TSN value. func (q *payloadQueue) get(tsn uint32) (*chunkPayloadData, bool) { length := q.chunks.Len() if length == 0 { return nil, false } head := q.chunks.Front().tsn if tsn < head || int(tsn-head) >= length { return nil, false } return q.chunks.At(int(tsn - head)), true } func (q *payloadQueue) markAsAcked(tsn uint32) int { var nBytesAcked int if c, ok := q.get(tsn); ok { c.acked = true c.retransmit = false nBytesAcked = len(c.userData) q.nBytes -= nBytesAcked c.userData = []byte{} } return nBytesAcked } func (q *payloadQueue) markAllToRetrasmit() { for i := 0; i < q.chunks.Len(); i++ { c := q.chunks.At(i) if c.acked || c.abandoned() { continue } c.retransmit = true } } func (q *payloadQueue) getNumBytes() int { return q.nBytes } func (q *payloadQueue) size() int { return q.chunks.Len() } sctp-1.9.0/payload_queue_test.go000066400000000000000000000060361512256410600167240ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func makePayload(tsn uint32, nBytes int) *chunkPayloadData { return &chunkPayloadData{tsn: tsn, userData: make([]byte, nBytes)} } func TestPayloadQueue(t *testing.T) { t.Run("pushNoCheck", func(t *testing.T) { pq := newPayloadQueue() pq.pushNoCheck(makePayload(0, 10)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 1, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(1, 11)) assert.Equal(t, 21, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 2, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(2, 12)) assert.Equal(t, 33, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 3, pq.size(), "item count mismatch") for i := uint32(0); i < 3; i++ { c, ok := pq.pop(i) assert.True(t, ok, "pop should succeed") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 0, pq.size(), "item count mismatch") pq.pushNoCheck(makePayload(3, 13)) assert.Equal(t, 13, pq.getNumBytes(), "total bytes mismatch") pq.pushNoCheck(makePayload(4, 14)) assert.Equal(t, 27, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { c, ok := pq.pop(i) assert.True(t, ok, "pop should succeed") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") assert.Equal(t, 0, pq.size(), "item count mismatch") }) t.Run("markAllToRetrasmit", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 3; i++ { pq.pushNoCheck(makePayload(uint32(i+1), 10)) //nolint:gosec // G115 } pq.markAsAcked(2) pq.markAllToRetrasmit() c, ok := pq.get(1) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(2) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") c, ok = pq.get(3) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") }) t.Run("reset retransmit flag on ack", func(t *testing.T) { pq := newPayloadQueue() for i := 0; i < 4; i++ { pq.pushNoCheck(makePayload(uint32(i+1), 10)) //nolint:gosec // G115 } pq.markAllToRetrasmit() pq.markAsAcked(2) // should cancel retransmission for TSN 2 pq.markAsAcked(4) // should cancel retransmission for TSN 4 c, ok := pq.get(1) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(2) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") c, ok = pq.get(3) assert.True(t, ok, "should be true") assert.True(t, c.retransmit, "should be marked as retransmit") c, ok = pq.get(4) assert.True(t, ok, "should be true") assert.False(t, c.retransmit, "should NOT be marked as retransmit") }) } sctp-1.9.0/pending_queue.go000066400000000000000000000067041512256410600156620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" ) // pendingBaseQueue type pendingBaseQueue struct { queue []*chunkPayloadData } func newPendingBaseQueue() *pendingBaseQueue { return &pendingBaseQueue{queue: []*chunkPayloadData{}} } func (q *pendingBaseQueue) push(c *chunkPayloadData) { q.queue = append(q.queue, c) } func (q *pendingBaseQueue) pop() *chunkPayloadData { if len(q.queue) == 0 { return nil } c := q.queue[0] q.queue[0] = nil if len(q.queue) == 0 { q.queue = nil } else { q.queue = q.queue[1:] } return c } func (q *pendingBaseQueue) get(i int) *chunkPayloadData { if len(q.queue) == 0 || i < 0 || i >= len(q.queue) { return nil } return q.queue[i] } func (q *pendingBaseQueue) size() int { return len(q.queue) } // pendingQueue type pendingQueue struct { unorderedQueue *pendingBaseQueue orderedQueue *pendingBaseQueue nBytes int selected bool unorderedIsSelected bool } // Pending queue errors. var ( ErrUnexpectedChunkPoppedUnordered = errors.New("unexpected chunk popped (unordered)") ErrUnexpectedChunkPoppedOrdered = errors.New("unexpected chunk popped (ordered)") ErrUnexpectedQState = errors.New("unexpected q state (should've been selected)") // Deprecated: use ErrUnexpectedChunkPoppedUnordered. ErrUnexpectedChuckPoppedUnordered = ErrUnexpectedChunkPoppedUnordered // Deprecated: use ErrUnexpectedChunkPoppedOrdered. ErrUnexpectedChuckPoppedOrdered = ErrUnexpectedChunkPoppedOrdered ) func newPendingQueue() *pendingQueue { return &pendingQueue{ unorderedQueue: newPendingBaseQueue(), orderedQueue: newPendingBaseQueue(), } } func (q *pendingQueue) push(c *chunkPayloadData) { if c.unordered { q.unorderedQueue.push(c) } else { q.orderedQueue.push(c) } q.nBytes += len(c.userData) } func (q *pendingQueue) peek() *chunkPayloadData { if q.selected { if q.unorderedIsSelected { return q.unorderedQueue.get(0) } return q.orderedQueue.get(0) } if c := q.unorderedQueue.get(0); c != nil { return c } return q.orderedQueue.get(0) } func (q *pendingQueue) pop(chunkPayload *chunkPayloadData) error { //nolint:cyclop if q.selected { //nolint:nestif var popped *chunkPayloadData if q.unorderedIsSelected { popped = q.unorderedQueue.pop() if popped != chunkPayload { return ErrUnexpectedChunkPoppedUnordered } } else { popped = q.orderedQueue.pop() if popped != chunkPayload { return ErrUnexpectedChunkPoppedOrdered } } if popped.endingFragment { q.selected = false } } else { if !chunkPayload.beginningFragment { return ErrUnexpectedQState } if chunkPayload.unordered { popped := q.unorderedQueue.pop() if popped != chunkPayload { return ErrUnexpectedChunkPoppedUnordered } if !popped.endingFragment { q.selected = true q.unorderedIsSelected = true } } else { popped := q.orderedQueue.pop() if popped != chunkPayload { return ErrUnexpectedChunkPoppedOrdered } if !popped.endingFragment { q.selected = true q.unorderedIsSelected = false } } } // guard against negative values (should never happen, but just in case). q.nBytes -= len(chunkPayload.userData) if q.nBytes < 0 { q.nBytes = 0 } return nil } func (q *pendingQueue) getNumBytes() int { return q.nBytes } func (q *pendingQueue) size() int { return q.unorderedQueue.size() + q.orderedQueue.size() } sctp-1.9.0/pending_queue_test.go000066400000000000000000000171021512256410600167130ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) const ( noFragment = iota fragBegin fragMiddle fragEnd ) func makeDataChunk(tsn uint32, unordered bool, frag int) *chunkPayloadData { var begin, end bool switch frag { case noFragment: begin = true end = true case fragBegin: begin = true case fragEnd: end = true } return &chunkPayloadData{ tsn: tsn, unordered: unordered, beginningFragment: begin, endingFragment: end, userData: make([]byte, 10), // always 10 bytes } } func TestPendingBaseQueue(t *testing.T) { t.Run("push and pop", func(t *testing.T) { pq := newPendingBaseQueue() pq.push(makeDataChunk(0, false, noFragment)) pq.push(makeDataChunk(1, false, noFragment)) pq.push(makeDataChunk(2, false, noFragment)) for i := uint32(0); i < 3; i++ { c := pq.get(int(i)) assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } for i := uint32(0); i < 3; i++ { c := pq.pop() assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } pq.push(makeDataChunk(3, false, noFragment)) pq.push(makeDataChunk(4, false, noFragment)) for i := uint32(3); i < 5; i++ { c := pq.pop() assert.NotNil(t, c, "should not be nil") assert.Equal(t, i, c.tsn, "TSN should match") } }) t.Run("out of bounds", func(t *testing.T) { pq := newPendingBaseQueue() assert.Nil(t, pq.pop(), "should be nil") assert.Nil(t, pq.get(0), "should be nil") pq.push(makeDataChunk(0, false, noFragment)) assert.Nil(t, pq.get(-1), "should be nil") assert.Nil(t, pq.get(1), "should be nil") }) } func TestPendingQueue(t *testing.T) { // NOTE: TSN is not used in pendingQueue in the actual usage. // Following tests use TSN field as a chunk ID. t.Run("push and pop", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(1, false, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(2, false, noFragment)) assert.Equal(t, 30, pq.getNumBytes(), "total bytes mismatch") for i := uint32(0); i < 3; i++ { c := pq.peek() err := pq.pop(c) assert.Nil(t, err, "should not error") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(3, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(4, false, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") for i := uint32(3); i < 5; i++ { c := pq.peek() err := pq.pop(c) assert.Nil(t, err, "should not error") assert.Equal(t, i, c.tsn, "TSN should match") } assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") }) t.Run("unordered wins", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, noFragment)) assert.Equal(t, 10, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(1, true, noFragment)) assert.Equal(t, 20, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(2, false, noFragment)) assert.Equal(t, 30, pq.getNumBytes(), "total bytes mismatch") pq.push(makeDataChunk(3, true, noFragment)) assert.Equal(t, 40, pq.getNumBytes(), "total bytes mismatch") chunkPayload := pq.peek() err := pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(1), chunkPayload.tsn, "TSN should match") chunkPayload = pq.peek() err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(3), chunkPayload.tsn, "TSN should match") chunkPayload = pq.peek() err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(0), chunkPayload.tsn, "TSN should match") chunkPayload = pq.peek() err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(2), chunkPayload.tsn, "TSN should match") assert.Equal(t, 0, pq.getNumBytes(), "total bytes mismatch") }) t.Run("fragments", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, fragBegin)) pq.push(makeDataChunk(1, false, fragMiddle)) pq.push(makeDataChunk(2, false, fragEnd)) pq.push(makeDataChunk(3, true, fragBegin)) pq.push(makeDataChunk(4, true, fragMiddle)) pq.push(makeDataChunk(5, true, fragEnd)) expects := []uint32{3, 4, 5, 0, 1, 2} for _, exp := range expects { c := pq.peek() err := pq.pop(c) assert.NoError(t, err, "should not error") assert.Equal(t, exp, c.tsn, "TSN should match") } }) // Once decided ordered or unordered, the decision should persist until // it pops a chunk with endingFragment flags set to true. t.Run("selection persistence", func(t *testing.T) { pq := newPendingQueue() pq.push(makeDataChunk(0, false, fragBegin)) chunkPayload := pq.peek() err := pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, uint32(0), chunkPayload.tsn, "TSN should match") pq.push(makeDataChunk(1, true, noFragment)) pq.push(makeDataChunk(2, false, fragMiddle)) pq.push(makeDataChunk(3, false, fragEnd)) expects := []uint32{2, 3, 1} for _, exp := range expects { chunkPayload = pq.peek() err = pq.pop(chunkPayload) assert.NoError(t, err, "should not error") assert.Equal(t, exp, chunkPayload.tsn, "TSN should match") } }) } func TestPendingQueue_PopErrors(t *testing.T) { t.Run("ErrUnexpectedQState when not selected and not beginningFragment", func(t *testing.T) { pq := newPendingQueue() mid := makeDataChunk(100, false, fragMiddle) err := pq.pop(mid) assert.ErrorIs(t, err, ErrUnexpectedQState) }) t.Run("ErrUnexpectedChunkPoppedUnordered (not selected path)", func(t *testing.T) { pq := newPendingQueue() u1 := makeDataChunk(1, true, noFragment) u2 := makeDataChunk(2, true, noFragment) pq.push(u1) pq.push(u2) err := pq.pop(u2) assert.ErrorIs(t, err, ErrUnexpectedChunkPoppedUnordered) }) t.Run("ErrUnexpectedChunkPoppedOrdered (not selected path)", func(t *testing.T) { pq := newPendingQueue() o1 := makeDataChunk(10, false, noFragment) o2 := makeDataChunk(11, false, noFragment) pq.push(o1) pq.push(o2) err := pq.pop(o2) assert.ErrorIs(t, err, ErrUnexpectedChunkPoppedOrdered) }) t.Run("ErrUnexpectedChunkPoppedUnordered (selected unordered path)", func(t *testing.T) { pq := newPendingQueue() uBegin := makeDataChunk(21, true, fragBegin) uMid := makeDataChunk(22, true, fragMiddle) uEnd := makeDataChunk(23, true, fragEnd) pq.push(uBegin) pq.push(uMid) pq.push(uEnd) err := pq.pop(uBegin) assert.NoError(t, err) err = pq.pop(uEnd) assert.ErrorIs(t, err, ErrUnexpectedChunkPoppedUnordered) }) t.Run("ErrUnexpectedChunkPoppedOrdered (selected ordered path)", func(t *testing.T) { pq := newPendingQueue() oBegin := makeDataChunk(31, false, fragBegin) oMid := makeDataChunk(32, false, fragMiddle) oEnd := makeDataChunk(33, false, fragEnd) pq.push(oBegin) pq.push(oMid) pq.push(oEnd) err := pq.pop(oBegin) assert.NoError(t, err) err = pq.pop(oEnd) assert.ErrorIs(t, err, ErrUnexpectedChunkPoppedOrdered) }) t.Run("nBytes guard clamps to zero when underflows", func(t *testing.T) { pq := newPendingQueue() c := makeDataChunk(40, false, noFragment) pq.push(c) pq.nBytes = 5 peek := pq.peek() err := pq.pop(peek) assert.NoError(t, err) assert.Equal(t, 0, pq.getNumBytes()) }) } sctp-1.9.0/queue.go000066400000000000000000000023051512256410600141470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp type queue[T any] struct { buf []T head int tail int count int } const minCap = 16 func newQueue[T any](capacity int) *queue[T] { queueCap := minCap for queueCap < capacity { queueCap <<= 1 } return &queue[T]{ buf: make([]T, queueCap), } } func (q *queue[T]) Len() int { return q.count } func (q *queue[T]) PushBack(ele T) { q.growIfFull() q.buf[q.tail] = ele q.tail = (q.tail + 1) % len(q.buf) q.count++ } func (q *queue[T]) PopFront() T { ele := q.buf[q.head] var zeroVal T q.buf[q.head] = zeroVal q.head = (q.head + 1) % len(q.buf) q.count-- return ele } func (q *queue[T]) Front() T { return q.buf[q.head] } func (q *queue[T]) Back() T { return q.buf[(q.tail-1+len(q.buf))%len(q.buf)] } func (q *queue[T]) At(i int) T { return q.buf[(q.head+i)%(len(q.buf))] } func (q *queue[T]) growIfFull() { if q.count < len(q.buf) { return } newBuf := make([]T, q.count<<1) if q.tail > q.head { copy(newBuf, q.buf[q.head:q.tail]) } else { n := copy(newBuf, q.buf[q.head:]) copy(newBuf[n:], q.buf[:q.tail]) } q.head = 0 q.tail = q.count q.buf = newBuf } sctp-1.9.0/queue_test.go000066400000000000000000000045001512256410600152050ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "runtime" "runtime/debug" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestQueue(t *testing.T) { queu := newQueue[int](32) assert.Zero(t, queu.Len()) // test push & pop for i := 1; i < 33; i++ { queu.PushBack(i) } assert.Equal(t, 32, queu.Len()) assert.Equal(t, 5, queu.At(4)) for i := 1; i < 33; i++ { assert.Equal(t, i, queu.Front()) assert.Equal(t, i, queu.PopFront()) } assert.Zero(t, queu.Len()) queu.PushBack(10) queu.PushBack(11) assert.Equal(t, 2, queu.Len()) assert.Equal(t, 11, queu.At(1)) assert.Equal(t, 10, queu.Front()) assert.Equal(t, 10, queu.PopFront()) assert.Equal(t, 11, queu.PopFront()) // test grow capacity for i := 0; i < 64; i++ { queu.PushBack(i) } assert.Equal(t, 64, queu.Len()) assert.Equal(t, 2, queu.At(2)) for i := 0; i < 64; i++ { assert.Equal(t, i, queu.Front()) assert.Equal(t, i, queu.PopFront()) } } // waitForFinalizers spins until at least target have run or timeout hits. func waitForFinalizers(got *int32, target int32, timeout time.Duration) bool { deadline := time.Now().Add(timeout) for time.Now().Before(deadline) { runtime.GC() if atomic.LoadInt32(got) >= target { return true } time.Sleep(10 * time.Millisecond) } return atomic.LoadInt32(got) >= target } func TestPendingBaseQueuePopReleasesReferences(t *testing.T) { // Make GC more aggressive for the duration of this test. prev := debug.SetGCPercent(10) defer debug.SetGCPercent(prev) bufSize := 256 << 10 queue := newPendingBaseQueue() var finalized int32 // add 64 chunks, each with a finalizer to count collection. for i := 0; i < 64; i++ { c := &chunkPayloadData{ userData: make([]byte, bufSize), } // count when the chunk struct becomes unreachable. runtime.SetFinalizer(c, func(*chunkPayloadData) { atomic.AddInt32(&finalized, 1) }) queue.push(c) } // pop 63 chunks so only 1 is left for i := 0; i < 63; i++ { queue.pop() } assert.Equal(t, queue.size(), 1) wantAtLeast := int32(64 - 4) // wait for GC scheduling ok := waitForFinalizers(&finalized, wantAtLeast, 3*time.Second) assert.True(t, ok) // Now pop the last element; queue should be empty. queue.pop() assert.Equal(t, queue.size(), 0) } sctp-1.9.0/reassembly_queue.go000066400000000000000000000203611512256410600163770ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "io" "sort" "sync/atomic" ) func sortChunksByTSN(a []*chunkPayloadData) { sort.Slice(a, func(i, j int) bool { return sna32LT(a[i].tsn, a[j].tsn) }) } func sortChunksBySSN(a []*chunkSet) { sort.Slice(a, func(i, j int) bool { return sna16LT(a[i].ssn, a[j].ssn) }) } // chunkSet is a set of chunks that share the same SSN. type chunkSet struct { ssn uint16 // used only with the ordered chunks ppi PayloadProtocolIdentifier chunks []*chunkPayloadData } func newChunkSet(ssn uint16, ppi PayloadProtocolIdentifier) *chunkSet { return &chunkSet{ ssn: ssn, ppi: ppi, chunks: []*chunkPayloadData{}, } } func (set *chunkSet) push(chunk *chunkPayloadData) bool { // check if dup for _, c := range set.chunks { if c.tsn == chunk.tsn { return false } } // append and sort set.chunks = append(set.chunks, chunk) sortChunksByTSN(set.chunks) // Check if we now have a complete set complete := set.isComplete() return complete } func (set *chunkSet) isComplete() bool { // Condition for complete set // 0. Has at least one chunk. // 1. Begins with beginningFragment set to true // 2. Ends with endingFragment set to true // 3. TSN monotinically increase by 1 from beginning to end // 0. nChunks := len(set.chunks) if nChunks == 0 { return false } // 1. if !set.chunks[0].beginningFragment { return false } // 2. if !set.chunks[nChunks-1].endingFragment { return false } // 3. var lastTSN uint32 for i, chunk := range set.chunks { if i > 0 { // Fragments must have contiguous TSN // From RFC 4960 Section 3.3.1: // When a user message is fragmented into multiple chunks, the TSNs are // used by the receiver to reassemble the message. This means that the // TSNs for each fragment of a fragmented user message MUST be strictly // sequential. if chunk.tsn != lastTSN+1 { // mid or end fragment is missing return false } } lastTSN = chunk.tsn } return true } type reassemblyQueue struct { si uint16 nextSSN uint16 // expected SSN for next ordered chunk ordered []*chunkSet unordered []*chunkSet unorderedChunks []*chunkPayloadData nBytes uint64 } var errTryAgain = errors.New("try again") func newReassemblyQueue(si uint16) *reassemblyQueue { // From RFC 4960 Sec 6.5: // The Stream Sequence Number in all the streams MUST start from 0 when // the association is established. Also, when the Stream Sequence // Number reaches the value 65535 the next Stream Sequence Number MUST // be set to 0. return &reassemblyQueue{ si: si, nextSSN: 0, // From RFC 4960 Sec 6.5: ordered: make([]*chunkSet, 0), unordered: make([]*chunkSet, 0), } } func (r *reassemblyQueue) push(chunk *chunkPayloadData) bool { //nolint:cyclop var cset *chunkSet if chunk.streamIdentifier != r.si { return false } if chunk.unordered { // First, insert into unorderedChunks array r.unorderedChunks = append(r.unorderedChunks, chunk) atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) sortChunksByTSN(r.unorderedChunks) // Scan unorderedChunks that are contiguous (in TSN) cset = r.findCompleteUnorderedChunkSet() // If found, append the complete set to the unordered array if cset != nil { r.unordered = append(r.unordered, cset) return true } return false } // This is an ordered chunk if sna16LT(chunk.streamSequenceNumber, r.nextSSN) { return false } // Check if a fragmented chunkSet with the fragmented SSN already exists if chunk.isFragmented() { for _, set := range r.ordered { // nolint:godox // TODO: add caution around SSN wrapping here... this helps only a little bit // by ensuring we don't add to an unfragmented cset (1 chunk). There's // a case where if the SSN does wrap around, we may see the same SSN // for a different chunk. // nolint:godox // TODO: this slice can get pretty big; it may be worth maintaining a map // for O(1) lookups at the cost of 2x memory. if set.ssn == chunk.streamSequenceNumber && set.chunks[0].isFragmented() { cset = set break } } } // If not found, create a new chunkSet if cset == nil { cset = newChunkSet(chunk.streamSequenceNumber, chunk.payloadType) r.ordered = append(r.ordered, cset) if !chunk.unordered { sortChunksBySSN(r.ordered) } } atomic.AddUint64(&r.nBytes, uint64(len(chunk.userData))) return cset.push(chunk) } func (r *reassemblyQueue) findCompleteUnorderedChunkSet() *chunkSet { startIdx := -1 nChunks := 0 var lastTSN uint32 var found bool for i, chunk := range r.unorderedChunks { // seek beigining if chunk.beginningFragment { startIdx = i nChunks = 1 lastTSN = chunk.tsn if chunk.endingFragment { found = true break } continue } if startIdx < 0 { continue } // Check if contiguous in TSN if chunk.tsn != lastTSN+1 { startIdx = -1 continue } lastTSN = chunk.tsn nChunks++ if chunk.endingFragment { found = true break } } if !found { return nil } // Extract the range of chunks var chunks []*chunkPayloadData chunks = append(chunks, r.unorderedChunks[startIdx:startIdx+nChunks]...) r.unorderedChunks = append( r.unorderedChunks[:startIdx], r.unorderedChunks[startIdx+nChunks:]...) chunkSet := newChunkSet(0, chunks[0].payloadType) chunkSet.chunks = chunks return chunkSet } func (r *reassemblyQueue) isReadable() bool { // Check unordered first if len(r.unordered) > 0 { // The chunk sets in r.unordered should all be complete. return true } // Check ordered sets if len(r.ordered) > 0 { cset := r.ordered[0] if cset.isComplete() { if sna16LTE(cset.ssn, r.nextSSN) { return true } } } return false } func (r *reassemblyQueue) read(buf []byte) (int, PayloadProtocolIdentifier, error) { // nolint: cyclop var ( cset *chunkSet isUnordered bool nTotal int err error ) switch { case len(r.unordered) > 0: cset = r.unordered[0] isUnordered = true case len(r.ordered) > 0: cset = r.ordered[0] if !cset.isComplete() { return 0, 0, errTryAgain } if sna16GT(cset.ssn, r.nextSSN) { return 0, 0, errTryAgain } default: return 0, 0, errTryAgain } for _, c := range cset.chunks { if len(buf)-nTotal < len(c.userData) { err = io.ErrShortBuffer } else { copy(buf[nTotal:], c.userData) } nTotal += len(c.userData) } switch { case err != nil: return nTotal, 0, err case isUnordered: r.unordered = r.unordered[1:] default: r.ordered = r.ordered[1:] if cset.ssn == r.nextSSN { r.nextSSN++ } } r.subtractNumBytes(nTotal) return nTotal, cset.ppi, err } func (r *reassemblyQueue) forwardTSNForOrdered(lastSSN uint16) { // Use lastSSN to locate a chunkSet then remove it if the set has // not been complete keep := []*chunkSet{} for _, set := range r.ordered { if sna16LTE(set.ssn, lastSSN) { if !set.isComplete() { // drop the set for _, c := range set.chunks { r.subtractNumBytes(len(c.userData)) } continue } } keep = append(keep, set) } r.ordered = keep // Finally, forward nextSSN if sna16LTE(r.nextSSN, lastSSN) { r.nextSSN = lastSSN + 1 } } func (r *reassemblyQueue) forwardTSNForUnordered(newCumulativeTSN uint32) { // Remove all fragments in the unordered sets that contains chunks // equal to or older than `newCumulativeTSN`. // We know all sets in the r.unordered are complete ones. // Just remove chunks that are equal to or older than newCumulativeTSN // from the unorderedChunks lastIdx := -1 for i, c := range r.unorderedChunks { if sna32GT(c.tsn, newCumulativeTSN) { break } lastIdx = i } if lastIdx >= 0 { for _, c := range r.unorderedChunks[0 : lastIdx+1] { r.subtractNumBytes(len(c.userData)) } r.unorderedChunks = r.unorderedChunks[lastIdx+1:] } } func (r *reassemblyQueue) subtractNumBytes(nBytes int) { cur := atomic.LoadUint64(&r.nBytes) if int(cur) >= nBytes { //nolint:gosec // G115 atomic.AddUint64(&r.nBytes, -uint64(nBytes)) //nolint:gosec // G115 } else { atomic.StoreUint64(&r.nBytes, 0) } } func (r *reassemblyQueue) getNumBytes() int { return int(atomic.LoadUint64(&r.nBytes)) //nolint:gosec // G115 } sctp-1.9.0/reassembly_queue_test.go000066400000000000000000000373211512256410600174420ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "io" "testing" "github.com/stretchr/testify/assert" ) func TestReassemblyQueue(t *testing.T) { //nolint:maintidx t.Run("ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, endingFragment: true, tsn: 2, streamSequenceNumber: 0, userData: []byte("DEFG"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 7, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 7, n, "should received 7 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABCDEFG", "data should match") }) t.Run("ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 2, streamSequenceNumber: 0, userData: []byte("DEFG"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 7, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, endingFragment: true, tsn: 3, streamSequenceNumber: 0, userData: []byte("H"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 8, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 8, n, "should received 8 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABCDEFGH", "data should match") }) t.Run("ordered and unordered in the mix", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 1, streamSequenceNumber: 0, userData: []byte("ABC"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, endingFragment: true, tsn: 2, streamSequenceNumber: 1, userData: []byte("DEF"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") // // Now we have two complete chunks ready to read in the reassemblyQueue. // buf := make([]byte, 16) // Should read unordered chunks first n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 3, n, "should received 3 bytes") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "DEF", "data should match") // Next should read ordered chunks n, ppi, err = rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 3, n, "should received 3 bytes") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "ABC", "data should match") }) t.Run("unordered complete skips incomplete", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 10, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, endingFragment: true, tsn: 12, // <- incongiguous streamSequenceNumber: 1, userData: []byte("COMPLETE"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 10, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, endingFragment: true, tsn: 13, streamSequenceNumber: 1, userData: []byte("GOOD"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 14, rq.getNumBytes(), "num bytes mismatch") // // Now we have two complete chunks ready to read in the reassemblyQueue. // buf := make([]byte, 16) // Should pick the one that has "GOOD" n, ppi, err := rq.read(buf) assert.Nil(t, err, "read() should succeed") assert.Equal(t, 4, n, "should receive 4 bytes") assert.Equal(t, 10, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, ppi, orgPpi, "should have valid ppi") assert.Equal(t, string(buf[:n]), "GOOD", "data should match") }) t.Run("ignores chunk with wrong SI", func(t *testing.T) { rq := newReassemblyQueue(123) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, streamIdentifier: 124, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk should be ignored") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") }) t.Run("ignores chunk with stale SSN", func(t *testing.T) { rq := newReassemblyQueue(0) rq.nextSSN = 7 // forcibly set expected SSN to 7 orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: 6, // <-- stale userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "chunk should not be ignored") assert.Equal(t, 0, rq.getNumBytes(), "num bytes mismatch") }) t.Run("should fail to read incomplete chunk", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 123, streamSequenceNumber: 0, userData: []byte("IN"), } complete = rq.push(chunk) assert.False(t, complete, "the set should not be complete") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) _, _, err := rq.read(buf) assert.NotNil(t, err, "read() should not succeed") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") }) t.Run("should fail to read if the next SSN is not ready", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 123, streamSequenceNumber: 1, userData: []byte("IN"), } complete = rq.push(chunk) assert.True(t, complete, "the set should be complete") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") buf := make([]byte, 16) _, _, err := rq.read(buf) assert.NotNil(t, err, "read() should not succeed") assert.Equal(t, 2, rq.getNumBytes(), "num bytes mismatch") }) t.Run("detect buffer too short", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary for _, chunk := range []*chunkPayloadData{ { payloadType: orgPpi, beginningFragment: true, tsn: 123, streamSequenceNumber: 0, userData: []byte("0123"), }, { payloadType: orgPpi, tsn: 124, streamSequenceNumber: 0, userData: []byte("456"), }, { payloadType: orgPpi, endingFragment: true, tsn: 125, streamSequenceNumber: 0, userData: []byte("789"), }, } { rq.push(chunk) } assert.Equal(t, 10, rq.getNumBytes()) buf := make([]byte, 6) // <- passing buffer too short n, ppi, err := rq.read(buf) assert.Equal(t, io.ErrShortBuffer, err) assert.Equal(t, PayloadTypeUnknown, ppi) assert.Equal(t, 10, n) buf = make([]byte, n) n, ppi, err = rq.read(buf) assert.NoError(t, err) assert.Equal(t, orgPpi, ppi) assert.Equal(t, 10, n) }) t.Run("forwardTSN for ordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool var ssnComplete uint16 = 5 var ssnDropped uint16 = 6 chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, endingFragment: true, tsn: 10, streamSequenceNumber: ssnComplete, userData: []byte("123"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 11, streamSequenceNumber: ssnDropped, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, tsn: 12, streamSequenceNumber: ssnDropped, userData: []byte("DEF"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 9, rq.getNumBytes(), "num bytes mismatch") rq.forwardTSNForOrdered(ssnDropped) assert.Equal(t, 1, len(rq.ordered), "there should be one chunk left") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") }) t.Run("forwardTSN for unordered fragments", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool var ssnDropped uint16 = 6 var ssnKept uint16 = 7 chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, beginningFragment: true, tsn: 11, streamSequenceNumber: ssnDropped, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 12, streamSequenceNumber: ssnDropped, userData: []byte("DEF"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, unordered: true, tsn: 14, beginningFragment: true, streamSequenceNumber: ssnKept, userData: []byte("SOS"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 9, rq.getNumBytes(), "num bytes mismatch") // At this point, there are 3 chunks in the rq.unorderedChunks. // This call should remove chunks with tsn equals to 13 or older. rq.forwardTSNForUnordered(13) // As a result, there should be one chunk (tsn=14) assert.Equal(t, 1, len(rq.unorderedChunks), "there should be one chunk kept") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") }) t.Run("fragmented and unfragmented chunks with the same ssn", func(t *testing.T) { rq := newReassemblyQueue(0) orgPpi := PayloadTypeWebRTCBinary var chunk *chunkPayloadData var complete bool var ssn uint16 = 6 chunk = &chunkPayloadData{ payloadType: orgPpi, tsn: 12, beginningFragment: true, endingFragment: true, streamSequenceNumber: ssn, userData: []byte("DEF"), } complete = rq.push(chunk) assert.True(t, complete, "chunk set should be complete") assert.Equal(t, 3, rq.getNumBytes(), "num bytes mismatch") chunk = &chunkPayloadData{ payloadType: orgPpi, beginningFragment: true, tsn: 11, streamSequenceNumber: ssn, userData: []byte("ABC"), } complete = rq.push(chunk) assert.False(t, complete, "chunk set should not be complete yet") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") assert.Equal(t, 2, len(rq.ordered), "there should be two chunks") assert.Equal(t, 6, rq.getNumBytes(), "num bytes mismatch") }) } func TestChunkSet(t *testing.T) { t.Run("Empty chunkSet", func(t *testing.T) { cset := newChunkSet(0, 0) assert.False(t, cset.isComplete(), "empty chunkSet cannot be complete") }) t.Run("Push dup chunks to chunkSet", func(t *testing.T) { cset := newChunkSet(0, 0) cset.push(&chunkPayloadData{ tsn: 100, beginningFragment: true, }) complete := cset.push(&chunkPayloadData{ tsn: 100, endingFragment: true, }) assert.False(t, complete, "chunk with dup TSN is not complete") nChunks := len(cset.chunks) assert.Equal(t, 1, nChunks, "chunk with dup TSN should be ignored") }) t.Run("Incomplete chunkSet: no beginning", func(t *testing.T) { cset := &chunkSet{ ssn: 0, ppi: 0, chunks: []*chunkPayloadData{{}}, } assert.False(t, cset.isComplete(), "chunkSet not starting with B=1 cannot be complete") }) t.Run("Incomplete chunkSet: no contiguous tsn", func(t *testing.T) { cset := &chunkSet{ ssn: 0, ppi: 0, chunks: []*chunkPayloadData{ { tsn: 100, beginningFragment: true, }, { tsn: 101, }, { tsn: 103, endingFragment: true, }, }, } assert.False(t, cset.isComplete(), "chunkSet not starting with incontiguous tsn cannot be complete") }) } sctp-1.9.0/receive_payload_queue.go000066400000000000000000000120431512256410600173620ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "fmt" "math/bits" ) type receivePayloadQueue struct { tailTSN uint32 chunkSize int tsnBitmask []uint64 dupTSN []uint32 maxTSNOffset uint32 cumulativeTSN uint32 } func newReceivePayloadQueue(maxTSNOffset uint32) *receivePayloadQueue { maxTSNOffset = ((maxTSNOffset + 63) / 64) * 64 return &receivePayloadQueue{ tsnBitmask: make([]uint64, maxTSNOffset/64), maxTSNOffset: maxTSNOffset, } } func (q *receivePayloadQueue) init(cumulativeTSN uint32) { q.cumulativeTSN = cumulativeTSN q.tailTSN = cumulativeTSN q.chunkSize = 0 for i := range q.tsnBitmask { q.tsnBitmask[i] = 0 } q.dupTSN = q.dupTSN[:0] } func (q *receivePayloadQueue) hasChunk(tsn uint32) bool { if q.chunkSize == 0 || sna32LTE(tsn, q.cumulativeTSN) || sna32GT(tsn, q.tailTSN) { return false } index, offset := int(tsn/64)%len(q.tsnBitmask), tsn%64 return q.tsnBitmask[index]&(1<> uint64(start)) //nolint:gosec // G115 return i + start, i+start < end } func getFirstZeroBit(val uint64, start, end int) (int, bool) { return getFirstNonZeroBit(^val, start, end) } sctp-1.9.0/receive_payload_queue_test.go000066400000000000000000000135611512256410600204270ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "math" "testing" "github.com/stretchr/testify/assert" ) func TestReceivePayloadQueue(t *testing.T) { maxOffset := uint32(512) payloadQueue := newReceivePayloadQueue(maxOffset) initTSN := uint32(math.MaxUint32 - 10) payloadQueue.init(initTSN - 2) assert.Equal(t, initTSN-2, payloadQueue.getcumulativeTSN()) assert.Zero(t, payloadQueue.size()) _, ok := payloadQueue.getLastTSNReceived() assert.False(t, ok) assert.Empty(t, payloadQueue.getGapAckBlocks()) // force pop empy queue to advance cumulative TSN assert.False(t, payloadQueue.pop(true)) assert.Equal(t, initTSN-1, payloadQueue.getcumulativeTSN()) assert.Zero(t, payloadQueue.size()) assert.Empty(t, payloadQueue.getGapAckBlocks()) nextTSN := initTSN + maxOffset - 1 assert.True(t, payloadQueue.push(nextTSN)) assert.Equal(t, 1, payloadQueue.size()) lastTSN, ok := payloadQueue.getLastTSNReceived() assert.Truef(t, lastTSN == nextTSN && ok, "lastTSN:%d, ok:%t", lastTSN, ok) assert.True(t, payloadQueue.hasChunk(nextTSN)) assert.True(t, payloadQueue.push(initTSN)) assert.False(t, payloadQueue.canPush(initTSN-1)) assert.False(t, payloadQueue.canPush(initTSN+maxOffset)) assert.False(t, payloadQueue.push(initTSN+maxOffset)) assert.True(t, payloadQueue.canPush(nextTSN-1)) assert.Equal(t, 2, payloadQueue.size()) gaps := payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ {start: uint16(1), end: uint16(1)}, {start: uint16(maxOffset), end: uint16(maxOffset)}, }, gaps) assert.True(t, payloadQueue.pop(false)) assert.Equal(t, 1, payloadQueue.size()) assert.Equal(t, initTSN, payloadQueue.cumulativeTSN) assert.False(t, payloadQueue.pop(false)) assert.Equal(t, initTSN, payloadQueue.cumulativeTSN) size := payloadQueue.size() // push tsn with two gap // tsnRange [[start,end]...] tsnRange := [][]uint32{ {initTSN + 5, initTSN + 6}, {initTSN + 9, initTSN + 140}, } range0, range1 := tsnRange[0], tsnRange[1] for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { assert.True(t, payloadQueue.push(tsn)) assert.False(t, payloadQueue.pop(false)) assert.True(t, payloadQueue.hasChunk(tsn)) } size += int(range0[1] - range0[0] + 1) for tsn := range1[0]; sna32LTE(tsn, range1[1]); tsn++ { assert.True(t, payloadQueue.push(tsn)) assert.False(t, payloadQueue.pop(false)) assert.True(t, payloadQueue.hasChunk(tsn)) } size += int(range1[1] - range1[0] + 1) assert.Equal(t, size, payloadQueue.size()) gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ //nolint:gosec // G115 {start: uint16(range0[0] - initTSN), end: uint16(range0[1] - initTSN)}, //nolint:gosec // G115 {start: uint16(range1[0] - initTSN), end: uint16(range1[1] - initTSN)}, //nolint:gosec // G115 {start: uint16(nextTSN - initTSN), end: uint16(nextTSN - initTSN)}, }, gaps) // push duplicate tsns assert.False(t, payloadQueue.push(initTSN-2)) assert.False(t, payloadQueue.push(range0[0])) assert.False(t, payloadQueue.push(range0[0])) assert.False(t, payloadQueue.push(nextTSN)) assert.False(t, payloadQueue.push(initTSN+maxOffset+1)) duplicates := payloadQueue.popDuplicates() assert.EqualValues(t, []uint32{initTSN - 2, range0[0], range0[0], nextTSN}, duplicates) // force pop to advance cumulativeTSN to fill the gap [initTSN, initTSN+4] for tsn := initTSN + 1; sna32LT(tsn, range0[0]); tsn++ { assert.False(t, payloadQueue.pop(true)) assert.Equal(t, size, payloadQueue.size()) assert.Equal(t, tsn, payloadQueue.cumulativeTSN) } for tsn := range0[0]; sna32LTE(tsn, range0[1]); tsn++ { assert.True(t, payloadQueue.pop(false)) assert.Equal(t, tsn, payloadQueue.getcumulativeTSN()) } assert.False(t, payloadQueue.pop(false)) cumulativeTSN := payloadQueue.getcumulativeTSN() assert.Equal(t, range0[1], cumulativeTSN) gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ //nolint:gosec // G115 {start: uint16(range1[0] - range0[1]), end: uint16(range1[1] - range0[1])}, //nolint:gosec // G115 {start: uint16(nextTSN - range0[1]), end: uint16(nextTSN - range0[1])}, }, gaps) // fill the gap with received tsn for tsn := range0[1] + 1; sna32LT(tsn, range1[0]); tsn++ { assert.True(t, payloadQueue.push(tsn), tsn) } for tsn := range0[1] + 1; sna32LTE(tsn, range1[1]); tsn++ { assert.True(t, payloadQueue.pop(false)) assert.Equal(t, tsn, payloadQueue.getcumulativeTSN()) } assert.False(t, payloadQueue.pop(false)) assert.Equal(t, range1[1], payloadQueue.getcumulativeTSN()) gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ //nolint:gosec // G115 {start: uint16(nextTSN - range1[1]), end: uint16(nextTSN - range1[1])}, }, gaps) // gap block cross end tsn endTSN := maxOffset - 1 for tsn := nextTSN + 1; sna32LTE(tsn, endTSN); tsn++ { assert.True(t, payloadQueue.push(tsn)) } gaps = payloadQueue.getGapAckBlocks() assert.EqualValues(t, []gapAckBlock{ //nolint:gosec // G115 {start: uint16(nextTSN - range1[1]), end: uint16(endTSN - range1[1])}, }, gaps) assert.NotEmpty(t, payloadQueue.getGapAckBlocksString()) } func TestBitfunc(t *testing.T) { idx, ok := getFirstNonZeroBit(0xf, 0, 20) assert.True(t, ok) assert.Equal(t, 0, idx) _, ok = getFirstNonZeroBit(0xf<<20, 0, 20) assert.False(t, ok) idx, ok = getFirstNonZeroBit(0xf<<20, 5, 25) assert.True(t, ok) assert.Equal(t, 20, idx) _, ok = getFirstNonZeroBit(0xf<<20, 30, 40) assert.False(t, ok) _, ok = getFirstNonZeroBit(0, 0, 64) assert.False(t, ok) idx, ok = getFirstZeroBit(0xf, 0, 20) assert.True(t, ok) assert.Equal(t, 4, idx) idx, ok = getFirstZeroBit(0xf<<20, 0, 20) assert.True(t, ok) assert.Equal(t, 0, idx) _, ok = getFirstZeroBit(0xf<<20, 20, 24) assert.False(t, ok) idx, ok = getFirstZeroBit(0xf<<20, 30, 40) assert.True(t, ok) assert.Equal(t, 30, idx) _, ok = getFirstZeroBit(math.MaxUint64, 0, 64) assert.False(t, ok) } sctp-1.9.0/renovate.json000066400000000000000000000001731512256410600152130ustar00rootroot00000000000000{ "$schema": "https://docs.renovatebot.com/renovate-schema.json", "extends": [ "github>pion/renovate-config" ] } sctp-1.9.0/rtx_timer.go000066400000000000000000000127711512256410600150500ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "math" "sync" "time" ) const ( // RTO.Initial in msec. rtoInitial float64 = 1.0 * 1000 // RTO.Min in msec. rtoMin float64 = 1.0 * 1000 // RTO.Max in msec. defaultRTOMax float64 = 60.0 * 1000 // RTO.Alpha. rtoAlpha float64 = 0.125 // RTO.Beta. rtoBeta float64 = 0.25 // Max.Init.Retransmits. maxInitRetrans uint = 8 // Path.Max.Retrans. pathMaxRetrans uint = 5 noMaxRetrans uint = 0 ) // rtoManager manages Rtx timeout values. // This is an implementation of RFC 4960 sec 6.3.1. type rtoManager struct { srtt float64 rttvar float64 rto float64 noUpdate bool mutex sync.RWMutex rtoMax float64 } // newRTOManager creates a new rtoManager. func newRTOManager(rtoMax float64) *rtoManager { mgr := rtoManager{ rto: rtoInitial, rtoMax: rtoMax, } if mgr.rtoMax == 0 { mgr.rtoMax = defaultRTOMax } return &mgr } // setNewRTT takes a newly measured RTT then adjust the RTO in msec. func (m *rtoManager) setNewRTT(rtt float64) float64 { m.mutex.Lock() defer m.mutex.Unlock() if m.noUpdate { return m.srtt } if m.srtt == 0 { // First measurement m.srtt = rtt m.rttvar = rtt / 2 } else { // Subsequent rtt measurement m.rttvar = (1-rtoBeta)*m.rttvar + rtoBeta*(math.Abs(m.srtt-rtt)) m.srtt = (1-rtoAlpha)*m.srtt + rtoAlpha*rtt } m.rto = math.Min(math.Max(m.srtt+4*m.rttvar, rtoMin), m.rtoMax) return m.srtt } // getRTO simply returns the current RTO in msec. func (m *rtoManager) getRTO() float64 { m.mutex.RLock() defer m.mutex.RUnlock() return m.rto } // reset resets the RTO variables to the initial values. func (m *rtoManager) reset() { m.mutex.Lock() defer m.mutex.Unlock() if m.noUpdate { return } m.srtt = 0 m.rttvar = 0 m.rto = rtoInitial } // set RTO value for testing. func (m *rtoManager) setRTO(rto float64, noUpdate bool) { m.mutex.Lock() defer m.mutex.Unlock() m.rto = rto m.noUpdate = noUpdate } // rtxTimerObserver is the inteface to a timer observer. // NOTE: Observers MUST NOT call start() or stop() method on rtxTimer // from within these callbacks. type rtxTimerObserver interface { onRetransmissionTimeout(timerID int, n uint) onRetransmissionFailure(timerID int) } type rtxTimerState uint8 const ( rtxTimerStopped rtxTimerState = iota rtxTimerStarted rtxTimerClosed ) // rtxTimer provides the retnransmission timer conforms with RFC 4960 Sec 6.3.1. type rtxTimer struct { timer *time.Timer observer rtxTimerObserver id int maxRetrans uint rtoMax float64 mutex sync.Mutex rto float64 nRtos uint state rtxTimerState pending uint8 } // newRTXTimer creates a new retransmission timer. // if maxRetrans is set to 0, it will keep retransmitting until stop() is called. // (it will never make onRetransmissionFailure() callback. func newRTXTimer(id int, observer rtxTimerObserver, maxRetrans uint, rtoMax float64, ) *rtxTimer { timer := rtxTimer{ id: id, observer: observer, maxRetrans: maxRetrans, rtoMax: rtoMax, } if timer.rtoMax == 0 { timer.rtoMax = defaultRTOMax } timer.timer = time.AfterFunc(math.MaxInt64, timer.timeout) timer.timer.Stop() return &timer } func (t *rtxTimer) calculateNextTimeout() time.Duration { timeout := calculateNextTimeout(t.rto, t.nRtos, t.rtoMax) return time.Duration(timeout) * time.Millisecond } func (t *rtxTimer) timeout() { t.mutex.Lock() if t.pending--; t.pending == 0 && t.state == rtxTimerStarted { if t.nRtos++; t.maxRetrans == 0 || t.nRtos <= t.maxRetrans { t.timer.Reset(t.calculateNextTimeout()) t.pending++ defer t.observer.onRetransmissionTimeout(t.id, t.nRtos) } else { t.state = rtxTimerStopped defer t.observer.onRetransmissionFailure(t.id) } } t.mutex.Unlock() } // start starts the timer. func (t *rtxTimer) start(rto float64) bool { t.mutex.Lock() defer t.mutex.Unlock() // this timer is already closed or aleady running if t.state != rtxTimerStopped { return false } // Note: rto value is intentionally not capped by RTO.Min to allow // fast timeout for the tests. Non-test code should pass in the // rto generated by rtoManager getRTO() method which caps the // value at RTO.Min or at RTO.Max. t.rto = rto t.nRtos = 0 t.state = rtxTimerStarted t.pending++ t.timer.Reset(t.calculateNextTimeout()) return true } // stop stops the timer. func (t *rtxTimer) stop() { t.mutex.Lock() defer t.mutex.Unlock() if t.state == rtxTimerStarted { if t.timer.Stop() { t.pending-- } t.state = rtxTimerStopped } } // closes the timer. this is similar to stop() but subsequent start() call // will fail (the timer is no longer usable). func (t *rtxTimer) close() { t.mutex.Lock() defer t.mutex.Unlock() if t.state == rtxTimerStarted && t.timer.Stop() { t.pending-- } t.state = rtxTimerClosed } // isRunning tests if the timer is running. // Debug purpose only. func (t *rtxTimer) isRunning() bool { t.mutex.Lock() defer t.mutex.Unlock() return t.state == rtxTimerStarted } func calculateNextTimeout(rto float64, nRtos uint, rtoMax float64) float64 { // RFC 4096 sec 6.3.3. Handle T3-rtx Expiration // E2) For the destination address for which the timer expires, set RTO // <- RTO * 2 ("back off the timer"). The maximum value discussed // in rule C7 above (RTO.max) may be used to provide an upper bound // to this doubling operation. if nRtos < 31 { m := 1 << nRtos return math.Min(rto*float64(m), rtoMax) } return rtoMax } sctp-1.9.0/rtx_timer_test.go000066400000000000000000000252151512256410600161040ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "math" "sync/atomic" "testing" "time" "github.com/stretchr/testify/assert" ) func TestRTOManager(t *testing.T) { t.Run("initial values", func(t *testing.T) { m := newRTOManager(0) assert.Equal(t, rtoInitial, m.rto, "should be rtoInitial") assert.Equal(t, rtoInitial, m.getRTO(), "should be rtoInitial") assert.Equal(t, float64(0), m.srtt, "should be 0") assert.Equal(t, float64(0), m.rttvar, "should be 0") }) t.Run("RTO calculation (small RTT)", func(t *testing.T) { var rto float64 m := newRTOManager(0) exp := []int32{ 1800, 1500, 1275, 1106, 1000, // capped at RTO.Min } for i := 0; i < 5; i++ { m.setNewRTT(600) rto = m.getRTO() assert.Equal(t, exp[i], int32(math.Floor(rto)), "should be equal") } }) t.Run("RTO calculation (large RTT)", func(t *testing.T) { var rto float64 m := newRTOManager(0) exp := []int32{ 60000, // capped at RTO.Max 60000, // capped at RTO.Max 60000, // capped at RTO.Max 55312, 48984, } for i := 0; i < 5; i++ { m.setNewRTT(30000) rto = m.getRTO() assert.Equal(t, exp[i], int32(math.Floor(rto)), "should be equal") } }) t.Run("calculateNextTimeout", func(t *testing.T) { var rto float64 rto = calculateNextTimeout(1.0, 0, defaultRTOMax) assert.Equal(t, float64(1), rto, "should match") rto = calculateNextTimeout(1.0, 1, defaultRTOMax) assert.Equal(t, float64(2), rto, "should match") rto = calculateNextTimeout(1.0, 2, defaultRTOMax) assert.Equal(t, float64(4), rto, "should match") rto = calculateNextTimeout(1.0, 30, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") rto = calculateNextTimeout(1.0, 63, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") rto = calculateNextTimeout(1.0, 64, defaultRTOMax) assert.Equal(t, float64(60000), rto, "should match") }) t.Run("calculateNextTimeout w/ RTOMax", func(t *testing.T) { var rto float64 rto = calculateNextTimeout(1.0, 0, 2.0) assert.Equal(t, 1.0, rto, "should match") rto = calculateNextTimeout(1.5, 1, 2.0) assert.Equal(t, 2.0, rto, "should match") rto = calculateNextTimeout(1.0, 10, 2.0) assert.Equal(t, 2.0, rto, "should match") rto = calculateNextTimeout(1.0, 31, 1000.0) assert.Equal(t, 1000.0, rto, "should match") }) t.Run("reset", func(t *testing.T) { m := newRTOManager(0) for i := 0; i < 10; i++ { m.setNewRTT(200) } m.reset() assert.Equal(t, rtoInitial, m.getRTO(), "should be rtoInitial") assert.Equal(t, float64(0), m.srtt, "should be 0") assert.Equal(t, float64(0), m.rttvar, "should be 0") }) } type ( onRTO func(id int, n uint) onRtxFailure func(id int) ) type testTimerObserver struct { onRTO onRTO onRtxFailure onRtxFailure } func (o *testTimerObserver) onRetransmissionTimeout(id int, n uint) { o.onRTO(id, n) } func (o *testTimerObserver) onRetransmissionFailure(id int) { o.onRtxFailure(id) } func TestRtxTimer(t *testing.T) { //nolint:maintidx t.Run("callback interval", func(t *testing.T) { timerID := 0 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) // 30 : 1 (30) // 60 : 2 (90) // 120: 3 (210) // 240: 4 (550) <== expected in 650 msec assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) assert.False(t, rt.isRunning(), "should not be running") // since := time.Now() ok := rt.start(30) assert.True(t, ok, "should be true") assert.True(t, rt.isRunning(), "should be running") time.Sleep(650 * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(4), atomic.LoadInt32(&nCbs), "should be called 4 times") }) t.Run("last start wins", func(t *testing.T) { timerID := 3 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") ok = rt.start(interval * 99) // should ignored assert.False(t, ok, "should be ignored") ok = rt.start(interval * 99) // should ignored assert.False(t, ok, "should be ignored") time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs), "must be called once") }) t.Run("stop right afeter start", func(t *testing.T) { timerID := 3 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") rt.stop() time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs), "no callback should be made") }) t.Run("start, stop then start", func(t *testing.T) { timerID := 1 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) interval := float64(30.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") ok = rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") time.Sleep(time.Duration(interval*1.5) * time.Millisecond) rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") assert.Equal(t, int32(1), atomic.LoadInt32(&nCbs), "must be called once") }) t.Run("start and stop in a tight loop", func(t *testing.T) { timerID := 2 var nCbs int32 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { atomic.AddInt32(&nCbs, 1) t.Log("onRTO() called") assert.Equalf(t, timerID, id, "unexpted timer ID: %d", id) }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) for i := 0; i < 1000; i++ { ok := rt.start(30) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") rt.stop() assert.False(t, rt.isRunning(), "should NOT be running") } assert.Equal(t, int32(0), atomic.LoadInt32(&nCbs), "no callback should be made") }) t.Run("timer should stop after rtx failure", func(t *testing.T) { timerID := 4 var nCbs int32 doneCh := make(chan bool) since := time.Now() var elapsed float64 // in seconds rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, time.Since(since).Seconds()) atomic.AddInt32(&nCbs, 1) }, onRtxFailure: func(id int) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) elapsed = time.Since(since).Seconds() t.Logf("onRtxFailure: elapsed=%.03f\n", elapsed) doneCh <- true }, }, pathMaxRetrans, 0) // RTO(msec) Total(msec) // 10 10 1st RTO // 20 30 2nd RTO // 40 70 3rd RTO // 80 150 4th RTO // 160 310 5th RTO (== Path.Max.Retrans) // 320 630 Failure interval := float64(10.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") <-doneCh assert.False(t, rt.isRunning(), "should not be running") assert.Equal(t, int32(5), atomic.LoadInt32(&nCbs), "should be called 5 times") assert.True(t, elapsed > 0.600, "must have taken more than 600 msec") assert.True(t, elapsed < 0.700, "must fail in less than 700 msec") }) t.Run("timer should not stop if maxRetrans is 0", func(t *testing.T) { timerID := 4 maxRtos := uint(6) var nCbs int32 doneCh := make(chan bool) since := time.Now() var elapsed float64 // in seconds rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, nRtos uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) elapsed = time.Since(since).Seconds() t.Logf("onRTO: n=%d elapsed=%.03f\n", nRtos, elapsed) atomic.AddInt32(&nCbs, 1) if nRtos == maxRtos { doneCh <- true } }, onRtxFailure: func(_ int) { assert.Fail(t, "timer should not fail") }, }, 0, 0) // RTO(msec) Total(msec) // 10 10 1st RTO // 20 30 2nd RTO // 40 70 3rd RTO // 80 150 4th RTO // 160 310 5th RTO // 320 630 6th RTO => exit test (timer should still be running) interval := float64(10.0) ok := rt.start(interval) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "should be running") <-doneCh assert.True(t, rt.isRunning(), "should still be running") assert.Equal(t, int32(6), atomic.LoadInt32(&nCbs), "should be called 6 times") assert.True(t, elapsed > 0.600, "must have taken more than 600 msec") assert.True(t, elapsed < 0.700, "must fail in less than 700 msec") rt.stop() }) t.Run("stop timer that is not running is noop", func(t *testing.T) { timerID := 5 doneCh := make(chan bool) rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(id int, _ uint) { assert.Equal(t, timerID, id, "unexpted timer ID: %d", id) doneCh <- true }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) for i := 0; i < 10; i++ { rt.stop() } ok := rt.start(20) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "must be running") <-doneCh rt.stop() assert.False(t, rt.isRunning(), "must be false") }) t.Run("closed timer won't start", func(t *testing.T) { var rtoCount int timerID := 6 rt := newRTXTimer(timerID, &testTimerObserver{ onRTO: func(_ int, _ uint) { rtoCount++ }, onRtxFailure: func(_ int) {}, }, pathMaxRetrans, 0) ok := rt.start(20) assert.True(t, ok, "should be accepted") assert.True(t, rt.isRunning(), "must be running") rt.close() assert.False(t, rt.isRunning(), "must be false") ok = rt.start(20) assert.False(t, ok, "should not start") assert.False(t, rt.isRunning(), "must not be running") time.Sleep(100 * time.Millisecond) assert.Equal(t, 0, rtoCount, "RTO should not occur") }) } sctp-1.9.0/sctp.go000066400000000000000000000002341512256410600137730ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT // Package sctp implements the SCTP spec package sctp sctp-1.9.0/stream.go000066400000000000000000000330631512256410600143230ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "errors" "fmt" "io" "os" "sync" "sync/atomic" "time" "github.com/pion/logging" "github.com/pion/transport/v3/deadline" ) const ( // ReliabilityTypeReliable is used for reliable transmission. ReliabilityTypeReliable byte = 0 // ReliabilityTypeRexmit is used for partial reliability by retransmission count. ReliabilityTypeRexmit byte = 1 // ReliabilityTypeTimed is used for partial reliability by retransmission duration. ReliabilityTypeTimed byte = 2 ) // StreamState is an enum for SCTP Stream state field // This field identifies the state of stream. type StreamState int // StreamState enums. const ( StreamStateOpen StreamState = iota // Stream object starts with StreamStateOpen StreamStateClosing // Outgoing stream is being reset StreamStateClosed // Stream has been closed ) func (ss StreamState) String() string { switch ss { case StreamStateOpen: return "open" case StreamStateClosing: return "closing" case StreamStateClosed: return "closed" } return "unknown" } // SCTP stream errors. var ( ErrOutboundPacketTooLarge = errors.New("outbound packet larger than maximum message size") ErrStreamClosed = errors.New("stream closed") ErrReadDeadlineExceeded = fmt.Errorf("read deadline exceeded: %w", os.ErrDeadlineExceeded) ) // Stream represents an SCTP stream. type Stream struct { association *Association lock sync.RWMutex streamIdentifier uint16 defaultPayloadType PayloadProtocolIdentifier reassemblyQueue *reassemblyQueue sequenceNumber uint16 readNotifier *sync.Cond readErr error readTimeoutCancel chan struct{} writeDeadline *deadline.Deadline writeLock sync.Mutex unordered bool reliabilityType byte reliabilityValue uint32 bufferedAmount uint64 bufferedAmountLow uint64 onBufferedAmountLow func() state StreamState log logging.LeveledLogger name string } // StreamIdentifier returns the Stream identifier associated to the stream. func (s *Stream) StreamIdentifier() uint16 { s.lock.RLock() defer s.lock.RUnlock() return s.streamIdentifier } // SetDefaultPayloadType sets the default payload type used by Write. func (s *Stream) SetDefaultPayloadType(defaultPayloadType PayloadProtocolIdentifier) { atomic.StoreUint32((*uint32)(&s.defaultPayloadType), uint32(defaultPayloadType)) } // SetReliabilityParams sets reliability parameters for this stream. func (s *Stream) SetReliabilityParams(unordered bool, relType byte, relVal uint32) { s.lock.Lock() defer s.lock.Unlock() s.setReliabilityParams(unordered, relType, relVal) } // setReliabilityParams sets reliability parameters for this stream. // The caller should hold the lock. func (s *Stream) setReliabilityParams(unordered bool, relType byte, relVal uint32) { s.log.Debugf("[%s] reliability params: ordered=%v type=%d value=%d", s.name, !unordered, relType, relVal) s.unordered = unordered s.reliabilityType = relType s.reliabilityValue = relVal } // Read reads a packet of len(p) bytes, dropping the Payload Protocol Identifier. // Returns EOF when the stream is reset or an error if the stream is closed // otherwise. func (s *Stream) Read(p []byte) (int, error) { n, _, err := s.ReadSCTP(p) return n, err } // ReadSCTP reads a packet of len(payload) bytes and returns the associated Payload // Protocol Identifier. // Returns EOF when the stream is reset or an error if the stream is closed // otherwise. func (s *Stream) ReadSCTP(payload []byte) (int, PayloadProtocolIdentifier, error) { s.lock.Lock() defer s.lock.Unlock() defer func() { // close readTimeoutCancel if the current read timeout routine is no longer effective if s.readTimeoutCancel != nil && s.readErr != nil { close(s.readTimeoutCancel) s.readTimeoutCancel = nil } }() for { n, ppi, err := s.reassemblyQueue.read(payload) if err == nil || errors.Is(err, io.ErrShortBuffer) { return n, ppi, err } if s.readErr != nil { return 0, PayloadProtocolIdentifier(0), s.readErr } s.readNotifier.Wait() } } // SetReadDeadline sets the read deadline in an identical way to net.Conn. func (s *Stream) SetReadDeadline(deadline time.Time) error { s.lock.Lock() defer s.lock.Unlock() if s.readTimeoutCancel != nil { close(s.readTimeoutCancel) s.readTimeoutCancel = nil } if s.readErr != nil { if !errors.Is(s.readErr, ErrReadDeadlineExceeded) { return nil } s.readErr = nil } if !deadline.IsZero() { s.readTimeoutCancel = make(chan struct{}) go func(readTimeoutCancel chan struct{}) { t := time.NewTimer(time.Until(deadline)) select { case <-readTimeoutCancel: t.Stop() return case <-t.C: select { case <-readTimeoutCancel: return default: } s.lock.Lock() if s.readErr == nil { s.readErr = ErrReadDeadlineExceeded } s.readTimeoutCancel = nil s.lock.Unlock() s.readNotifier.Signal() } }(s.readTimeoutCancel) } return nil } func (s *Stream) handleData(pd *chunkPayloadData) { s.lock.Lock() defer s.lock.Unlock() var readable bool if s.reassemblyQueue.push(pd) { readable = s.reassemblyQueue.isReadable() s.log.Debugf("[%s] reassemblyQueue readable=%v", s.name, readable) if readable { s.log.Debugf("[%s] readNotifier.signal()", s.name) s.readNotifier.Signal() s.log.Debugf("[%s] readNotifier.signal() done", s.name) } } } func (s *Stream) handleForwardTSNForOrdered(ssn uint16) { var readable bool func() { s.lock.Lock() defer s.lock.Unlock() if s.unordered { return // unordered chunks are handled by handleForwardUnordered method } // Remove all chunks older than or equal to the new TSN from // the reassemblyQueue. s.reassemblyQueue.forwardTSNForOrdered(ssn) readable = s.reassemblyQueue.isReadable() }() // Notify the reader asynchronously if there's a data chunk to read. if readable { s.readNotifier.Signal() } } func (s *Stream) handleForwardTSNForUnordered(newCumulativeTSN uint32) { var readable bool func() { s.lock.Lock() defer s.lock.Unlock() if !s.unordered { return // ordered chunks are handled by handleForwardTSNOrdered method } // Remove all chunks older than or equal to the new TSN from // the reassemblyQueue. s.reassemblyQueue.forwardTSNForUnordered(newCumulativeTSN) readable = s.reassemblyQueue.isReadable() }() // Notify the reader asynchronously if there's a data chunk to read. if readable { s.readNotifier.Signal() } } // Write writes len(payload) bytes from payload with the default Payload Protocol Identifier. func (s *Stream) Write(payload []byte) (n int, err error) { ppi := PayloadProtocolIdentifier(atomic.LoadUint32((*uint32)(&s.defaultPayloadType))) return s.WriteSCTP(payload, ppi) } // WriteSCTP writes len(payload) bytes from payload to the DTLS connection. func (s *Stream) WriteSCTP(payload []byte, ppi PayloadProtocolIdentifier) (int, error) { maxMessageSize := s.association.MaxMessageSize() if len(payload) > int(maxMessageSize) { return 0, fmt.Errorf("%w: %v", ErrOutboundPacketTooLarge, maxMessageSize) } if s.State() != StreamStateOpen { return 0, ErrStreamClosed } // the send could fail if the association is blocked for writing (timeout), it will left a hole // in the stream sequence number space, so we need to lock the write to avoid concurrent send and decrement // the sequence number in case of failure if s.association.isBlockWrite() { s.writeLock.Lock() } chunks, unordered := s.packetize(payload, ppi) n := len(payload) err := s.association.sendPayloadData(s.writeDeadline, chunks) if err != nil { s.lock.Lock() s.bufferedAmount -= uint64(n) if !unordered { s.sequenceNumber-- } s.lock.Unlock() n = 0 } if s.association.isBlockWrite() { s.writeLock.Unlock() } return n, err } // SetWriteDeadline sets the write deadline in an identical way to net.Conn, // it will only work for blocking writes. func (s *Stream) SetWriteDeadline(deadline time.Time) error { s.writeDeadline.Set(deadline) return nil } // SetDeadline sets the read and write deadlines in an identical way to net.Conn. func (s *Stream) SetDeadline(t time.Time) error { if err := s.SetReadDeadline(t); err != nil { return err } return s.SetWriteDeadline(t) } func (s *Stream) packetize(raw []byte, ppi PayloadProtocolIdentifier) ([]*chunkPayloadData, bool) { s.lock.Lock() defer s.lock.Unlock() offset := uint32(0) remaining := uint32(len(raw)) //nolint:gosec // G115 // From draft-ietf-rtcweb-data-protocol-09, section 6: // All Data Channel Establishment Protocol messages MUST be sent using // ordered delivery and reliable transmission. unordered := ppi != PayloadTypeWebRTCDCEP && s.unordered var chunks []*chunkPayloadData var head *chunkPayloadData for remaining != 0 { fragmentSize := min32(s.association.maxPayloadSize, remaining) // Copy the userdata since we'll have to store it until acked // and the caller may re-use the buffer in the mean time userData := make([]byte, fragmentSize) copy(userData, raw[offset:offset+fragmentSize]) chunk := &chunkPayloadData{ streamIdentifier: s.streamIdentifier, userData: userData, unordered: unordered, beginningFragment: offset == 0, endingFragment: remaining-fragmentSize == 0, immediateSack: false, payloadType: ppi, streamSequenceNumber: s.sequenceNumber, head: head, } if head == nil { head = chunk } chunks = append(chunks, chunk) remaining -= fragmentSize offset += fragmentSize } // RFC 4960 Sec 6.6 // Note: When transmitting ordered and unordered data, an endpoint does // not increment its Stream Sequence Number when transmitting a DATA // chunk with U flag set to 1. if !unordered { s.sequenceNumber++ } s.bufferedAmount += uint64(len(raw)) s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) return chunks, unordered } // Close closes the write-direction of the stream. // Future calls to Write are not permitted after calling Close. func (s *Stream) Close() error { if sid, resetOutbound := func() (uint16, bool) { s.lock.Lock() defer s.lock.Unlock() s.log.Debugf("[%s] Close: state=%s", s.name, s.state.String()) if s.state == StreamStateOpen { if s.readErr == nil { s.state = StreamStateClosing } else { s.state = StreamStateClosed } s.log.Debugf("[%s] state change: open => %s", s.name, s.state.String()) return s.streamIdentifier, true } return s.streamIdentifier, false }(); resetOutbound { // Reset the outgoing stream // https://tools.ietf.org/html/rfc6525 return s.association.sendResetRequest(sid) } return nil } // BufferedAmount returns the number of bytes of data currently queued to be sent over this stream. func (s *Stream) BufferedAmount() uint64 { s.lock.RLock() defer s.lock.RUnlock() return s.bufferedAmount } // BufferedAmountLowThreshold returns the number of bytes of buffered outgoing data that is // considered "low." Defaults to 0. func (s *Stream) BufferedAmountLowThreshold() uint64 { s.lock.RLock() defer s.lock.RUnlock() return s.bufferedAmountLow } // SetBufferedAmountLowThreshold is used to update the threshold. // See BufferedAmountLowThreshold(). func (s *Stream) SetBufferedAmountLowThreshold(th uint64) { s.lock.Lock() defer s.lock.Unlock() s.bufferedAmountLow = th } // OnBufferedAmountLow sets the callback handler which would be called when the number of // bytes of outgoing data buffered is lower than the threshold. func (s *Stream) OnBufferedAmountLow(f func()) { s.lock.Lock() defer s.lock.Unlock() s.onBufferedAmountLow = f } // This method is called by association's readLoop (go-)routine to notify this stream // of the specified amount of outgoing data has been delivered to the peer. func (s *Stream) onBufferReleased(nBytesReleased int) { if nBytesReleased <= 0 { return } s.lock.Lock() fromAmount := s.bufferedAmount if s.bufferedAmount < uint64(nBytesReleased) { s.bufferedAmount = 0 s.log.Errorf("[%s] released buffer size %d should be <= %d", s.name, nBytesReleased, s.bufferedAmount) } else { s.bufferedAmount -= uint64(nBytesReleased) } s.log.Tracef("[%s] bufferedAmount = %d", s.name, s.bufferedAmount) if s.onBufferedAmountLow != nil && fromAmount > s.bufferedAmountLow && s.bufferedAmount <= s.bufferedAmountLow { f := s.onBufferedAmountLow s.lock.Unlock() f() return } s.lock.Unlock() } func (s *Stream) getNumBytesInReassemblyQueue() int { // No lock is required as it reads the size with atomic load function. return s.reassemblyQueue.getNumBytes() } func (s *Stream) onInboundStreamReset() { s.lock.Lock() defer s.lock.Unlock() s.log.Debugf("[%s] onInboundStreamReset: state=%s", s.name, s.state.String()) // No more inbound data to read. Unblock the read with io.EOF. // This should cause DCEP layer (datachannel package) to call Close() which // will reset outgoing stream also. // See RFC 8831 section 6.7: // if one side decides to close the data channel, it resets the corresponding // outgoing stream. When the peer sees that an incoming stream was // reset, it also resets its corresponding outgoing stream. Once this // is completed, the data channel is closed. s.readErr = io.EOF s.readNotifier.Broadcast() if s.state == StreamStateClosing { s.log.Debugf("[%s] state change: closing => closed", s.name) s.state = StreamStateClosed } } // State return the stream state. func (s *Stream) State() StreamState { s.lock.RLock() defer s.lock.RUnlock() return s.state } sctp-1.9.0/stream_test.go000066400000000000000000000045721512256410600153650ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/pion/logging" "github.com/stretchr/testify/assert" ) func TestSessionBufferedAmount(t *testing.T) { t.Run("bufferedAmount", func(t *testing.T) { s := &Stream{ log: logging.NewDefaultLoggerFactory().NewLogger("sctp-test"), } assert.Equal(t, uint64(0), s.BufferedAmount()) assert.Equal(t, uint64(0), s.BufferedAmountLowThreshold()) s.bufferedAmount = 8192 s.SetBufferedAmountLowThreshold(2048) assert.Equal(t, uint64(8192), s.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, uint64(2048), s.BufferedAmountLowThreshold(), "unexpected threshold") }) t.Run("OnBufferedAmountLow", func(t *testing.T) { stream := &Stream{ log: logging.NewDefaultLoggerFactory().NewLogger("sctp-test"), } stream.bufferedAmount = 4096 stream.SetBufferedAmountLowThreshold(2048) nCbs := 0 stream.OnBufferedAmountLow(func() { nCbs++ }) // Negative value should be ignored (by design) stream.onBufferReleased(-32) // bufferedAmount = 3072 assert.Equal(t, uint64(4096), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to above, no callback stream.onBufferReleased(1024) // bufferedAmount = 3072 assert.Equal(t, uint64(3072), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 0, nCbs, "callback count mismatch") // Above to equal, callback should be made stream.onBufferReleased(1024) // bufferedAmount = 2048 assert.Equal(t, uint64(2048), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Eaual to below, no callback stream.onBufferReleased(1024) // bufferedAmount = 1024 assert.Equal(t, uint64(1024), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Blow to below, no callback stream.onBufferReleased(1024) // bufferedAmount = 0 assert.Equal(t, uint64(0), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") // Capped at 0, no callback stream.onBufferReleased(1024) // bufferedAmount = 0 assert.Equal(t, uint64(0), stream.BufferedAmount(), "unexpected bufferedAmount") assert.Equal(t, 1, nCbs, "callback count mismatch") }) } sctp-1.9.0/util.go000066400000000000000000000023531512256410600140030ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp const ( paddingMultiple = 4 ) func getPadding(l int) int { return (paddingMultiple - (l % paddingMultiple)) % paddingMultiple } func padByte(in []byte, cnt int) []byte { if cnt < 0 { cnt = 0 } padding := make([]byte, cnt) return append(in, padding...) } // Serial Number Arithmetic (RFC 1982). func sna32LT(i1, i2 uint32) bool { return (i1 < i2 && i2-i1 < 1<<31) || (i1 > i2 && i1-i2 > 1<<31) } func sna32LTE(i1, i2 uint32) bool { return i1 == i2 || sna32LT(i1, i2) } func sna32GT(i1, i2 uint32) bool { return (i1 < i2 && (i2-i1) >= 1<<31) || (i1 > i2 && (i1-i2) <= 1<<31) } func sna32GTE(i1, i2 uint32) bool { return i1 == i2 || sna32GT(i1, i2) } func sna32EQ(i1, i2 uint32) bool { return i1 == i2 } func sna16LT(i1, i2 uint16) bool { return (i1 < i2 && (i2-i1) < 1<<15) || (i1 > i2 && (i1-i2) > 1<<15) } func sna16LTE(i1, i2 uint16) bool { return i1 == i2 || sna16LT(i1, i2) } func sna16GT(i1, i2 uint16) bool { return (i1 < i2 && (i2-i1) >= 1<<15) || (i1 > i2 && (i1-i2) <= 1<<15) } func sna16GTE(i1, i2 uint16) bool { return i1 == i2 || sna16GT(i1, i2) } func sna16EQ(i1, i2 uint16) bool { return i1 == i2 } sctp-1.9.0/util_test.go000066400000000000000000000107131512256410600150410ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "github.com/stretchr/testify/assert" ) func TestPadByte_Success(t *testing.T) { tt := []struct { value []byte padLen int expected []byte }{ {[]byte{0x1, 0x2}, 0, []byte{0x1, 0x2}}, {[]byte{0x1, 0x2}, 1, []byte{0x1, 0x2, 0x0}}, {[]byte{0x1, 0x2}, 2, []byte{0x1, 0x2, 0x0, 0x0}}, {[]byte{0x1, 0x2}, 3, []byte{0x1, 0x2, 0x0, 0x0, 0x0}}, {[]byte{0x1, 0x2}, -1, []byte{0x1, 0x2}}, } for i, tc := range tt { actual := padByte(tc.value, tc.padLen) assert.Equalf(t, tc.expected, actual, "test %d not equal", i) } } func TestSerialNumberArithmetic(t *testing.T) { const div int = 16 t.Run("32-bit", func(t *testing.T) { // nolint:dupl const serialBits uint32 = 32 const interval uint32 = uint32((uint64(1) << uint64(serialBits)) / uint64(div)) const maxForwardDistance uint32 = 1<<(serialBits-1) - 1 const maxBackwardDistance uint32 = 1 << (serialBits - 1) for i := uint32(0); i < uint32(div); i++ { s1 := i * interval s2f := s1 + maxForwardDistance s2b := s1 + maxBackwardDistance assert.Truef(t, sna32LT(s1, s2f), "s1 < s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.Falsef(t, sna32LT(s1, s2b), "s1 < s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.Falsef(t, sna32GT(s1, s2f), "s1 > s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.Truef(t, sna32GT(s1, s2b), "s1 > s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.Truef(t, sna32LTE(s1, s2f), "s1 <= s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.Falsef(t, sna32LTE(s1, s2b), "s1 <= s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.Falsef(t, sna32GTE(s1, s2f), "s1 >= s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.Truef(t, sna32GTE(s1, s2b), "s1 >= s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.Truef(t, sna32EQ(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna32EQ(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.Falsef(t, sna32EQ(s1, s1+1), "s1 == s1+1 should be false: s1=0x%x s1+1=0x%x", s1, s1+1) assert.Falsef(t, sna32EQ(s1, s1-1), "s1 == s1-1 hould be false: s1=0x%x s1-1=0x%x", s1, s1-1) assert.Truef(t, sna32LTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna32LTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.Truef(t, sna32GTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna32GTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) } }) t.Run("16-bit", func(t *testing.T) { // nolint:dupl const serialBits uint16 = 16 const interval uint16 = uint16((uint64(1) << uint64(serialBits)) / uint64(div)) const maxForwardDistance uint16 = 1<<(serialBits-1) - 1 const maxBackwardDistance uint16 = 1 << (serialBits - 1) for i := uint16(0); i < uint16(div); i++ { s1 := i * interval s2f := s1 + maxForwardDistance s2b := s1 + maxBackwardDistance assert.Truef(t, sna16LT(s1, s2f), "s1 < s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.Falsef(t, sna16LT(s1, s2b), "s1 < s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.Falsef(t, sna16GT(s1, s2f), "s1 > s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.Truef(t, sna16GT(s1, s2b), "s1 > s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.Truef(t, sna16LTE(s1, s2f), "s1 <= s2 should be true: s1=0x%x s2=0x%x", s1, s2f) assert.Falsef(t, sna16LTE(s1, s2b), "s1 <= s2 should be false: s1=0x%x s2=0x%x", s1, s2b) assert.Falsef(t, sna16GTE(s1, s2f), "s1 >= s2 should be fales: s1=0x%x s2=0x%x", s1, s2f) assert.Truef(t, sna16GTE(s1, s2b), "s1 >= s2 should be true: s1=0x%x s2=0x%x", s1, s2b) assert.Truef(t, sna16EQ(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna16EQ(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.Falsef(t, sna16EQ(s1, s1+1), "s1 == s1+1 should be false: s1=0x%x s1+1=0x%x", s1, s1+1) assert.Falsef(t, sna16EQ(s1, s1-1), "s1 == s1-1 hould be false: s1=0x%x s1-1=0x%x", s1, s1-1) assert.Truef(t, sna16LTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna16LTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) assert.Truef(t, sna16GTE(s1, s1), "s1 == s1 should be true: s1=0x%x s2=0x%x", s1, s1) assert.Truef(t, sna16GTE(s2b, s2b), "s2 == s2 should be true: s2=0x%x s2=0x%x", s2b, s2b) } }) } sctp-1.9.0/vnet_test.go000066400000000000000000000557771512256410600150630ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "bytes" crand "crypto/rand" "fmt" "net" "reflect" "sync/atomic" "testing" "time" "github.com/pion/logging" "github.com/pion/transport/v3/test" "github.com/pion/transport/v3/vnet" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" ) type vNetEnvConfig struct { minDelay time.Duration loggerFactory logging.LoggerFactory log logging.LeveledLogger } type vNetEnv struct { wan *vnet.Router net0 *vnet.Net net1 *vnet.Net numToDropData int numToDropReconfig int numToDropCookieEcho int numToDropCookieAck int } func (venv *vNetEnv) dropNextDataChunk(numToDrop int) { venv.numToDropData = numToDrop } func (venv *vNetEnv) dropNextReconfigChunk(numToDrop int) { // nolint:unused venv.numToDropReconfig = numToDrop } func (venv *vNetEnv) dropNextCookieEchoChunk(numToDrop int) { venv.numToDropCookieEcho = numToDrop } func (venv *vNetEnv) dropNextCookieAckChunk(numToDrop int) { venv.numToDropCookieAck = numToDrop } func buildVNetEnv(t *testing.T, cfg *vNetEnvConfig) (*vNetEnv, error) { //nolint:cyclop t.Helper() log := cfg.log var venv *vNetEnv serverIP := "1.1.1.1" clientIP := "2.2.2.2" wan, err := vnet.NewRouter(&vnet.RouterConfig{ CIDR: "0.0.0.0/0", MinDelay: cfg.minDelay, MaxJitter: 0 * time.Millisecond, LoggerFactory: cfg.loggerFactory, }) if err != nil { return nil, err } tsnAutoLockOnFilter := func() func(vnet.Chunk) bool { var lockedOnTSN bool var tsn uint32 return func(c vnet.Chunk) bool { var toDrop bool p := &packet{} assert.NoError(t, p.unmarshal(true, c.UserData())) loop: for i := 0; i < len(p.chunks); i++ { switch chunk := p.chunks[i].(type) { case *chunkPayloadData: if venv.numToDropData > 0 { if !lockedOnTSN { tsn = chunk.tsn lockedOnTSN = true log.Infof("Chunk filter: lock on TSN %d", tsn) } if chunk.tsn == tsn { toDrop = true venv.numToDropData-- log.Infof("Chunk filter: drop TSN %d", tsn) break loop } } case *chunkReconfig: if venv.numToDropReconfig > 0 { toDrop = true venv.numToDropReconfig-- log.Infof("Chunk filter: drop RECONFIG %s", chunk.String()) break loop } case *chunkCookieEcho: if venv.numToDropCookieEcho > 0 { toDrop = true venv.numToDropCookieEcho-- log.Infof("Chunk filter: drop %s", chunk.String()) break loop } case *chunkCookieAck: if venv.numToDropCookieAck > 0 { toDrop = true venv.numToDropCookieAck-- log.Infof("Chunk filter: drop %s", chunk.String()) break loop } } } return !toDrop } } wan.AddChunkFilter(tsnAutoLockOnFilter()) net0, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{serverIP}, }) if err != nil { return nil, err } err = wan.AddNet(net0) if err != nil { return nil, err } net1, err := vnet.NewNet(&vnet.NetConfig{ StaticIPs: []string{clientIP}, }) if err != nil { return nil, err } err = wan.AddNet(net1) if err != nil { return nil, err } err = wan.Start() if err != nil { return nil, err } venv = &vNetEnv{ wan: wan, net0: net0, net1: net1, } return venv, nil } func testRwndFull(t *testing.T, unordered bool) { //nolint:cyclop t.Helper() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(t, &vNetEnvConfig{ minDelay: 200 * time.Millisecond, loggerFactory: loggerFactory, log: log, }) require.NoError(t, err, "should succeed") require.NotNil(t, venv, "should not be nil") defer venv.wan.Stop() // nolint:errcheck serverHandshakeDone := make(chan struct{}) clientHandshakeDone := make(chan struct{}) serverStreamReady := make(chan struct{}) clientStreamReady := make(chan struct{}) clientStartWrite := make(chan struct{}) serverRecvBufFull := make(chan struct{}) serverStartRead := make(chan struct{}) serverReadAll := make(chan struct{}) clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) shutDownClient := make(chan struct{}) shutDownServer := make(chan struct{}) maxReceiveBufferSize := uint32(64 * 1024) msgSize := int(float32(maxReceiveBufferSize)/2) + int(initialMTU) msg := make([]byte, msgSize) n, err := crand.Read(msg) require.NoError(t, err, "failed to read random bytes") require.Equal(t, len(msg), n, "short random read") go func() { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association assoc, err := Server(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") close(serverHandshakeDone) stream, err := assoc.AcceptStream() if !assert.NoError(t, err, "should succeed") { return } defer stream.Close() // nolint:errcheck // Expunge the first HELLO packet buf := make([]byte, 64*1024) n, err := stream.Read(buf) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, "HELLO", string(buf[:n]), "should match") stream.SetReliabilityParams(unordered, ReliabilityTypeReliable, 0) log.Info("server stream ready") close(serverStreamReady) for { assoc.lock.RLock() rbufSize := assoc.getMyReceiverWindowCredit() log.Infof("rbufSize = %d", rbufSize) assoc.lock.RUnlock() if rbufSize == 0 { break } time.Sleep(50 * time.Millisecond) } close(serverRecvBufFull) <-serverStartRead for i := 0; i < 2; i++ { n, err = stream.Read(buf) if !assert.NoError(t, err, "should succeed") { return } if !assert.NoError(t, err, "should succeed") { return } log.Infof("server read %d bytes", n) assert.Truef(t, reflect.DeepEqual(msg, buf[:n]), "msg %d should match", i) } close(serverReadAll) <-shutDownServer log.Info("server closing") }() go func() { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return } // client association assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") close(clientHandshakeDone) stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if !assert.NoError(t, err, "should succeed") { return } defer stream.Close() // nolint:errcheck // Send a message to let server side stream to open _, err = stream.Write([]byte("HELLO")) if !assert.NoError(t, err, "should succeed") { return } stream.SetReliabilityParams(unordered, ReliabilityTypeReliable, 0) log.Info("client stream ready") close(clientStreamReady) <-clientStartWrite // Set the cwnd and rwnd to the size large enough to send the large messages // right away assoc.lock.Lock() assoc.cwnd = 2 * maxReceiveBufferSize assoc.rwnd = 2 * maxReceiveBufferSize assoc.lock.Unlock() // Send two large messages so that the second one will // cause receiver side buffer full for i := 0; i < 2; i++ { _, err = stream.Write(msg) if !assert.NoError(t, err, "should succeed") { return } } <-shutDownClient log.Info("client closing") }() // // Scenario // // wait until both handshake complete <-clientHandshakeDone <-serverHandshakeDone log.Info("handshake complete") // wait until both establish a stream <-clientStreamReady <-serverStreamReady log.Info("stream ready") // drop next 1 DATA chunk sent to the server venv.dropNextDataChunk(1) // let client begin writing log.Info("client start writing") close(clientStartWrite) // wait until the server's receive buffer becomes full <-serverRecvBufFull // let server start reading close(serverStartRead) // wait until the server receives all data log.Info("let server start reading") <-serverReadAll log.Info("server received all data") close(shutDownClient) <-clientShutDown close(shutDownServer) <-serverShutDown log.Info("all done") } func TestRwndFull(t *testing.T) { t.Run("Ordered", func(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 10) defer lim.Stop() testRwndFull(t, false) }) t.Run("Unordered", func(t *testing.T) { // Limit runtime in case of deadlocks lim := test.TimeOut(time.Second * 10) defer lim.Stop() testRwndFull(t, true) }) } func TestStreamClose(t *testing.T) { //nolint:cyclop loopBackTest := func(t *testing.T, dropReconfigChunk bool) { t.Helper() lim := test.TimeOut(time.Second * 10) defer lim.Stop() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(t, &vNetEnvConfig{ loggerFactory: loggerFactory, log: log, }) require.NoError(t, err, "should succeed") require.NotNil(t, venv, "should not be nil") defer venv.wan.Stop() // nolint:errcheck clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) const numMessages = 10 const messageSize = 1024 var messages [][]byte var numServerReceived int var numClientReceived int for i := 0; i < numMessages; i++ { bytes := make([]byte, messageSize) messages = append(messages, bytes) } go func() { defer close(serverShutDown) // connected UDP conn for server conn, innerErr := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, innerErr, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association assoc, innerErr := Server(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if !assert.NoError(t, innerErr, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") stream, innerErr := assoc.AcceptStream() if !assert.NoError(t, innerErr, "should succeed") { return } assert.Equal(t, StreamStateOpen, stream.State()) buf := make([]byte, 1500) for { n, errRead := stream.Read(buf) if errRead != nil { log.Infof("server: Read returned %v", errRead) _ = stream.Close() // nolint:errcheck assert.Equal(t, StreamStateClosed, stream.State()) break } log.Infof("server: received %d bytes (%d)", n, numServerReceived) assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numServerReceived]), "should receive HELLO") _, err2 := stream.Write(buf[:n]) assert.NoError(t, err2, "should succeed") numServerReceived++ } // don't close association until the client's stream routine is complete <-clientShutDown }() // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // client association assoc, err := Client(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if !assert.NoError(t, err, "should succeed") { return } assert.Equal(t, StreamStateOpen, stream.State()) stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) // begin client read-loop buf := make([]byte, 1500) go func() { defer close(clientShutDown) for { n, err2 := stream.Read(buf) if err2 != nil { log.Infof("client: Read returned %v", err2) assert.Equal(t, StreamStateClosed, stream.State()) break } log.Infof("client: received %d bytes (%d)", n, numClientReceived) assert.Equal(t, 0, bytes.Compare(buf[:n], messages[numClientReceived]), "should receive HELLO") numClientReceived++ } }() // Send messages to the server for i := 0; i < numMessages; i++ { _, err = stream.Write(messages[i]) assert.NoError(t, err, "should succeed") } if dropReconfigChunk { venv.dropNextReconfigChunk(1) } // Immediately close the stream err = stream.Close() assert.NoError(t, err, "should succeed") assert.Equal(t, StreamStateClosing, stream.State()) log.Info("client wait for exit reading..") <-clientShutDown assert.Equal(t, numMessages, numServerReceived, "all messages should be received") assert.Equal(t, numMessages, numClientReceived, "all messages should be received") _, err = stream.Write([]byte{1}) assert.Equal(t, err, ErrStreamClosed, "after closed should not allow write") // Check if RECONFIG was actually dropped assert.Equal(t, 0, venv.numToDropReconfig, "should be zero") // Sleep enough time for reconfig response to come back time.Sleep(100 * time.Millisecond) // Verify there's no more pending reconfig assoc.lock.RLock() pendingReconfigs := len(assoc.reconfigs) assoc.lock.RUnlock() assert.Equal(t, 0, pendingReconfigs, "should be zero") } t.Run("without dropping Reconfig", func(t *testing.T) { loopBackTest(t, false) }) t.Run("with dropping Reconfig", func(t *testing.T) { loopBackTest(t, true) }) } // this test case reproduces the issue mentioned in // https://github.com/pion/webrtc/issues/1270#issuecomment-653953743 // and confirmes the fix. // To reproduce the case mentioned above: // * Use simultaneous-open (SCTP) // * Drop both of the first COOKIE-ECHO and COOKIE-ACK. func TestCookieEchoRetransmission(t *testing.T) { lim := test.TimeOut(time.Second * 10) defer lim.Stop() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test") venv, err := buildVNetEnv(t, &vNetEnvConfig{ minDelay: 200 * time.Millisecond, loggerFactory: loggerFactory, log: log, }) require.NoError(t, err, "should succeed") require.NotNil(t, venv, "should not be nil") defer venv.wan.Stop() // nolint:errcheck // To cause the cookie echo retransmission, both COOKIE-ECHO // and COOKIE-ACK chunks need to be dropped at the same time. venv.dropNextCookieEchoChunk(1) venv.dropNextCookieAckChunk(1) serverHandshakeDone := make(chan struct{}) clientHandshakeDone := make(chan struct{}) waitAllHandshakeDone := make(chan struct{}) clientShutDown := make(chan struct{}) serverShutDown := make(chan struct{}) maxReceiveBufferSize := uint32(64 * 1024) // Go routine for Server go func() { defer close(serverShutDown) // connected UDP conn for server conn, err := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return } defer conn.Close() // nolint:errcheck // server association // using Client for simultaneous open assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("server handshake complete") close(serverHandshakeDone) <-waitAllHandshakeDone }() // Go routine for Client go func() { defer close(clientShutDown) // connected UDP conn for client conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if !assert.NoError(t, err, "should succeed") { return } // client association assoc, err := Client(Config{ NetConn: conn, MaxReceiveBufferSize: maxReceiveBufferSize, LoggerFactory: loggerFactory, }) if !assert.NoError(t, err, "should succeed") { return } defer assoc.Close() // nolint:errcheck log.Info("client handshake complete") close(clientHandshakeDone) <-waitAllHandshakeDone }() // // Scenario // // wait until both handshake complete <-clientHandshakeDone <-serverHandshakeDone close(waitAllHandshakeDone) log.Info("handshake complete") <-clientShutDown <-serverShutDown log.Info("all done") } // Simulate an RTT switch (high -> low) by delaying early DATA, then disabling delay so // later DATA arrives before earlier DATA. Under a RACK regression, rackMinRTT would never increases, // causing reoWnd to be too small and marking packets sent at high RTT as spuriously lost. func TestRACK_RTTSwitch_Reordering_NoDrop(t *testing.T) { //nolint:gocyclo,cyclop,maintidx lim := test.TimeOut(15 * time.Second) defer lim.Stop() loggerFactory := logging.NewDefaultLoggerFactory() log := loggerFactory.NewLogger("test-rack-rtt-switch") venv, err := buildVNetEnv(t, &vNetEnvConfig{ minDelay: 0, loggerFactory: loggerFactory, log: log, }) require.NoError(t, err) require.NotNil(t, venv) defer venv.wan.Stop() // nolint:errcheck var delayOn atomic.Value delayOn.Store(true) venv.wan.AddChunkFilter(func(c vnet.Chunk) bool { p := &packet{} if err := p.unmarshal(true, c.UserData()); err != nil { return true } v := delayOn.Load() if val, ok := v.(bool); ok && !val { return true } for i := 0; i < len(p.chunks); i++ { if _, ok := p.chunks[i].(*chunkPayloadData); ok { time.Sleep(100 * time.Millisecond) break } } return true }) const ( numMessages = 40 messageSize = 256 ) makeMessages := func() [][]byte { msgs := make([][]byte, numMessages) for i := 0; i < numMessages; i++ { b := bytes.Repeat([]byte{byte(i % 251)}, messageSize) msgs[i] = b } return msgs } type statsResult struct { fr uint64 ok bool } errCh := make(chan error, 16) clientDone := make(chan struct{}) serverDone := make(chan struct{}) clientStatsCh := make(chan statsResult, 1) serverStatsCh := make(chan statsResult, 1) go func() { defer close(serverDone) fail := func(e error) { if e != nil { errCh <- e } } conn, err := venv.net0.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, ) if err != nil { fail(fmt.Errorf("server DialUDP: %w", err)) serverStatsCh <- statsResult{ok: false} return } defer conn.Close() // nolint:errcheck assoc, err := Server(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if err != nil { fail(fmt.Errorf("server assoc: %w", err)) serverStatsCh <- statsResult{ok: false} return } defer func() { var fr uint64 if assoc != nil { fr = assoc.stats.getNumFastRetrans() } serverStatsCh <- statsResult{fr: fr, ok: assoc != nil} _ = assoc.Close() }() stream, err := assoc.AcceptStream() if err != nil { fail(fmt.Errorf("server AcceptStream: %w", err)) return } defer stream.Close() // nolint:errcheck stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) buf := make([]byte, 1500) for { _ = stream.SetReadDeadline(time.Now().Add(500 * time.Millisecond)) n, rerr := stream.Read(buf) if rerr != nil { return } if n > 0 { _, _ = stream.Write(buf[:n]) } } }() go func() { defer close(clientDone) fail := func(e error) { if e != nil { errCh <- e } } conn, err := venv.net1.DialUDP("udp4", &net.UDPAddr{IP: net.ParseIP("2.2.2.2"), Port: defaultSCTPSrcDstPort}, &net.UDPAddr{IP: net.ParseIP("1.1.1.1"), Port: defaultSCTPSrcDstPort}, ) if err != nil { fail(fmt.Errorf("client DialUDP: %w", err)) clientStatsCh <- statsResult{ok: false} return } defer conn.Close() // nolint:errcheck assoc, err := Client(Config{ NetConn: conn, LoggerFactory: loggerFactory, }) if err != nil { fail(fmt.Errorf("client assoc: %w", err)) clientStatsCh <- statsResult{ok: false} return } defer func() { var fr uint64 if assoc != nil { fr = assoc.stats.getNumFastRetrans() } clientStatsCh <- statsResult{fr: fr, ok: assoc != nil} _ = assoc.Close() }() stream, err := assoc.OpenStream(777, PayloadTypeWebRTCBinary) if err != nil { fail(fmt.Errorf("client OpenStream: %w", err)) return } defer stream.Close() // nolint:errcheck stream.SetReliabilityParams(false, ReliabilityTypeReliable, 0) msgs := makeMessages() // phase 1: high-RTT emulation we send 25 messages and drop a DATA chunk for one time. delayOn.Store(true) venv.dropNextDataChunk(1) for i := 0; i < 25; i++ { if _, werr := stream.Write(msgs[i]); werr != nil { fail(fmt.Errorf("client write phase1 i=%d: %w", i, werr)) return } } // phase 2 we switch to low-RTT, newer datea should arrive before older. delayOn.Store(false) for i := 25; i < numMessages; i++ { if _, werr := stream.Write(msgs[i]); werr != nil { fail(fmt.Errorf("client write phase2 i=%d: %w", i, werr)) return } } seen := make(map[byte]bool, numMessages) buf := make([]byte, 4096) deadline := time.Now().Add(15 * time.Second) for len(seen) < numMessages && time.Now().Before(deadline) { _ = stream.SetReadDeadline(time.Now().Add(250 * time.Millisecond)) n, rerr := stream.Read(buf) if rerr != nil || n == 0 { continue } if n < messageSize { fail(fmt.Errorf("short echo read: got=%d want=%d", n, messageSize)) //nolint:err113 return } id := buf[0] if seen[id] { // dups are harmless, keep reading continue } expected := bytes.Repeat([]byte{id}, messageSize) if !bytes.Equal(buf[:messageSize], expected) { fail(fmt.Errorf("payload mismatch for id=%d", int(id))) //nolint:err113 return } seen[id] = true } if len(seen) != numMessages { fail(fmt.Errorf("missing echoes: got=%d want=%d", len(seen), numMessages)) //nolint:err113 return } }() <-clientDone <-serverDone // drain and assert errors, well if any :) close(errCh) for e := range errCh { assert.NoError(t, e) } // check FR stats reported. cs := <-clientStatsCh ss := <-serverStatsCh if assert.True(t, cs.ok, "client assoc/stats unavailable") { assert.LessOrEqual(t, cs.fr, uint64(2), "client fast retransmits should be low") } if assert.True(t, ss.ok, "server assoc/stats unavailable") { assert.LessOrEqual(t, ss.fr, uint64(2), "server fast retransmits should be low") } } sctp-1.9.0/windowedmin.go000066400000000000000000000034211512256410600153470ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "sort" "time" ) // windowedMin maintains a monotonic deque of (time,value) to answer // the minimum over a sliding window efficiently. // Not thread-safe; caller must synchronize (Association already does). type windowedMin struct { rackMinRTTWnd time.Duration deque []entry } type entry struct { t time.Time v time.Duration } func newWindowedMin(window time.Duration) *windowedMin { if window <= 0 { window = 30 * time.Second } return &windowedMin{rackMinRTTWnd: window} } // prune removes elements older than (now - wnd). func (window *windowedMin) prune(now time.Time) { if len(window.deque) == 0 { return } cutoff := now.Add(-window.rackMinRTTWnd) firstValidTSAfterCutoff := sort.Search(len(window.deque), func(i int) bool { return !window.deque[i].t.Before(cutoff) // no builtin func for >= cutoff time }) if firstValidTSAfterCutoff > 0 { window.deque = window.deque[firstValidTSAfterCutoff:] } } // Push inserts a new sample and preserves monotonic non-increasing values. // It maintains minimum values by removing larger entries. func (window *windowedMin) Push(now time.Time, v time.Duration) { window.prune(now) for i := len(window.deque); i > 0 && window.deque[i-1].v >= v; i-- { window.deque = window.deque[:i-1] } window.deque = append( window.deque, entry{ t: now, v: v, }, ) } // Min returns the minimum value in the current window or 0 if empty. func (window *windowedMin) Min(now time.Time) time.Duration { window.prune(now) if len(window.deque) == 0 { return 0 } return window.deque[0].v } // Len is only for tests/diagnostics. func (window *windowedMin) Len() int { return len(window.deque) } sctp-1.9.0/windowedmin_test.go000066400000000000000000000037651512256410600164210ustar00rootroot00000000000000// SPDX-FileCopyrightText: 2023 The Pion community // SPDX-License-Identifier: MIT package sctp import ( "testing" "time" "github.com/stretchr/testify/assert" ) func TestWindowedMin_Basic(t *testing.T) { window := newWindowedMin(100 * time.Millisecond) base := time.Unix(0, 0) window.Push(base, 30*time.Millisecond) assert.Equal(t, 30*time.Millisecond, window.Min(base)) window.Push(base.Add(10*time.Millisecond), 20*time.Millisecond) assert.Equal(t, 20*time.Millisecond, window.Min(base.Add(10*time.Millisecond))) // larger value shouldn't change min window.Push(base.Add(20*time.Millisecond), 40*time.Millisecond) assert.Equal(t, 20*time.Millisecond, window.Min(base.Add(20*time.Millisecond))) // decreasing again shrinks min window.Push(base.Add(30*time.Millisecond), 10*time.Millisecond) assert.Equal(t, 10*time.Millisecond, window.Min(base.Add(30*time.Millisecond))) } func TestWindowedMin_WindowExpiry(t *testing.T) { window := newWindowedMin(50 * time.Millisecond) base := time.Unix(0, 0) window.Push(base, 10*time.Millisecond) // t=0 window.Push(base.Add(10*time.Millisecond), 20*time.Millisecond) // t=10ms // at t=60ms, first sample is expired, m becomes 20ms m := window.Min(base.Add(60 * time.Millisecond)) assert.Equal(t, 20*time.Millisecond, m) // at t=200ms, all are expired -> 0 m = window.Min(base.Add(200 * time.Millisecond)) assert.Zero(t, m) } func TestWindowedMin_EqualValues(t *testing.T) { window := newWindowedMin(1 * time.Second) base := time.Unix(0, 0) window.Push(base, 15*time.Millisecond) window.Push(base.Add(1*time.Millisecond), 15*time.Millisecond) assert.Equal(t, 1, window.Len()) assert.Equal(t, 15*time.Millisecond, window.Min(base.Add(2*time.Millisecond))) } func TestWindowedMin_DefaultWindow30s(t *testing.T) { zeroWnd := newWindowedMin(0) negativeWnd := newWindowedMin(-5 * time.Second) assert.Equal(t, 30*time.Second, zeroWnd.rackMinRTTWnd) assert.Equal(t, 30*time.Second, negativeWnd.rackMinRTTWnd) }