diff --git a/.github/codecov.yml b/.github/codecov.yml new file mode 100644 index 000000000..c4e6b3f3b --- /dev/null +++ b/.github/codecov.yml @@ -0,0 +1,11 @@ +github_checks: + annotations: false + +coverage: + status: + project: + default: + informational: true + patch: + default: + informational: true diff --git a/.github/workflows/coverage.yml b/.github/workflows/coverage.yml new file mode 100644 index 000000000..1d03357b7 --- /dev/null +++ b/.github/workflows/coverage.yml @@ -0,0 +1,21 @@ +on: + pull_request: + merge_group: + +name: Coverage + +jobs: + coverage: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Install `cargo llvm-cov` + uses: taiki-e/install-action@cargo-llvm-cov + - name: Run Coverage + run: ./ci.sh coverage + - name: Upload coverage to Codecov + uses: codecov/codecov-action@v3 + with: + #token: ${{ secrets.CODECOV_TOKEN }} # not required for public repos + files: lcov.info + fail_ci_if_error: false diff --git a/.github/workflows/fuzz.yml b/.github/workflows/fuzz.yml new file mode 100644 index 000000000..e7a787afa --- /dev/null +++ b/.github/workflows/fuzz.yml @@ -0,0 +1,17 @@ +on: + pull_request: + merge_group: + +name: Fuzz + +jobs: + fuzz: + runs-on: ubuntu-22.04 + env: + RUSTUP_TOOLCHAIN: nightly + steps: + - uses: actions/checkout@v2 + - name: Install cargo-fuzz + run: cargo install cargo-fuzz + - name: Fuzz + run: cargo fuzz run packet_parser -- -max_len=1536 -max_total_time=30 diff --git a/.github/workflows/matrix-bot.yml b/.github/workflows/matrix-bot.yml new file mode 100644 index 000000000..ca51045ba --- /dev/null +++ b/.github/workflows/matrix-bot.yml @@ -0,0 +1,44 @@ +name: Matrix bot +on: + pull_request_target: + types: [opened, closed] + +jobs: + new-pr: + if: github.event.action == 'opened' && github.repository == 'smoltcp-rs/smoltcp' + runs-on: ubuntu-latest + continue-on-error: true + steps: + - name: send message + uses: s3krit/matrix-message-action@v0.0.3 + with: + room_id: ${{ secrets.MATRIX_ROOM_ID }} + access_token: ${{ secrets.MATRIX_ACCESS_TOKEN }} + message: "New PR: [${{ github.event.pull_request.title }}](${{ github.event.pull_request.html_url }})" + server: "matrix.org" + + merged-pr: + if: github.event.action == 'closed' && github.event.pull_request.merged == true && github.repository == 'smoltcp-rs/smoltcp' + runs-on: ubuntu-latest + continue-on-error: true + steps: + - name: send message + uses: s3krit/matrix-message-action@v0.0.3 + with: + room_id: ${{ secrets.MATRIX_ROOM_ID }} + access_token: ${{ secrets.MATRIX_ACCESS_TOKEN }} + message: "PR merged: [${{ github.event.pull_request.title }}](${{ github.event.pull_request.html_url }})" + server: "matrix.org" + + abandoned-pr: + if: github.event.action == 'closed' && github.event.pull_request.merged == false && github.repository == 'smoltcp-rs/smoltcp' + runs-on: ubuntu-latest + continue-on-error: true + steps: + - name: send message + uses: s3krit/matrix-message-action@v0.0.3 + with: + room_id: ${{ secrets.MATRIX_ROOM_ID }} + access_token: ${{ secrets.MATRIX_ACCESS_TOKEN }} + message: "PR closed without merging: [${{ github.event.pull_request.title }}](${{ github.event.pull_request.html_url }})" + server: "matrix.org" diff --git a/.github/workflows/rustfmt.yaml b/.github/workflows/rustfmt.yaml new file mode 100644 index 000000000..107164af5 --- /dev/null +++ b/.github/workflows/rustfmt.yaml @@ -0,0 +1,12 @@ +on: + pull_request: + merge_group: + +name: Rustfmt check +jobs: + fmt: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v2 + - name: Check fmt + run: cargo fmt -- --check diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml new file mode 100644 index 000000000..ee6774c98 --- /dev/null +++ b/.github/workflows/test.yml @@ -0,0 +1,64 @@ +on: + pull_request: + merge_group: + +name: Test + +jobs: + tests: + runs-on: ubuntu-22.04 + needs: [check-msrv, test-msrv, test-stable, clippy] + steps: + - name: Done + run: exit 0 + + check-msrv: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Run Checks MSRV + run: ./ci.sh check msrv + + test-msrv: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Run Tests MSRV + run: ./ci.sh test msrv + + clippy: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Run Clippy + run: ./ci.sh clippy + + test-stable: + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v2 + - name: Run Tests stable + run: ./ci.sh test stable + + test-nightly: + runs-on: ubuntu-22.04 + continue-on-error: true + steps: + - uses: actions/checkout@v2 + - name: Run Tests nightly + run: ./ci.sh test nightly + + #check-stable: + #runs-on: ubuntu-22.04 + #steps: + #- uses: actions/checkout@v2 + #- name: Run Tests + #run: ./ci.sh check stable + + #check-nightly: + #runs-on: ubuntu-22.04 + #continue-on-error: true + #steps: + #- uses: actions/checkout@v2 + #- name: Run Tests + #run: ./ci.sh check nightly diff --git a/.test_like_travis.rb b/.test_like_travis.rb deleted file mode 100755 index e75d156f1..000000000 --- a/.test_like_travis.rb +++ /dev/null @@ -1,19 +0,0 @@ -#!/usr/bin/ruby - -require 'yaml' - -travis_config = YAML.load_file('.travis.yml') -travis_config['matrix']['include'].each do |env| - ENV['RUSTUP_TOOLCHAIN'] = env['rust'] - env['env'].scan(/(\w+)=\'(.+?)\'/) do - ENV[$1] = $2 - end - travis_config['script'].each do |cmd| - $stderr.puts('+ ' + cmd.gsub(/\$(\w+)/) { ENV[$1] }) - system(cmd) - $?.success? or exit 1 - end - env['env'].scan(/(\w+)=\'(.+?)\'/) do - ENV.delete $1 - end -end diff --git a/.travis.yml b/.travis.yml deleted file mode 100644 index 1490e63c9..000000000 --- a/.travis.yml +++ /dev/null @@ -1,69 +0,0 @@ -language: rust -matrix: - include: - ### Litmus check that we work on stable/beta - # (we don't, not until slice_rotate lands) - # - rust: stable - # env: FEATURES='default' MODE='test' - # - rust: beta - # env: FEATURES='default' MODE='test' - ### Test default configurations - - rust: nightly - env: FEATURES='default' MODE='test' - ### Test select feature permutations, chosen to be as orthogonal as possible - - rust: nightly - env: FEATURES='std ethernet phy-raw_socket proto-ipv6 socket-udp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet phy-tap_interface proto-ipv6 socket-udp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv4 proto-igmp socket-raw' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv4 socket-udp socket-tcp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv4 proto-dhcpv4 socket-udp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv6 socket-udp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv6 socket-tcp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv4 socket-icmp socket-tcp' MODE='test' - - rust: nightly - env: FEATURES='std ethernet proto-ipv6 socket-icmp socket-tcp' MODE='test' - ### Test select feature permutations, chosen to be as aggressive as possible - - rust: nightly - env: FEATURES='ethernet proto-ipv4 proto-ipv6 socket-raw socket-udp socket-tcp socket-icmp std' - MODE='test' - - rust: nightly - env: FEATURES='ethernet proto-ipv4 proto-ipv6 socket-raw socket-udp socket-tcp socket-icmp alloc' - MODE='test' - - rust: nightly - env: FEATURES='proto-ipv4 proto-ipv6 socket-raw socket-udp socket-tcp socket-icmp alloc' - MODE='test' - - rust: nightly - env: FEATURES='ethernet proto-ipv4 proto-ipv6 proto-igmp proto-dhcpv4 socket-raw socket-udp socket-tcp socket-icmp' - MODE='build' - - rust: nightly - env: MODE='fuzz run' ARGS='packet_parser -- -max_len=1536 -max_total_time=30' - - rust: nightly - env: FEATURES='default' MODE='clippy' - - rust: nightly - env: FEATURES='default' MODE='check --bench bench' - - os: osx - rust: nightly - env: FEATURES='default' MODE='build' - allow_failures: - # something's screwy with Travis (as usual) and cargo-fuzz dies with a LeakSanitizer error - # even when it's successful. Keep this in .travis.yml in case it starts working some day. - - rust: nightly - env: MODE='fuzz run' ARGS='packet_parser -- -max_len=1536 -max_total_time=30' - # Clippy sometimes fails to compile and this shouldn't gate merges - - rust: nightly - env: FEATURES='default' MODE='clippy' - # See if the bench actually breaks - - rust: nightly - env: FEATURES='default' MODE='check --bench bench' -before_script: - - if [ "$MODE" == "fuzz run" ]; then cargo install cargo-fuzz; fi - - if [ "$MODE" == "clippy" ]; then cargo install clippy; fi -script: - - cargo $MODE --no-default-features --features "$FEATURES" $ARGS diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 000000000..29ae5ea4a --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,215 @@ +# Changelog +All notable changes to this project will be documented in this file. + +The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/), +and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html). + +## [Unreleased] + +## [0.9.1] - 2023-02-08 + +- iface: make MulticastError public. (#747) +- Fix parsing of ieee802154 link layer address for NDISC options (#746) + +## [0.9.0] - 2023-02-06 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.56 to 1.65 +- Added DNS client support. + - Add DnsSocket (#465) + - Add support for one-shot mDNS resolution (#669) +- Added support for packet fragmentation and reassembly, both for IPv4 and 6LoWPAN. (#591, #580, #624, #634, #645, #653, #684) +- Major error handling overhaul. + - Previously, _smoltcp_ had a single `Error` enum that all methods returned. Now methods that can fail have their own error enums, with only the actual errors they can return. (#617, #667, #730) + - Consuming `phy::Device` tokens is now infallible. + - In the case of "buffer full", `phy::Device` implementations must return `None` from the `transmit`/`receive` methods. (Previously, they could either do that, or return tokens and then return `Error::Exhausted` when consuming them. The latter wasted computation since it'd make _smoltcp_ pointlessly spend effort preparing the packet, and is now disallowed). + - For all other phy errors, `phy::Device` implementations should drop the packet and handle the error themselves. (Either log it and forget it, or buffer/count it and offer methods to let the user retrieve the error queue/counts.) Returning the error to have it bubble up to `Interface::poll()` is no longer supported. +- phy: the `trait Device` now uses Generic Associated Types (GAT) for the TX and RX tokens. The main impact of this is `Device` impls can now borrow data (because previously, the`for<'a> T: Device<'a>` bounds required to workaround the lack of GATs essentially implied `T: 'static`.) (#572) +- iface: The `Interface` API has been significantly simplified and cleaned up. + - The builder has been removed (#736) + - SocketSet and Device are now borrowed in methods that need them, instead of owning them. (#619) + - `Interface` now owns the list of addresses (#719), routes, neighbor cache (#722), 6LoWPAN address contexts, and fragmentation buffers (#736) instead of borrowing them with `managed`. + - A new compile-time configuration mechanism has been added, to configure the size of the (now owned) buffers (#742) +- iface: Change neighbor discovery timeout from 3s to 1s, to match Linux's behavior. (#620) +- iface: Remove implicit sized bound on device generics (#679) +- iface/6lowpan: Add address context information for resolving 6LoWPAN addresses (#687) +- iface/6lowpan: fix incorrect SAM value in IPHC when address is not compressed (#630) +- iface/6lowpan: packet parsing fuzz fixes (#636) +- socket: Add send_with to udp, raw, and icmp sockets. These methods enable reserving a packet buffer with a greater size than you need, and then shrinking the size once you know it. (#625) +- socket: Make `trait AnySocket` object-safe (#718) +- socket/dhcpv4: add waker support (#623) +- socket/dhcpv4: indicate new config if there's a packet buffer provided (#685) +- socket/dhcpv4: Use renewal time from DHCP server ACK, if given (#683) +- socket/dhcpv4: allow for extra configuration + - setting arbitrary options in the request. (#650) + - retrieving arbitrary options from the response. (#650) + - setting custom parameter request list. (#650) + - setting custom timing for retries. (#650) + - Allow specifying different server/client DHCP ports (#738) +- socket/raw: Add `peek` and `peek_slice` methods (#734) +- socket/raw: When sending packets, send the source IP address unmodified (it was previously replaced with the interface's address if it was unspecified). (#616) +- socket/tcp: Do not reset socket-level settings, such as keepalive, on reset (#603) +- socket/tcp: ensure we always accept the segment at offset=0 even if the assembler is full. (#735, #452) +- socket/tcp: Refactored assembler, now more robust and faster (#726, #735) +- socket/udp: accept packets with checksum field set to `0`, since that means the checksum is not computed (#632) +- wire: make many functions const (#693) +- wire/dhcpv4: remove Option enum (#656) +- wire/dhcpv4: use heapless Vec for DNS server list (#678) +- wire/icmpv4: add support for TimeExceeded packets (#609) +- wire/ip: Remove `IpRepr::Unspecified`, `IpVersion::Unspecified`, `IpAddress::Unspecified` (#579, #616) +- wire/ip: support parsing unspecified IPv6 IpEndpoints from string (like `[::]:12345`) (#732) +- wire/ipv6: Make Public Ipv6RoutingType (#691) +- wire/ndisc: do not error on unrecognized options. (#737) +- Switch to Rust 2021 edition. (#729) +- Remove obsolete Cargo feature `rust-1_28` (#725) + +## [0.8.2] - 2022-11-27 + +- tcp: Fix return value of nagle_enable ([#642](https://github.com/smoltcp-rs/smoltcp/pull/642)) +- tcp: Only clear retransmit timer when all packets are acked ([#662](https://github.com/smoltcp-rs/smoltcp/pull/662)) +- tcp: Send incomplete fin packets even if nagle enabled ([#665](https://github.com/smoltcp-rs/smoltcp/pull/665)) +- phy: Fix mtu calculation for raw_socket ([#611](https://github.com/smoltcp-rs/smoltcp/pull/611)) +- wire: Fix ipv6 contains_addr function ([#605](https://github.com/smoltcp-rs/smoltcp/pull/605)) + +## [0.8.1] - 2022-05-12 + +- Remove unused `rand_core` dep. ([#589](https://github.com/smoltcp-rs/smoltcp/pull/589)) +- Use socklen_t instead of u32 for RawSocket bind() parameter. Fixes build on 32bit Android. ([#593](https://github.com/smoltcp-rs/smoltcp/pull/593)) +- Propagate phy::RawSocket send errors to caller ([#588](https://github.com/smoltcp-rs/smoltcp/pull/588)) +- Fix Interface set_hardware_addr, get_hardware_addr for ieee802154/6lowpan. ([#584](https://github.com/smoltcp-rs/smoltcp/pull/584)) + +## [0.8.0] - 2021-12-11 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.40 to 1.56 +- Add support for IEEE 802.15.4 + 6LoWPAN medium ([#469](https://github.com/smoltcp-rs/smoltcp/pull/469)) +- Add support for IP medium ([#401](https://github.com/smoltcp-rs/smoltcp/pull/401)) +- Add `defmt` logging supprt ([#455](https://github.com/smoltcp-rs/smoltcp/pull/455)) +- Add RNG infrastructure ([#547](https://github.com/smoltcp-rs/smoltcp/pull/547), [#573](https://github.com/smoltcp-rs/smoltcp/pull/573)) +- Add `Context` struct that must be passed to some socket methods ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) +- Remove `SocketSet`, sockets are owned by `Interface` now. ([#557](https://github.com/smoltcp-rs/smoltcp/pull/557), [#571](https://github.com/smoltcp-rs/smoltcp/pull/571)) +- TCP: Add Nagle's Algorithm. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) +- TCP crash and correctness fixes: + - Add Nagle's Algorithm. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) + - Window scaling fixes. ([#500](https://github.com/smoltcp-rs/smoltcp/pull/500)) + - Fix delayed ack causing ack not to be sent after 3 packets. ([#530](https://github.com/smoltcp-rs/smoltcp/pull/530)) + - Fix RTT estimation for RTTs longer than 1 second ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop when remote side sets a MSS of 0 ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop when retransmit when remote window is 0 ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix crash when receiving a FIN in SYN_SENT state ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix overflow crash when receiving a wrong ACK seq in SYN_RECEIVED state ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix overflow crash when initial sequence number is u32::MAX ([#538](https://github.com/smoltcp-rs/smoltcp/pull/538)) + - Fix infinite loop on challenge ACKs ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Reply with RST to invalid packets in SynReceived state. ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Do not abort socket when receiving some invalid packets. ([#542](https://github.com/smoltcp-rs/smoltcp/pull/542)) + - Make initial sequence number random. ([#547](https://github.com/smoltcp-rs/smoltcp/pull/547)) + - Reply with RST to ACKs with invalid ackno in SYN_SENT. ([#522](https://github.com/smoltcp-rs/smoltcp/pull/522)) +- ARP fixes to deal better with broken networks: + - Fill cache only from ARP packets, not any packets. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Fill cache only from ARP packets directed at us. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Reject ARP packets with a source address not in the local network. ([#536](https://github.com/smoltcp-rs/smoltcp/pull/536), [#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Ignore unknown ARP packets. ([#544](https://github.com/smoltcp-rs/smoltcp/pull/544)) + - Flush neighbor cache on IP change ([#564](https://github.com/smoltcp-rs/smoltcp/pull/564)) +- UDP: Add `close()` method to unbind socket. ([#475](https://github.com/smoltcp-rs/smoltcp/pull/475), [#482](https://github.com/smoltcp-rs/smoltcp/pull/482)) +- DHCP client improvements: + - Refactored implementation to improve reliability and RFC compliance ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Convert to socket ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Added `max_lease_duration` option ([#459](https://github.com/smoltcp-rs/smoltcp/pull/459)) + - Do not set the BROADCAST flag ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) + - Add option to ignore NAKs ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) +- DHCP wire: + - Fix DhcpRepr::buffer_len not accounting for lease time, router and subnet options ([#478](https://github.com/smoltcp-rs/smoltcp/pull/478)) + - Emit DNS servers in DhcpRepr ([#510](https://github.com/smoltcp-rs/smoltcp/pull/510)) + - Fix incorrect bit for BROADCAST flag ([#548](https://github.com/smoltcp-rs/smoltcp/pull/548)) +- Improve resilience against packet ingress processing errors ([#281](https://github.com/smoltcp-rs/smoltcp/pull/281), [#483](https://github.com/smoltcp-rs/smoltcp/pull/483)) +- Implement `std::error::Error` for `smoltcp::Error` ([#485](https://github.com/smoltcp-rs/smoltcp/pull/485)) +- Update `managed` from 0.7 to 0.8 ([442](https://github.com/smoltcp-rs/smoltcp/pull/442)) +- Fix incorrect timestamp in PCAP captures ([#513](https://github.com/smoltcp-rs/smoltcp/pull/513)) +- Use microseconds instead of milliseconds in Instant and Duration ([#514](https://github.com/smoltcp-rs/smoltcp/pull/514)) +- Expose inner `Device` in `PcapWriter` ([#524](https://github.com/smoltcp-rs/smoltcp/pull/524)) +- Fix assert with any_ip + broadcast dst_addr. ([#533](https://github.com/smoltcp-rs/smoltcp/pull/533), [#534](https://github.com/smoltcp-rs/smoltcp/pull/534)) +- Simplify PcapSink trait ([#535](https://github.com/smoltcp-rs/smoltcp/pull/535)) +- Fix wrong operation order in FuzzInjector ([#525](https://github.com/smoltcp-rs/smoltcp/pull/525), [#535](https://github.com/smoltcp-rs/smoltcp/pull/535)) + +## [0.7.5] - 2021-06-28 + +- dhcpv4: emit DNS servers in repr ([#505](https://github.com/smoltcp-rs/smoltcp/pull/505)) + +## [0.7.4] - 2021-06-11 + +- tcp: fix "subtract sequence numbers with underflow" on remote window shrink. ([#490](https://github.com/smoltcp-rs/smoltcp/pull/490)) +- tcp: fix substract with overflow when receiving a SYNACK with unincremented ACK number. ([#491](https://github.com/smoltcp-rs/smoltcp/pull/491)) +- tcp: use nonzero initial sequence number to workaround misbehaving servers. ([#492](https://github.com/smoltcp-rs/smoltcp/pull/492)) + +## [0.7.3] - 2021-05-29 + +- Fix "unused attribute" error in recent nightlies. + +## [0.7.2] - 2021-05-29 + +- iface: check for ipv4 subnet broadcast addrs everywhere ([#462](https://github.com/smoltcp-rs/smoltcp/pull/462)) +- dhcp: always send parameter_request_list. ([#456](https://github.com/smoltcp-rs/smoltcp/pull/456)) +- dhcp: Clear expiration time on reset. ([#456](https://github.com/smoltcp-rs/smoltcp/pull/456)) +- phy: fix FaultInjector returning a too big buffer when simulating a drop on tx ([#463](https://github.com/smoltcp-rs/smoltcp/pull/463)) +- tcp rtte: fix "attempt to multiply with overflow". ([#476](https://github.com/smoltcp-rs/smoltcp/pull/476)) +- tcp: LastAck should only change to Closed on ack of fin. ([#477](https://github.com/smoltcp-rs/smoltcp/pull/477)) +- wire/dhcpv4: account for lease time, router and subnet options in DhcpRepr::buffer_len ([#478](https://github.com/smoltcp-rs/smoltcp/pull/478)) + +## [0.7.1] - 2021-03-27 + +- ndisc: Fix NeighborSolicit incorrectly asking for src addr instead of dst addr ([419](https://github.com/smoltcp-rs/smoltcp/pull/419)) +- dhcpv4: respect lease time from the server instead of renewing every 60 seconds. ([437](https://github.com/smoltcp-rs/smoltcp/pull/437)) +- Fix build errors due to invalid combinations of features ([416](https://github.com/smoltcp-rs/smoltcp/pull/416), [447](https://github.com/smoltcp-rs/smoltcp/pull/447)) +- wire/ipv4: make some functions const ([420](https://github.com/smoltcp-rs/smoltcp/pull/420)) +- phy: fix BPF on OpenBSD ([421](https://github.com/smoltcp-rs/smoltcp/pull/421), [427](https://github.com/smoltcp-rs/smoltcp/pull/427)) +- phy: enable RawSocket, TapInterface on Android ([435](https://github.com/smoltcp-rs/smoltcp/pull/435)) +- phy: fix phy_wait for waits longer than 1 second ([449](https://github.com/smoltcp-rs/smoltcp/pull/449)) + +## [0.7.0] - 2021-01-20 + +- Minimum Supported Rust Version (MSRV) **bumped** from 1.36 to 1.40 + +### New features +- tcp: Allow distinguishing between graceful (FIN) and ungraceful (RST) close. On graceful close, `recv()` now returns `Error::Finished`. On ungraceful close, `Error::Illegal` is returned, as before. ([351](https://github.com/smoltcp-rs/smoltcp/pull/351)) +- sockets: Add support for attaching async/await Wakers to sockets. Wakers are woken on socket state changes. ([394](https://github.com/smoltcp-rs/smoltcp/pull/394)) +- tcp: Set retransmission timeout based on an RTT estimation, instead of the previously fixed 100ms. This improves performance on high-latency links, such as mobile networks. ([406](https://github.com/smoltcp-rs/smoltcp/pull/406)) +- tcp: add Delayed ACK support. On by default, with a 10ms delay. ([404](https://github.com/smoltcp-rs/smoltcp/pull/404)) +- ip: Process broadcast packets directed to the subnet's broadcast address, such as 192.168.1.255. Previously broadcast packets were +only processed when directed to the 255.255.255.255 address. ([377](https://github.com/smoltcp-rs/smoltcp/pull/377)) + +### Fixes +- udp,raw,icmp: Fix packet buffer panic caused by large payload ([332](https://github.com/smoltcp-rs/smoltcp/pull/332)) +- dhcpv4: use offered ip in requested ip option ([310](https://github.com/smoltcp-rs/smoltcp/pull/310)) +- dhcpv4: Re-export dhcp::clientv4::Config +- dhcpv4: Enable `proto-dhcpv4` feature by default. ([327](https://github.com/smoltcp-rs/smoltcp/pull/327)) +- ethernet,arp: Allow for ARP retry during egress ([368](https://github.com/smoltcp-rs/smoltcp/pull/368)) +- ethernet,arp: Only limit the neighbor cache rate after sending a request packet ([369](https://github.com/smoltcp-rs/smoltcp/pull/369)) +- tcp: use provided ip for TcpSocket::connect instead of 0.0.0.0 ([329](https://github.com/smoltcp-rs/smoltcp/pull/329)) +- tcp: Accept data packets in FIN_WAIT_2 state. ([350](https://github.com/smoltcp-rs/smoltcp/pull/350)) +- tcp: Always send updated ack number in `ack_reply()`. ([353](https://github.com/smoltcp-rs/smoltcp/pull/353)) +- tcp: allow sending ACKs in FinWait2 state. ([388](https://github.com/smoltcp-rs/smoltcp/pull/388)) +- tcp: fix racey simultaneous close not sending FIN. ([398](https://github.com/smoltcp-rs/smoltcp/pull/398)) +- tcp: Do not send window updates in states that shouldn't do so ([360](https://github.com/smoltcp-rs/smoltcp/pull/360)) +- tcp: Return RST to unexpected ACK in SYN-SENT state. ([367](https://github.com/smoltcp-rs/smoltcp/pull/367)) +- tcp: Take MTU into account during TcpSocket dispatch. ([384](https://github.com/smoltcp-rs/smoltcp/pull/384)) +- tcp: don't send data outside the remote window ([387](https://github.com/smoltcp-rs/smoltcp/pull/387)) +- phy: Take Ethernet header into account for MTU of RawSocket and TapInterface. ([393](https://github.com/smoltcp-rs/smoltcp/pull/393)) +- phy: add null terminator to c-string passed to libc API ([372](https://github.com/smoltcp-rs/smoltcp/pull/372)) + +### Quality of Life™ improvements +- Update to Rust 2018 edition ([396](https://github.com/smoltcp-rs/smoltcp/pull/396)) +- Migrate CI to Github Actions ([390](https://github.com/smoltcp-rs/smoltcp/pull/390)) +- Fix clippy lints, enforce clippy in CI ([395](https://github.com/smoltcp-rs/smoltcp/pull/395), [402](https://github.com/smoltcp-rs/smoltcp/pull/402), [403](https://github.com/smoltcp-rs/smoltcp/pull/403), [405](https://github.com/smoltcp-rs/smoltcp/pull/405), [407](https://github.com/smoltcp-rs/smoltcp/pull/407)) +- Use #[non_exhaustive] for enums and structs ([409](https://github.com/smoltcp-rs/smoltcp/pull/409), [411](https://github.com/smoltcp-rs/smoltcp/pull/411)) +- Simplify lifetime parameters of sockets, SocketSet, EthernetInterface ([410](https://github.com/smoltcp-rs/smoltcp/pull/410), [413](https://github.com/smoltcp-rs/smoltcp/pull/413)) + +[Unreleased]: https://github.com/smoltcp-rs/smoltcp/compare/v0.9.1...HEAD +[0.9.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.9.0...v0.9.1 +[0.9.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.2...v0.9.0 +[0.8.2]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.1...v0.8.2 +[0.8.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.8.0...v0.8.1 +[0.8.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.0...v0.8.0 +[0.7.5]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.4...v0.7.5 +[0.7.4]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.3...v0.7.4 +[0.7.3]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.2...v0.7.3 +[0.7.2]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.1...v0.7.2 +[0.7.1]: https://github.com/smoltcp-rs/smoltcp/compare/v0.7.0...v0.7.1 +[0.7.0]: https://github.com/smoltcp-rs/smoltcp/compare/v0.6.0...v0.7.0 diff --git a/CODE_STYLE.md b/CODE_STYLE.md deleted file mode 100644 index 902e20f44..000000000 --- a/CODE_STYLE.md +++ /dev/null @@ -1,99 +0,0 @@ -# Code style - -smoltcp does not follow the rustfmt code style because whitequark (the original -author of smoltcp) finds automated formatters annoying and impairing readability -just as much as improving it in different cases. - -In general, format the things like the existing code and it'll be alright. -Here are a few things to watch out for, though: - -## Ordering use statements - -Use statements should be separated into two sections, uses from other crates and uses -from the current crate. The latter would ideally be sorted from most general -to most specific, but it's not very important. - -```rust -use core::cell::RefCell; - -use {Error, Result}; -use phy::{self, DeviceCapabilities, Device}; -``` - -## Wrapping function calls - -Avoid rightwards drift. This is fine: - -```rust -assert_eq!(iface.inner.process_ethernet(&mut socket_set, 0, frame.into_inner()), - Ok(Packet::None)); -``` - -This is also fine: - -```rust -assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, 0, - &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), - &IpAddress::Ipv4(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); -``` - -This is not: - -```rust -assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, 0, - &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), - &IpAddress::Ipv4(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); -``` - -## Wrapping function prototypes - -A function declaration might be wrapped... - - * right after `,`, - * right after `>`, - * right after `)`, - * right after `->`, - * right before and after `where`. - -Here's an artificial example, wrapped at 50 columns: - -```rust -fn dispatch_ethernet - (&mut self, tx_token: Tx, - timestamp: u64, f: F) -> - Result<()> - where Tx: TxToken, - F: FnOnce(EthernetFrame<&mut [u8]>) -{ - // ... -} -``` - -## Visually aligning tokens - -This is fine: - -```rust -struct State { - rng_seed: u32, - refilled_at: u64, - tx_bucket: u64, - rx_bucket: u64, -} -``` - -This is also fine: - -```rust -struct State { - rng_seed: u32, - refilled_at: u64, - tx_bucket: u64, - rx_bucket: u64, -} -``` - -It's OK to change between those if you touch that code anyway, -but avoid reformatting just for the sake of it. diff --git a/Cargo.toml b/Cargo.toml index ed21d0d5a..9167f7f05 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -1,11 +1,13 @@ [package] name = "smoltcp" -version = "0.6.0" +version = "0.9.1" +edition = "2021" +rust-version = "1.65" authors = ["whitequark "] description = "A TCP/IP stack designed for bare-metal, real-time systems without a heap." documentation = "https://docs.rs/smoltcp/" -homepage = "https://github.com/m-labs/smoltcp" -repository = "https://github.com/m-labs/smoltcp.git" +homepage = "https://github.com/smoltcp-rs/smoltcp" +repository = "https://github.com/smoltcp-rs/smoltcp.git" readme = "README.md" keywords = ["ip", "tcp", "udp", "ethernet", "network"] categories = ["embedded", "network-programming"] @@ -15,43 +17,222 @@ license = "0BSD" autoexamples = false [dependencies] -managed = { version = "0.7", default-features = false, features = ["map"] } +managed = { version = "0.8", default-features = false, features = ["map"] } byteorder = { version = "1.0", default-features = false } log = { version = "0.4.4", default-features = false, optional = true } libc = { version = "0.2.18", optional = true } bitflags = { version = "1.0", default-features = false } +defmt = { version = "0.3", optional = true } +cfg-if = "1.0.0" +heapless = "0.7.15" [dev-dependencies] -env_logger = "0.5" +env_logger = "0.10" getopts = "0.2" -rand = "0.3" -url = "1.0" +rand = "0.8" +url = "2.0" +rstest = "0.17" [features] -std = ["managed/std"] -alloc = ["managed/alloc"] +std = ["managed/std", "alloc"] +alloc = ["managed/alloc", "defmt?/alloc"] verbose = [] -ethernet = [] +defmt = [ "dep:defmt", "heapless/defmt", "heapless/defmt-impl" ] +"medium-ethernet" = ["socket"] +"medium-ip" = ["socket"] +"medium-ieee802154" = ["socket", "proto-sixlowpan"] + "phy-raw_socket" = ["std", "libc"] -"phy-tap_interface" = ["std", "libc"] +"phy-tuntap_interface" = ["std", "libc", "medium-ethernet"] + "proto-ipv4" = [] +"proto-ipv4-fragmentation" = ["proto-ipv4", "_proto-fragmentation"] "proto-igmp" = ["proto-ipv4"] -"proto-dhcpv4" = ["proto-ipv4", "socket-raw"] +"proto-dhcpv4" = ["proto-ipv4"] "proto-ipv6" = [] -"socket-raw" = [] -"socket-udp" = [] -"socket-tcp" = [] -"socket-icmp" = [] +"proto-rpl" = [] +"proto-sixlowpan" = ["proto-ipv6"] +"proto-sixlowpan-fragmentation" = ["proto-sixlowpan", "_proto-fragmentation"] +"proto-dns" = [] + +"socket" = [] +"socket-raw" = ["socket"] +"socket-udp" = ["socket"] +"socket-tcp" = ["socket"] +"socket-icmp" = ["socket"] +"socket-dhcpv4" = ["socket", "medium-ethernet", "proto-dhcpv4"] +"socket-dns" = ["socket", "proto-dns"] +"socket-mdns" = ["socket-dns"] + +"packetmeta-id" = [] + +"async" = [] + default = [ "std", "log", # needed for `cargo test --no-default-features --features default` :/ - "ethernet", - "phy-raw_socket", "phy-tap_interface", - "proto-ipv4", "proto-igmp", "proto-ipv6", - "socket-raw", "socket-icmp", "socket-udp", "socket-tcp" + "medium-ethernet", "medium-ip", "medium-ieee802154", + "phy-raw_socket", "phy-tuntap_interface", + "proto-ipv4", "proto-igmp", "proto-dhcpv4", "proto-ipv6", "proto-dns", + "proto-ipv4-fragmentation", "proto-sixlowpan-fragmentation", + "socket-raw", "socket-icmp", "socket-udp", "socket-tcp", "socket-dhcpv4", "socket-dns", "socket-mdns", + "packetmeta-id", "async" ] -# experimental; do not use; no guarantees provided that this feature will be kept -"rust-1_28" = [] +# Private features +# Features starting with "_" are considered private. They should not be enabled by +# other crates, and they are not considered semver-stable. + +"_proto-fragmentation" = [] + +# BEGIN AUTOGENERATED CONFIG FEATURES +# Generated by gen_config.py. DO NOT EDIT. +iface-max-addr-count-1 = [] +iface-max-addr-count-2 = [] # Default +iface-max-addr-count-3 = [] +iface-max-addr-count-4 = [] +iface-max-addr-count-5 = [] +iface-max-addr-count-6 = [] +iface-max-addr-count-7 = [] +iface-max-addr-count-8 = [] + +iface-max-multicast-group-count-1 = [] +iface-max-multicast-group-count-2 = [] +iface-max-multicast-group-count-3 = [] +iface-max-multicast-group-count-4 = [] # Default +iface-max-multicast-group-count-5 = [] +iface-max-multicast-group-count-6 = [] +iface-max-multicast-group-count-7 = [] +iface-max-multicast-group-count-8 = [] +iface-max-multicast-group-count-16 = [] +iface-max-multicast-group-count-32 = [] +iface-max-multicast-group-count-64 = [] +iface-max-multicast-group-count-128 = [] +iface-max-multicast-group-count-256 = [] +iface-max-multicast-group-count-512 = [] +iface-max-multicast-group-count-1024 = [] + +iface-max-sixlowpan-address-context-count-1 = [] +iface-max-sixlowpan-address-context-count-2 = [] +iface-max-sixlowpan-address-context-count-3 = [] +iface-max-sixlowpan-address-context-count-4 = [] # Default +iface-max-sixlowpan-address-context-count-5 = [] +iface-max-sixlowpan-address-context-count-6 = [] +iface-max-sixlowpan-address-context-count-7 = [] +iface-max-sixlowpan-address-context-count-8 = [] +iface-max-sixlowpan-address-context-count-16 = [] +iface-max-sixlowpan-address-context-count-32 = [] +iface-max-sixlowpan-address-context-count-64 = [] +iface-max-sixlowpan-address-context-count-128 = [] +iface-max-sixlowpan-address-context-count-256 = [] +iface-max-sixlowpan-address-context-count-512 = [] +iface-max-sixlowpan-address-context-count-1024 = [] + +iface-neighbor-cache-count-1 = [] +iface-neighbor-cache-count-2 = [] +iface-neighbor-cache-count-3 = [] +iface-neighbor-cache-count-4 = [] # Default +iface-neighbor-cache-count-5 = [] +iface-neighbor-cache-count-6 = [] +iface-neighbor-cache-count-7 = [] +iface-neighbor-cache-count-8 = [] +iface-neighbor-cache-count-16 = [] +iface-neighbor-cache-count-32 = [] +iface-neighbor-cache-count-64 = [] +iface-neighbor-cache-count-128 = [] +iface-neighbor-cache-count-256 = [] +iface-neighbor-cache-count-512 = [] +iface-neighbor-cache-count-1024 = [] + +iface-max-route-count-1 = [] +iface-max-route-count-2 = [] # Default +iface-max-route-count-3 = [] +iface-max-route-count-4 = [] +iface-max-route-count-5 = [] +iface-max-route-count-6 = [] +iface-max-route-count-7 = [] +iface-max-route-count-8 = [] +iface-max-route-count-16 = [] +iface-max-route-count-32 = [] +iface-max-route-count-64 = [] +iface-max-route-count-128 = [] +iface-max-route-count-256 = [] +iface-max-route-count-512 = [] +iface-max-route-count-1024 = [] + +fragmentation-buffer-size-256 = [] +fragmentation-buffer-size-512 = [] +fragmentation-buffer-size-1024 = [] +fragmentation-buffer-size-1500 = [] # Default +fragmentation-buffer-size-2048 = [] +fragmentation-buffer-size-4096 = [] +fragmentation-buffer-size-8192 = [] +fragmentation-buffer-size-16384 = [] +fragmentation-buffer-size-32768 = [] +fragmentation-buffer-size-65536 = [] + +assembler-max-segment-count-1 = [] +assembler-max-segment-count-2 = [] +assembler-max-segment-count-3 = [] +assembler-max-segment-count-4 = [] # Default +assembler-max-segment-count-8 = [] +assembler-max-segment-count-16 = [] +assembler-max-segment-count-32 = [] + +reassembly-buffer-size-256 = [] +reassembly-buffer-size-512 = [] +reassembly-buffer-size-1024 = [] +reassembly-buffer-size-1500 = [] # Default +reassembly-buffer-size-2048 = [] +reassembly-buffer-size-4096 = [] +reassembly-buffer-size-8192 = [] +reassembly-buffer-size-16384 = [] +reassembly-buffer-size-32768 = [] +reassembly-buffer-size-65536 = [] + +reassembly-buffer-count-1 = [] # Default +reassembly-buffer-count-2 = [] +reassembly-buffer-count-3 = [] +reassembly-buffer-count-4 = [] +reassembly-buffer-count-8 = [] +reassembly-buffer-count-16 = [] +reassembly-buffer-count-32 = [] + +dns-max-result-count-1 = [] # Default +dns-max-result-count-2 = [] +dns-max-result-count-3 = [] +dns-max-result-count-4 = [] +dns-max-result-count-8 = [] +dns-max-result-count-16 = [] +dns-max-result-count-32 = [] + +dns-max-server-count-1 = [] # Default +dns-max-server-count-2 = [] +dns-max-server-count-3 = [] +dns-max-server-count-4 = [] +dns-max-server-count-8 = [] +dns-max-server-count-16 = [] +dns-max-server-count-32 = [] + +dns-max-name-size-64 = [] +dns-max-name-size-128 = [] +dns-max-name-size-255 = [] # Default + +rpl-relations-buffer-count-1 = [] +rpl-relations-buffer-count-2 = [] +rpl-relations-buffer-count-4 = [] +rpl-relations-buffer-count-8 = [] +rpl-relations-buffer-count-16 = [] # Default +rpl-relations-buffer-count-32 = [] +rpl-relations-buffer-count-64 = [] +rpl-relations-buffer-count-128 = [] + +rpl-parents-buffer-count-2 = [] +rpl-parents-buffer-count-4 = [] +rpl-parents-buffer-count-8 = [] # Default +rpl-parents-buffer-count-16 = [] +rpl-parents-buffer-count-32 = [] + +# END AUTOGENERATED CONFIG FEATURES [[example]] name = "packet2pcap" @@ -64,35 +245,47 @@ required-features = ["std", "phy-raw_socket", "proto-ipv4"] [[example]] name = "httpclient" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "proto-ipv6", "socket-tcp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-ipv6", "socket-tcp"] [[example]] name = "ping" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "proto-ipv6", "socket-icmp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-ipv6", "socket-icmp"] [[example]] name = "server" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] [[example]] name = "client" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-tcp", "socket-udp"] [[example]] name = "loopback" -required-features = ["log", "proto-ipv4", "socket-tcp"] +required-features = ["log", "medium-ethernet", "proto-ipv4", "socket-tcp"] [[example]] name = "multicast" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "proto-igmp", "socket-udp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-igmp", "socket-udp"] [[example]] name = "benchmark" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "socket-raw", "socket-udp"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-raw", "socket-udp"] [[example]] name = "dhcp_client" -required-features = ["std", "phy-tap_interface", "proto-ipv4", "proto-dhcpv4", "socket-raw"] +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "proto-dhcpv4", "socket-raw"] + +[[example]] +name = "sixlowpan" +required-features = ["std", "medium-ieee802154", "phy-raw_socket", "proto-sixlowpan", "proto-sixlowpan-fragmentation", "socket-udp"] + +[[example]] +name = "sixlowpan_benchmark" +required-features = ["std", "medium-ieee802154", "phy-raw_socket", "proto-sixlowpan", "proto-sixlowpan-fragmentation", "socket-udp"] + +[[example]] +name = "dns" +required-features = ["std", "medium-ethernet", "medium-ip", "phy-tuntap_interface", "proto-ipv4", "socket-dns"] [profile.release] debug = 2 diff --git a/README.md b/README.md index bc2cecc8b..65f7a29eb 100644 --- a/README.md +++ b/README.md @@ -1,12 +1,18 @@ # smoltcp +[![docs.rs](https://docs.rs/smoltcp/badge.svg)](https://docs.rs/smoltcp) +[![crates.io](https://img.shields.io/crates/v/smoltcp.svg)](https://crates.io/crates/smoltcp) +[![crates.io](https://img.shields.io/crates/d/smoltcp.svg)](https://crates.io/crates/smoltcp) +[![crates.io](https://img.shields.io/matrix/smoltcp:matrix.org)](https://matrix.to/#/#smoltcp:matrix.org) +[![codecov](https://codecov.io/github/smoltcp-rs/smoltcp/branch/master/graph/badge.svg?token=3KbAR9xH1t)](https://codecov.io/github/smoltcp-rs/smoltcp) + _smoltcp_ is a standalone, event-driven TCP/IP stack that is designed for bare-metal, real-time systems. Its design goals are simplicity and robustness. Its design anti-goals include complicated compile-time computations, such as macro or type tricks, even at cost of performance degradation. _smoltcp_ does not need heap allocation *at all*, is [extensively documented][docs], -and compiles on stable Rust 1.28 and later. +and compiles on stable Rust 1.65 and later. _smoltcp_ achieves [~Gbps of throughput](#examplesbenchmarkrs) when tested against the Linux TCP stack in loopback mode. @@ -20,8 +26,9 @@ To set expectations right, both implemented and omitted features are listed. ### Media layer -The only supported medium is Ethernet. +There are 3 supported mediums. +* Ethernet * Regular Ethernet II frames are supported. * Unicast, broadcast and multicast packets are supported. * ARP packets (including gratuitous requests and replies) are supported. @@ -29,6 +36,11 @@ The only supported medium is Ethernet. * Cached ARP entries expire after one minute. * 802.3 frames and 802.1Q are **not** supported. * Jumbo frames are **not** supported. +* IP + * Unicast, broadcast and multicast packets are supported. +* IEEE 802.15.4 + 6LoWPAN (experimental) + * Unicast, broadcast and multicast packets are supported. + * ONLY UDP packets are supported. ### IP layer @@ -38,7 +50,7 @@ The only supported medium is Ethernet. * IPv4 time-to-live value is configurable per socket, set to 64 by default. * IPv4 default gateway is supported. * Routing outgoing IPv4 packets is supported, through a default gateway or a CIDR route table. - * IPv4 fragmentation is **not** supported. + * IPv4 fragmentation and reassembly is supported. * IPv4 options are **not** supported and are silently ignored. #### IPv6 @@ -75,7 +87,7 @@ The ICMPv4 protocol is supported, and ICMP sockets are available. #### ICMPv6 -The ICMPv6 protocol is supported, but is **not** available via ICMP sockets. +The ICMPv6 protocol is supported, and ICMP sockets are available. * ICMPv6 header checksum is supported. * ICMPv6 echo replies are generated in response to echo requests. @@ -106,13 +118,13 @@ The TCP protocol is supported over IPv4 and IPv6, and server and client TCP sock * Multiple packets are transmitted without waiting for an acknowledgement. * Reassembly of out-of-order segments is supported, with no more than 4 or 32 gaps in sequence space. * Keep-alive packets may be sent at a configurable interval. - * Retransmission timeout starts at a fixed interval of 100 ms and doubles every time. + * Retransmission timeout starts at at an estimate of RTT, and doubles every time. * Time-wait timeout has a fixed interval of 10 s. * User timeout has a configurable interval. + * Delayed acknowledgements are supported, with configurable delay. + * Nagle's algorithm is implemented. * Selective acknowledgements are **not** implemented. - * Delayed acknowledgements are **not** implemented. * Silly window syndrome avoidance is **not** implemented. - * Nagle's algorithm is **not** implemented. * Congestion control is **not** implemented. * Timestamping is **not** supported. * Urgent pointer is **ignored**. @@ -125,7 +137,7 @@ To use the _smoltcp_ library in your project, add the following to `Cargo.toml`: ```toml [dependencies] -smoltcp = "0.5" +smoltcp = "0.8.0" ``` The default configuration assumes a hosted environment, for ease of evaluation. @@ -133,9 +145,11 @@ You probably want to disable default features and configure them one by one: ```toml [dependencies] -smoltcp = { version = "0.5", default-features = false, features = ["log"] } +smoltcp = { version = "0.8.0", default-features = false, features = ["log"] } ``` +## Feature flags + ### Feature `std` The `std` feature enables use of objects and slices owned by the networking stack through a @@ -161,6 +175,14 @@ the DEBUG log level. This feature is enabled by default. +### Feature `defmt` + +The `defmt` feature enables logging of events with the [defmt crate][defmt]. + +[defmt]: https://crates.io/crates/defmt + +This feature is disabled by default, and cannot be used at the same time as `log`. + ### Feature `verbose` The `verbose` feature enables logging of events where the logging itself may incur very high @@ -170,16 +192,15 @@ or `BufWriter` is used, which are of course not available on heap-less systems. This feature is disabled by default. -### Features `phy-raw_socket` and `phy-tap_interface` +### Features `phy-raw_socket` and `phy-tuntap_interface` -Enable `smoltcp::phy::RawSocket` and `smoltcp::phy::TapInterface`, respectively. +Enable `smoltcp::phy::RawSocket` and `smoltcp::phy::TunTapInterface`, respectively. These features are enabled by default. -### Features `socket-raw`, `socket-udp`, and `socket-tcp` +### Features `socket-raw`, `socket-udp`, `socket-tcp`, `socket-icmp`, `socket-dhcpv4`, `socket-dns` -Enable `smoltcp::socket::RawSocket`, `smoltcp::socket::UdpSocket`, -and `smoltcp::socket::TcpSocket`, respectively. +Enable the corresponding socket type. These features are enabled by default. @@ -190,13 +211,80 @@ Enable [IPv4] and [IPv6] respectively. [IPv4]: https://tools.ietf.org/rfc/rfc791.txt [IPv6]: https://tools.ietf.org/rfc/rfc8200.txt +## Configuration + +_smoltcp_ has some configuration settings that are set at compile time, affecting sizes +and counts of buffers. + +They can be set in two ways: + +- Via Cargo features: enable a feature like `-`. `name` must be in lowercase and +use dashes instead of underscores. For example. `iface-max-addr-count-3`. Only a selection of values +is available, check `Cargo.toml` for the list. +- Via environment variables at build time: set the variable named `SMOLTCP_`. For example +`SMOLTCP_IFACE_MAX_ADDR_COUNT=3 cargo build`. You can also set them in the `[env]` section of `.cargo/config.toml`. +Any value can be set, unlike with Cargo features. + +Environment variables take precedence over Cargo features. If two Cargo features are enabled for the same setting +with different values, compilation fails. + +### `IFACE_MAX_ADDR_COUNT` + +Max amount of IP addresses that can be assigned to one interface (counting both IPv4 and IPv6 addresses). Default: 2. + +### `IFACE_MAX_MULTICAST_GROUP_COUNT` + +Max amount of multicast groups that can be joined by one interface. Default: 4. + +### `IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT` + +Max amount of 6LoWPAN address contexts that can be assigned to one interface. Default: 4. + +### `IFACE_NEIGHBOR_CACHE_COUNT` + +Amount of "IP address -> hardware address" entries the neighbor cache (also known as the "ARP cache" or the "ARP table") holds. Default: 4. + +### `IFACE_MAX_ROUTE_COUNT` + +Max amount of routes that can be added to one interface. Includes the default route. Includes both IPv4 and IPv6. Default: 2. + +### `FRAGMENTATION_BUFFER_SIZE` + +Size of the buffer used for fragmenting outgoing packets larger than the MTU. Packets larger than this setting will be dropped instead of fragmented. Default: 1500. + +### `ASSEMBLER_MAX_SEGMENT_COUNT` + +Maximum number of non-contiguous segments the assembler can hold. Used for both packet reassembly and TCP stream reassembly. Default: 4. + +### `REASSEMBLY_BUFFER_SIZE` + +Size of the buffer used for reassembling (de-fragmenting) incoming packets. If the reassembled packet is larger than this setting, it will be dropped instead of reassembled. Default: 1500. + +### `REASSEMBLY_BUFFER_COUNT` + +Number of reassembly buffers, i.e how many different incoming packets can be reassembled at the same time. Default: 1. + +### `DNS_MAX_RESULT_COUNT` + +Maximum amount of address results for a given DNS query that will be kept. For example, if this is set to 2 and the queried name has 4 `A` records, only the first 2 will be returned. Default: 1. + +### `DNS_MAX_SERVER_COUNT` + +Maximum amount of DNS servers that can be configured in one DNS socket. Default: 1. + +### `DNS_MAX_NAME_SIZE` + +Maximum length of DNS names that can be queried. Default: 255. + + + ## Hosted usage examples _smoltcp_, being a freestanding networking stack, needs to be able to transmit and receive raw frames. For testing purposes, we will use a regular OS, and run _smoltcp_ in a userspace process. Only Linux is supported (right now). -On \*nix OSes, transmiting and receiving raw frames normally requires superuser privileges, but +On \*nix OSes, transmitting and receiving raw frames normally requires superuser privileges, but on Linux it is possible to create a _persistent tap interface_ that can be manipulated by a specific user: @@ -217,6 +305,53 @@ sudo iptables -t nat -A POSTROUTING -s 192.168.69.0/24 -j MASQUERADE sudo sysctl net.ipv4.ip_forward=1 sudo ip6tables -t nat -A POSTROUTING -s fdaa::/64 -j MASQUERADE sudo sysctl -w net.ipv6.conf.all.forwarding=1 + +# Some distros have a default policy of DROP. This allows the traffic. +sudo iptables -A FORWARD -i tap0 -s 192.168.69.0/24 -j ACCEPT +sudo iptables -A FORWARD -o tap0 -d 192.168.69.0/24 -j ACCEPT +``` + +### Bridged connection + +Instead of the routed connection above, you may also set up a bridged (switched) +connection. This will make smoltcp speak directly to your LAN, with real ARP, etc. +It is needed to run the DHCP example. + +NOTE: In this case, the examples' IP configuration must match your LAN's! + +NOTE: this ONLY works with actual wired Ethernet connections. It +will NOT work on a WiFi connection. + +```sh +# Replace with your wired Ethernet interface name +ETH=enp0s20f0u1u1 + +sudo modprobe bridge +sudo modprobe br_netfilter + +sudo sysctl -w net.bridge.bridge-nf-call-arptables=0 +sudo sysctl -w net.bridge.bridge-nf-call-ip6tables=0 +sudo sysctl -w net.bridge.bridge-nf-call-iptables=0 + +sudo ip tuntap add name tap0 mode tap user $USER +sudo brctl addbr br0 +sudo brctl addif br0 tap0 +sudo brctl addif br0 $ETH +sudo ip link set tap0 up +sudo ip link set $ETH up +sudo ip link set br0 up + +# This connects your host system to the internet, so you can use it +# at the same time you run the examples. +sudo dhcpcd br0 +``` + +To tear down: + +``` +sudo killall dhcpcd +sudo ip link set br0 down +sudo brctl delbr br0 ``` ### Fault injection @@ -270,19 +405,19 @@ The host is assigned the hardware address `02-00-00-00-00-02`, IPv4 address `192 Read its [source code](/examples/httpclient.rs), then run it as: ```sh -cargo run --example httpclient -- tap0 ADDRESS URL +cargo run --example httpclient -- --tap tap0 ADDRESS URL ``` For example: ```sh -cargo run --example httpclient -- tap0 93.184.216.34 http://example.org/ +cargo run --example httpclient -- --tap tap0 93.184.216.34 http://example.org/ ``` or: ```sh -cargo run --example httpclient -- tap0 2606:2800:220:1:248:1893:25c8:1946 http://example.org/ +cargo run --example httpclient -- --tap tap0 2606:2800:220:1:248:1893:25c8:1946 http://example.org/ ``` It connects to the given address (not a hostname) and URL, and prints any returned response data. @@ -297,7 +432,7 @@ The host is assigned the hardware address `02-00-00-00-00-02` and IPv4 address ` Read its [source code](/examples/ping.rs), then run it as: ```sh -cargo run --example ping -- tap0 ADDRESS +cargo run --example ping -- --tap tap0 ADDRESS ``` It sends a series of 4 ICMP ECHO\_REQUEST packets to the given address at one second intervals and @@ -319,14 +454,14 @@ The host is assigned the hardware address `02-00-00-00-00-01` and IPv4 address ` Read its [source code](/examples/server.rs), then run it as: ```sh -cargo run --example server -- tap0 +cargo run --example server -- --tap tap0 ``` It responds to: * pings (`ping 192.168.69.1`); * UDP packets on port 6969 (`socat stdio udp4-connect:192.168.69.1:6969 <<<"abcdefg"`), - where it will respond "hello" to any incoming packet; + where it will respond with reversed chunks of the input indefinitely; * TCP connections on port 6969 (`socat stdio tcp4-connect:192.168.69.1:6969`), where it will respond "hello" to any incoming connection and immediately close it; * TCP connections on port 6970 (`socat stdio tcp4-connect:192.168.69.1:6970 <<<"abcdefg"`), @@ -349,7 +484,7 @@ The host is assigned the hardware address `02-00-00-00-00-02` and IPv4 address ` Read its [source code](/examples/client.rs), then run it as: ```sh -cargo run --example client -- tap0 ADDRESS PORT +cargo run --example client -- --tap tap0 ADDRESS PORT ``` It connects to the given address (not a hostname) and port (e.g. `socat stdio tcp4-listen:1234`), @@ -362,7 +497,7 @@ _examples/benchmark.rs_ implements a simple throughput benchmark. Read its [source code](/examples/benchmark.rs), then run it as: ```sh -cargo run --release --example benchmark -- tap0 [reader|writer] +cargo run --release --example benchmark -- --tap tap0 [reader|writer] ``` It establishes a connection to itself from a different thread and reads or writes a large amount @@ -372,9 +507,9 @@ A typical result (achieved on a Intel Core i7-7500U CPU and a Linux 4.9.65 x86_6 on a Dell XPS 13 9360 laptop) is as follows: ``` -$ cargo run -q --release --example benchmark tap0 reader +$ cargo run -q --release --example benchmark -- --tap tap0 reader throughput: 2.556 Gbps -$ cargo run -q --release --example benchmark tap0 writer +$ cargo run -q --release --example benchmark -- --tap tap0 writer throughput: 5.301 Gbps ``` @@ -391,7 +526,7 @@ Although it does not require `std`, this example still requires the `alloc` feat Read its [source code](/examples/loopback.rs), then run it without `std`: ```sh -cargo run --example loopback --no-default-features --features="log proto-ipv4 socket-tcp alloc" +cargo run --example loopback --no-default-features --features="log proto-ipv4 socket-tcp alloc" ``` ... or with `std` (in this case the features don't have to be explicitly listed): diff --git a/benches/bench.rs b/benches/bench.rs index 4aed8ce22..2738840c1 100644 --- a/benches/bench.rs +++ b/benches/bench.rs @@ -1,25 +1,25 @@ #![feature(test)] -extern crate test; -extern crate smoltcp; - mod wire { - use test; - #[cfg(feature = "proto-ipv6")] - use smoltcp::wire::{Ipv6Address, Ipv6Repr, Ipv6Packet}; - #[cfg(feature = "proto-ipv4")] - use smoltcp::wire::{Ipv4Address, Ipv4Repr, Ipv4Packet}; - use smoltcp::phy::{ChecksumCapabilities}; + use smoltcp::phy::ChecksumCapabilities; use smoltcp::wire::{IpAddress, IpProtocol}; - use smoltcp::wire::{TcpRepr, TcpPacket, TcpSeqNumber, TcpControl}; - use smoltcp::wire::{UdpRepr, UdpPacket}; + #[cfg(feature = "proto-ipv4")] + use smoltcp::wire::{Ipv4Address, Ipv4Packet, Ipv4Repr}; + #[cfg(feature = "proto-ipv6")] + use smoltcp::wire::{Ipv6Address, Ipv6Packet, Ipv6Repr}; + use smoltcp::wire::{TcpControl, TcpPacket, TcpRepr, TcpSeqNumber}; + use smoltcp::wire::{UdpPacket, UdpRepr}; + + extern crate test; #[cfg(feature = "proto-ipv6")] - const SRC_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1])); + const SRC_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ])); #[cfg(feature = "proto-ipv6")] - const DST_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 2])); + const DST_ADDR: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ])); #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] const SRC_ADDR: IpAddress = IpAddress::Ipv4(Ipv4Address([192, 168, 1, 1])); @@ -29,42 +29,53 @@ mod wire { #[bench] #[cfg(any(feature = "proto-ipv6", feature = "proto-ipv4"))] fn bench_emit_tcp(b: &mut test::Bencher) { - static PAYLOAD_BYTES: [u8; 400] = - [0x2a; 400]; + static PAYLOAD_BYTES: [u8; 400] = [0x2a; 400]; let repr = TcpRepr { - src_port: 48896, - dst_port: 80, - seq_number: TcpSeqNumber(0x01234567), - ack_number: None, - window_len: 0x0123, - control: TcpControl::Syn, - max_seg_size: None, + src_port: 48896, + dst_port: 80, + control: TcpControl::Syn, + seq_number: TcpSeqNumber(0x01234567), + ack_number: None, + window_len: 0x0123, window_scale: None, - payload: &PAYLOAD_BYTES + max_seg_size: None, + sack_permitted: false, + sack_ranges: [None, None, None], + payload: &PAYLOAD_BYTES, }; let mut bytes = vec![0xa5; repr.buffer_len()]; b.iter(|| { - let mut packet = TcpPacket::new(&mut bytes); - repr.emit(&mut packet, &SRC_ADDR, &DST_ADDR, &ChecksumCapabilities::default()); + let mut packet = TcpPacket::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR, + &DST_ADDR, + &ChecksumCapabilities::default(), + ); }); } #[bench] #[cfg(any(feature = "proto-ipv6", feature = "proto-ipv4"))] fn bench_emit_udp(b: &mut test::Bencher) { - static PAYLOAD_BYTES: [u8; 400] = - [0x2a; 400]; + static PAYLOAD_BYTES: [u8; 400] = [0x2a; 400]; let repr = UdpRepr { src_port: 48896, dst_port: 80, - payload: &PAYLOAD_BYTES }; - let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()]; b.iter(|| { - let mut packet = UdpPacket::new(&mut bytes); - repr.emit(&mut packet, &SRC_ADDR, &DST_ADDR, &ChecksumCapabilities::default()); + let mut packet = UdpPacket::new_unchecked(&mut bytes); + repr.emit( + &mut packet, + &SRC_ADDR, + &DST_ADDR, + PAYLOAD_BYTES.len(), + |buf| buf.copy_from_slice(&PAYLOAD_BYTES), + &ChecksumCapabilities::default(), + ); }); } @@ -72,16 +83,16 @@ mod wire { #[cfg(feature = "proto-ipv4")] fn bench_emit_ipv4(b: &mut test::Bencher) { let repr = Ipv4Repr { - src_addr: Ipv4Address([192, 168, 1, 1]), - dst_addr: Ipv4Address([192, 168, 1, 2]), - protocol: IpProtocol::Tcp, + src_addr: Ipv4Address([192, 168, 1, 1]), + dst_addr: Ipv4Address([192, 168, 1, 2]), + next_header: IpProtocol::Tcp, payload_len: 100, - hop_limit: 64 + hop_limit: 64, }; let mut bytes = vec![0xa5; repr.buffer_len()]; b.iter(|| { - let mut packet = Ipv4Packet::new(&mut bytes); + let mut packet = Ipv4Packet::new_unchecked(&mut bytes); repr.emit(&mut packet, &ChecksumCapabilities::default()); }); } @@ -90,18 +101,16 @@ mod wire { #[cfg(feature = "proto-ipv6")] fn bench_emit_ipv6(b: &mut test::Bencher) { let repr = Ipv6Repr { - src_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1]), - dst_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 2]), + src_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + dst_addr: Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]), next_header: IpProtocol::Tcp, payload_len: 100, - hop_limit: 64 + hop_limit: 64, }; let mut bytes = vec![0xa5; repr.buffer_len()]; b.iter(|| { - let mut packet = Ipv6Packet::new(&mut bytes); + let mut packet = Ipv6Packet::new_unchecked(&mut bytes); repr.emit(&mut packet); }); } diff --git a/build.rs b/build.rs new file mode 100644 index 000000000..568713925 --- /dev/null +++ b/build.rs @@ -0,0 +1,103 @@ +use std::collections::HashMap; +use std::fmt::Write; +use std::path::PathBuf; +use std::{env, fs}; + +static CONFIGS: &[(&str, usize)] = &[ + // BEGIN AUTOGENERATED CONFIG FEATURES + // Generated by gen_config.py. DO NOT EDIT. + ("IFACE_MAX_ADDR_COUNT", 2), + ("IFACE_MAX_MULTICAST_GROUP_COUNT", 4), + ("IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT", 4), + ("IFACE_NEIGHBOR_CACHE_COUNT", 4), + ("IFACE_MAX_ROUTE_COUNT", 2), + ("FRAGMENTATION_BUFFER_SIZE", 1500), + ("ASSEMBLER_MAX_SEGMENT_COUNT", 4), + ("REASSEMBLY_BUFFER_SIZE", 1500), + ("REASSEMBLY_BUFFER_COUNT", 1), + ("DNS_MAX_RESULT_COUNT", 1), + ("DNS_MAX_SERVER_COUNT", 1), + ("DNS_MAX_NAME_SIZE", 255), + ("RPL_RELATIONS_BUFFER_COUNT", 16), + ("RPL_PARENTS_BUFFER_COUNT", 8), + // END AUTOGENERATED CONFIG FEATURES +]; + +struct ConfigState { + value: usize, + seen_feature: bool, + seen_env: bool, +} + +fn main() { + // only rebuild if build.rs changed. Otherwise Cargo will rebuild if any + // other file changed. + println!("cargo:rerun-if-changed=build.rs"); + + // Rebuild if config envvar changed. + for (name, _) in CONFIGS { + println!("cargo:rerun-if-env-changed=SMOLTCP_{name}"); + } + + let mut configs = HashMap::new(); + for (name, default) in CONFIGS { + configs.insert( + *name, + ConfigState { + value: *default, + seen_env: false, + seen_feature: false, + }, + ); + } + + for (var, value) in env::vars() { + if let Some(name) = var.strip_prefix("SMOLTCP_") { + let Some(cfg) = configs.get_mut(name) else { + panic!("Unknown env var {name}") + }; + + let Ok(value) = value.parse::() else { + panic!("Invalid value for env var {name}: {value}") + }; + + cfg.value = value; + cfg.seen_env = true; + } + + if let Some(feature) = var.strip_prefix("CARGO_FEATURE_") { + if let Some(i) = feature.rfind('_') { + let name = &feature[..i]; + let value = &feature[i + 1..]; + if let Some(cfg) = configs.get_mut(name) { + let Ok(value) = value.parse::() else { + panic!("Invalid value for feature {name}: {value}") + }; + + // envvars take priority. + if !cfg.seen_env { + if cfg.seen_feature { + panic!( + "multiple values set for feature {}: {} and {}", + name, cfg.value, value + ); + } + + cfg.value = value; + cfg.seen_feature = true; + } + } + } + } + } + + let mut data = String::new(); + + for (name, cfg) in &configs { + writeln!(&mut data, "pub const {}: usize = {};", name, cfg.value).unwrap(); + } + + let out_dir = PathBuf::from(env::var_os("OUT_DIR").unwrap()); + let out_file = out_dir.join("config.rs").to_string_lossy().to_string(); + fs::write(out_file, data).unwrap(); +} diff --git a/ci.sh b/ci.sh new file mode 100755 index 000000000..3c45f5d7c --- /dev/null +++ b/ci.sh @@ -0,0 +1,118 @@ +#!/usr/bin/env bash + +set -eox pipefail + +export DEFMT_LOG=trace + +MSRV="1.65.0" + +RUSTC_VERSIONS=( + $MSRV + "stable" + "nightly" +) + +FEATURES_TEST=( + "default" + "std,proto-ipv4" + "std,medium-ethernet,phy-raw_socket,proto-ipv6,socket-udp,socket-dns" + "std,medium-ethernet,phy-tuntap_interface,proto-ipv6,socket-udp" + "std,medium-ethernet,proto-ipv4,proto-ipv4-fragmentation,socket-raw,socket-dns" + "std,medium-ethernet,proto-ipv4,proto-igmp,socket-raw,socket-dns" + "std,medium-ethernet,proto-ipv4,socket-udp,socket-tcp,socket-dns" + "std,medium-ethernet,proto-ipv4,proto-dhcpv4,socket-udp" + "std,medium-ethernet,medium-ip,medium-ieee802154,proto-ipv6,socket-udp,socket-dns" + "std,medium-ethernet,proto-ipv6,socket-tcp" + "std,medium-ethernet,medium-ip,proto-ipv4,socket-icmp,socket-tcp" + "std,medium-ip,proto-ipv6,socket-icmp,socket-tcp" + "std,medium-ieee802154,proto-sixlowpan,socket-udp" + "std,medium-ieee802154,proto-sixlowpan,proto-sixlowpan-fragmentation,socket-udp" + "std,medium-ieee802154,proto-rpl,proto-sixlowpan,proto-sixlowpan-fragmentation,socket-udp" + "std,medium-ip,proto-ipv4,proto-ipv6,socket-tcp,socket-udp" + "std,medium-ethernet,medium-ip,medium-ieee802154,proto-ipv4,proto-ipv6,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" +) + +FEATURES_TEST_NIGHTLY=( + "alloc,medium-ethernet,proto-ipv4,proto-ipv6,socket-raw,socket-udp,socket-tcp,socket-icmp" +) + +FEATURES_CHECK=( + "medium-ip,medium-ethernet,medium-ieee802154,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" + "defmt,medium-ip,medium-ethernet,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" + "defmt,alloc,medium-ip,medium-ethernet,proto-ipv6,proto-ipv6,proto-igmp,proto-dhcpv4,socket-raw,socket-udp,socket-tcp,socket-icmp,socket-dns,async" +) + +test() { + local version=$1 + rustup toolchain install $version + + for features in ${FEATURES_TEST[@]}; do + cargo +$version test --no-default-features --features "$features" + done + + if [[ $version == "nightly" ]]; then + for features in ${FEATURES_TEST_NIGHTLY[@]}; do + cargo +$version test --no-default-features --features "$features" + done + fi +} + +check() { + local version=$1 + rustup toolchain install $version + + export DEFMT_LOG="trace" + + for features in ${FEATURES_CHECK[@]}; do + cargo +$version check --no-default-features --features "$features" + done +} + +clippy() { + rustup toolchain install $MSRV + rustup component add clippy --toolchain=$MSRV + cargo +$MSRV clippy --tests --examples -- -D warnings +} + +coverage() { + for features in ${FEATURES_TEST[@]}; do + cargo llvm-cov --no-report --no-default-features --features "$features" + done + cargo llvm-cov report --lcov --output-path lcov.info +} + +if [[ $1 == "test" || $1 == "all" ]]; then + if [[ -n $2 ]]; then + if [[ $2 == "msrv" ]]; then + test $MSRV + else + test $2 + fi + else + for version in ${RUSTC_VERSIONS[@]}; do + test $version + done + fi +fi + +if [[ $1 == "check" || $1 == "all" ]]; then + if [[ -n $2 ]]; then + if [[ $2 == "msrv" ]]; then + check $MSRV + else + check $2 + fi + else + for version in ${RUSTC_VERSIONS[@]}; do + check $version + done + fi +fi + +if [[ $1 == "clippy" || $1 == "all" ]]; then + clippy +fi + +if [[ $1 == "coverage" || $1 == "all" ]]; then + coverage +fi diff --git a/examples/benchmark.rs b/examples/benchmark.rs index 3e78ac234..ad2c6e142 100644 --- a/examples/benchmark.rs +++ b/examples/benchmark.rs @@ -1,33 +1,32 @@ -#[cfg(feature = "log")] -#[macro_use] -extern crate log; -#[cfg(feature = "log")] -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; +#![allow(clippy::collapsible_if)] mod utils; use std::cmp; -use std::collections::BTreeMap; -use std::sync::atomic::{Ordering, AtomicBool}; -use std::thread; use std::io::{Read, Write}; use std::net::TcpStream; use std::os::unix::io::AsRawFd; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder}; -use smoltcp::socket::SocketSet; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer}; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; const AMOUNT: usize = 1_000_000_000; -enum Client { Reader, Writer } +enum Client { + Reader, + Writer, +} fn client(kind: Client) { - let port = match kind { Client::Reader => 1234, Client::Writer => 1235 }; + let port = match kind { + Client::Reader => 1234, + Client::Writer => 1235, + }; let mut stream = TcpStream::connect(("192.168.69.1", port)).unwrap(); let mut buffer = vec![0; 1_000_000]; @@ -46,7 +45,7 @@ fn client(kind: Client) { // print!("(P:{})", result); processed += result } - Err(err) => panic!("cannot process: {}", err) + Err(err) => panic!("cannot process: {err}"), } } @@ -66,96 +65,96 @@ fn main() { utils::setup_logging("info"); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); free.push("MODE"); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); let mode = match matches.free[0].as_ref() { "reader" => Client::Reader, "writer" => Client::Writer, - _ => panic!("invalid mode") + _ => panic!("invalid mode"), }; - thread::spawn(move || client(mode)); + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); - let neighbor_cache = NeighborCache::new(BTreeMap::new()); + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); - let tcp1_rx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp1_tx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp1_socket = TcpSocket::new(tcp1_rx_buffer, tcp1_tx_buffer); - - let tcp2_rx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); - let ip_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)]; - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + }); let mut sockets = SocketSet::new(vec![]); let tcp1_handle = sockets.add(tcp1_socket); let tcp2_handle = sockets.add(tcp2_socket); let default_timeout = Some(Duration::from_millis(1000)); + thread::spawn(move || client(mode)); let mut processed = 0; while !CLIENT_DONE.load(Ordering::SeqCst) { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}",e); - } - } - + iface.poll(timestamp, &mut device, &mut sockets); // tcp:1234: emit data - { - let mut socket = sockets.get::(tcp1_handle); - if !socket.is_open() { - socket.listen(1234).unwrap(); - } + let socket = sockets.get_mut::(tcp1_handle); + if !socket.is_open() { + socket.listen(1234).unwrap(); + } - if socket.can_send() { - if processed < AMOUNT { - let length = socket.send(|buffer| { + if socket.can_send() { + if processed < AMOUNT { + let length = socket + .send(|buffer| { let length = cmp::min(buffer.len(), AMOUNT - processed); (length, length) - }).unwrap(); - processed += length; - } + }) + .unwrap(); + processed += length; } } // tcp:1235: sink data - { - let mut socket = sockets.get::(tcp2_handle); - if !socket.is_open() { - socket.listen(1235).unwrap(); - } + let socket = sockets.get_mut::(tcp2_handle); + if !socket.is_open() { + socket.listen(1235).unwrap(); + } - if socket.can_recv() { - if processed < AMOUNT { - let length = socket.recv(|buffer| { + if socket.can_recv() { + if processed < AMOUNT { + let length = socket + .recv(|buffer| { let length = cmp::min(buffer.len(), AMOUNT - processed); (length, length) - }).unwrap(); - processed += length; - } + }) + .unwrap(); + processed += length; } } - match iface.poll_at(&sockets, timestamp) { + match iface.poll_at(timestamp, &sockets) { Some(poll_at) if timestamp < poll_at => { phy_wait(fd, Some(poll_at - timestamp)).expect("wait error"); - }, + } Some(_) => (), None => { phy_wait(fd, default_timeout).expect("wait error"); diff --git a/examples/client.rs b/examples/client.rs index a9aaf653d..c18c08ff7 100644 --- a/examples/client.rs +++ b/examples/client.rs @@ -1,106 +1,118 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; - mod utils; -use std::str::{self, FromStr}; -use std::collections::BTreeMap; +use log::debug; use std::os::unix::io::AsRawFd; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, Ipv4Address, IpAddress, IpCidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder, Routes}; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; +use std::str::{self, FromStr}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; fn main() { utils::setup_logging(""); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); free.push("ADDRESS"); free.push("PORT"); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); let address = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); let port = u16::from_str(&matches.free[1]).expect("invalid port format"); - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - - let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 64]); - let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 128]); - let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]); - let ip_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 2), 24)]; - let default_v4_gw = Ipv4Address::new(192, 168, 69, 100); - let mut routes_storage = [None; 1]; - let mut routes = Routes::new(&mut routes_storage[..]); - routes.add_default_ipv4_route(default_v4_gw).unwrap(); - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .routes(routes) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + // Create sockets + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 1500]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 1500]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); let mut sockets = SocketSet::new(vec![]); let tcp_handle = sockets.add(tcp_socket); - { - let mut socket = sockets.get::(tcp_handle); - socket.connect((address, port), 49500).unwrap(); - } + let socket = sockets.get_mut::(tcp_handle); + socket + .connect(iface.context(), (address, port), 49500) + .unwrap(); let mut tcp_active = false; loop { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}", e); - } - } + iface.poll(timestamp, &mut device, &mut sockets); - { - let mut socket = sockets.get::(tcp_handle); - if socket.is_active() && !tcp_active { - debug!("connected"); - } else if !socket.is_active() && tcp_active { - debug!("disconnected"); - break - } - tcp_active = socket.is_active(); + let socket = sockets.get_mut::(tcp_handle); + if socket.is_active() && !tcp_active { + debug!("connected"); + } else if !socket.is_active() && tcp_active { + debug!("disconnected"); + break; + } + tcp_active = socket.is_active(); - if socket.may_recv() { - let data = socket.recv(|data| { + if socket.may_recv() { + let data = socket + .recv(|data| { let mut data = data.to_owned(); - if data.len() > 0 { - debug!("recv data: {:?}", - str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")); + if !data.is_empty() { + debug!( + "recv data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); data = data.split(|&b| b == b'\n').collect::>().concat(); data.reverse(); data.extend(b"\n"); } (data.len(), data) - }).unwrap(); - if socket.can_send() && data.len() > 0 { - debug!("send data: {:?}", - str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")); - socket.send_slice(&data[..]).unwrap(); - } - } else if socket.may_send() { - debug!("close"); - socket.close(); + }) + .unwrap(); + if socket.can_send() && !data.is_empty() { + debug!( + "send data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + socket.send_slice(&data[..]).unwrap(); } + } else if socket.may_send() { + debug!("close"); + socket.close(); } - phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error"); + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); } } diff --git a/examples/dhcp_client.rs b/examples/dhcp_client.rs index 0695c40e1..9ef46c27b 100644 --- a/examples/dhcp_client.rs +++ b/examples/dhcp_client.rs @@ -1,106 +1,94 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; - +#![allow(clippy::option_map_unit_fn)] mod utils; -use std::collections::BTreeMap; +use log::*; use std::os::unix::io::AsRawFd; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, Ipv4Address, IpCidr, Ipv4Cidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder, Routes}; -use smoltcp::socket::{SocketSet, RawSocketBuffer, RawPacketMetadata}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::socket::dhcpv4; use smoltcp::time::Instant; -use smoltcp::dhcp::Dhcpv4Client; +use smoltcp::wire::{EthernetAddress, IpCidr, Ipv4Address, Ipv4Cidr}; +use smoltcp::{ + phy::{wait as phy_wait, Device, Medium}, + time::Duration, +}; fn main() { #[cfg(feature = "log")] utils::setup_logging(""); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); - - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); - let ip_addrs = [IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 0)]; - let mut routes_storage = [None; 1]; - let routes = Routes::new(&mut routes_storage[..]); - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .routes(routes) - .finalize(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + + // Create sockets + let mut dhcp_socket = dhcpv4::Socket::new(); + + // Set a ridiculously short max lease time to show DHCP renews work properly. + // This will cause the DHCP client to start renewing after 5 seconds, and give up the + // lease after 10 seconds if renew hasn't succeeded. + // IMPORTANT: This should be removed in production. + dhcp_socket.set_max_lease_duration(Some(Duration::from_secs(10))); let mut sockets = SocketSet::new(vec![]); - let dhcp_rx_buffer = RawSocketBuffer::new( - [RawPacketMetadata::EMPTY; 1], - vec![0; 900] - ); - let dhcp_tx_buffer = RawSocketBuffer::new( - [RawPacketMetadata::EMPTY; 1], - vec![0; 600] - ); - let mut dhcp = Dhcpv4Client::new(&mut sockets, dhcp_rx_buffer, dhcp_tx_buffer, Instant::now()); - let mut prev_cidr = Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0); + let dhcp_handle = sockets.add(dhcp_socket); + loop { let timestamp = Instant::now(); - iface.poll(&mut sockets, timestamp) - .map(|_| ()) - .unwrap_or_else(|e| println!("Poll: {:?}", e)); - let config = dhcp.poll(&mut iface, &mut sockets, timestamp) - .unwrap_or_else(|e| { - println!("DHCP: {:?}", e); - None - }); - config.map(|config| { - println!("DHCP config: {:?}", config); - match config.address { - Some(cidr) => if cidr != prev_cidr { - iface.update_ip_addrs(|addrs| { - addrs.iter_mut().nth(0) - .map(|addr| { - *addr = IpCidr::Ipv4(cidr); - }); - }); - prev_cidr = cidr; - println!("Assigned a new IPv4 address: {}", cidr); + iface.poll(timestamp, &mut device, &mut sockets); + + let event = sockets.get_mut::(dhcp_handle).poll(); + match event { + None => {} + Some(dhcpv4::Event::Configured(config)) => { + debug!("DHCP config acquired!"); + + debug!("IP address: {}", config.address); + set_ipv4_addr(&mut iface, config.address); + + if let Some(router) = config.router { + debug!("Default gateway: {}", router); + iface.routes_mut().add_default_ipv4_route(router).unwrap(); + } else { + debug!("Default gateway: None"); + iface.routes_mut().remove_default_ipv4_route(); } - _ => {} - } - config.router.map(|router| iface.routes_mut() - .add_default_ipv4_route(router.into()) - .unwrap() - ); - iface.routes_mut() - .update(|routes_map| { - routes_map.get(&IpCidr::new(Ipv4Address::UNSPECIFIED.into(), 0)) - .map(|default_route| { - println!("Default gateway: {}", default_route.via_router); - }); - }); - - if config.dns_servers.iter().any(|s| s.is_some()) { - println!("DNS servers:"); - for dns_server in config.dns_servers.iter().filter_map(|s| *s) { - println!("- {}", dns_server); + for (i, s) in config.dns_servers.iter().enumerate() { + debug!("DNS server {}: {}", i, s); } } - }); + Some(dhcpv4::Event::Deconfigured) => { + debug!("DHCP lost config!"); + set_ipv4_addr(&mut iface, Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)); + iface.routes_mut().remove_default_ipv4_route(); + } + } - let mut timeout = dhcp.next_poll(timestamp); - iface.poll_delay(&sockets, timestamp) - .map(|sockets_timeout| timeout = sockets_timeout); - phy_wait(fd, Some(timeout)) - .unwrap_or_else(|e| println!("Wait: {:?}", e));; + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); } } + +fn set_ipv4_addr(iface: &mut Interface, cidr: Ipv4Cidr) { + iface.update_ip_addrs(|addrs| { + let dest = addrs.iter_mut().next().unwrap(); + *dest = IpCidr::Ipv4(cidr); + }); +} diff --git a/examples/dns.rs b/examples/dns.rs new file mode 100644 index 000000000..977f40546 --- /dev/null +++ b/examples/dns.rs @@ -0,0 +1,92 @@ +mod utils; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::Device; +use smoltcp::phy::{wait as phy_wait, Medium}; +use smoltcp::socket::dns::{self, GetQueryResultError}; +use smoltcp::time::Instant; +use smoltcp::wire::{DnsQueryType, EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; +use std::os::unix::io::AsRawFd; + +fn main() { + utils::setup_logging("warn"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_tuntap_options(&mut opts, &mut free); + utils::add_middleware_options(&mut opts, &mut free); + free.push("ADDRESS"); + + let mut matches = utils::parse_options(&opts, free); + let device = utils::parse_tuntap_options(&mut matches); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + let name = &matches.free[0]; + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let servers = &[ + Ipv4Address::new(8, 8, 4, 4).into(), + Ipv4Address::new(8, 8, 8, 8).into(), + ]; + let dns_socket = dns::Socket::new(servers, vec![]); + + let mut sockets = SocketSet::new(vec![]); + let dns_handle = sockets.add(dns_socket); + + let socket = sockets.get_mut::(dns_handle); + let query = socket + .start_query(iface.context(), name, DnsQueryType::A) + .unwrap(); + + loop { + let timestamp = Instant::now(); + log::debug!("timestamp {:?}", timestamp); + + iface.poll(timestamp, &mut device, &mut sockets); + + match sockets + .get_mut::(dns_handle) + .get_query_result(query) + { + Ok(addrs) => { + println!("Query done: {addrs:?}"); + break; + } + Err(GetQueryResultError::Pending) => {} // not done yet + Err(e) => panic!("query failed: {e:?}"), + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/httpclient.rs b/examples/httpclient.rs index 7ac9ef71d..8f3a53aa7 100644 --- a/examples/httpclient.rs +++ b/examples/httpclient.rs @@ -1,113 +1,123 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate rand; -extern crate url; -extern crate smoltcp; - mod utils; -use std::str::{self, FromStr}; -use std::collections::BTreeMap; +use log::debug; use std::os::unix::io::AsRawFd; +use std::str::{self, FromStr}; use url::Url; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, Ipv4Address, Ipv6Address, IpAddress, IpCidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder, Routes}; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::tcp; use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; fn main() { utils::setup_logging(""); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); free.push("ADDRESS"); free.push("URL"); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); let address = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); let url = Url::parse(&matches.free[1]).expect("invalid url format"); + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - - let tcp_rx_buffer = TcpSocketBuffer::new(vec![0; 1024]); - let tcp_tx_buffer = TcpSocketBuffer::new(vec![0; 1024]); - let tcp_socket = TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]); - let ip_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24), - IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64), - IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)]; - let default_v4_gw = Ipv4Address::new(192, 168, 69, 100); - let default_v6_gw = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100); - let mut routes_storage = [None; 2]; - let mut routes = Routes::new(&mut routes_storage[..]); - routes.add_default_ipv4_route(default_v4_gw).unwrap(); - routes.add_default_ipv6_route(default_v6_gw).unwrap(); - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .routes(routes) - .finalize(); + // Create sockets + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 1024]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 1024]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); let mut sockets = SocketSet::new(vec![]); let tcp_handle = sockets.add(tcp_socket); - enum State { Connect, Request, Response }; + enum State { + Connect, + Request, + Response, + } let mut state = State::Connect; loop { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}",e); - } - } + iface.poll(timestamp, &mut device, &mut sockets); - { - let mut socket = sockets.get::(tcp_handle); + let socket = sockets.get_mut::(tcp_handle); + let cx = iface.context(); - state = match state { - State::Connect if !socket.is_active() => { - debug!("connecting"); - let local_port = 49152 + rand::random::() % 16384; - socket.connect((address, url.port().unwrap_or(80)), local_port).unwrap(); - State::Request - } - State::Request if socket.may_send() => { - debug!("sending request"); - let http_get = "GET ".to_owned() + url.path() + " HTTP/1.1\r\n"; - socket.send_slice(http_get.as_ref()).expect("cannot send"); - let http_host = "Host: ".to_owned() + url.host_str().unwrap() + "\r\n"; - socket.send_slice(http_host.as_ref()).expect("cannot send"); - socket.send_slice(b"Connection: close\r\n").expect("cannot send"); - socket.send_slice(b"\r\n").expect("cannot send"); - State::Response - } - State::Response if socket.can_recv() => { - socket.recv(|data| { + state = match state { + State::Connect if !socket.is_active() => { + debug!("connecting"); + let local_port = 49152 + rand::random::() % 16384; + socket + .connect(cx, (address, url.port().unwrap_or(80)), local_port) + .unwrap(); + State::Request + } + State::Request if socket.may_send() => { + debug!("sending request"); + let http_get = "GET ".to_owned() + url.path() + " HTTP/1.1\r\n"; + socket.send_slice(http_get.as_ref()).expect("cannot send"); + let http_host = "Host: ".to_owned() + url.host_str().unwrap() + "\r\n"; + socket.send_slice(http_host.as_ref()).expect("cannot send"); + socket + .send_slice(b"Connection: close\r\n") + .expect("cannot send"); + socket.send_slice(b"\r\n").expect("cannot send"); + State::Response + } + State::Response if socket.can_recv() => { + socket + .recv(|data| { println!("{}", str::from_utf8(data).unwrap_or("(invalid utf8)")); (data.len(), ()) - }).unwrap(); - State::Response - } - State::Response if !socket.may_recv() => { - debug!("received complete response"); - break - } - _ => state + }) + .unwrap(); + State::Response } - } + State::Response if !socket.may_recv() => { + debug!("received complete response"); + break; + } + _ => state, + }; - phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error"); + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); } } diff --git a/examples/loopback.rs b/examples/loopback.rs index afc33fa20..7ca95b188 100644 --- a/examples/loopback.rs +++ b/examples/loopback.rs @@ -1,33 +1,27 @@ #![cfg_attr(not(feature = "std"), no_std)] #![allow(unused_mut)] - -#[cfg(feature = "std")] -use std as core; -#[macro_use] -extern crate log; -extern crate smoltcp; -#[cfg(feature = "std")] -extern crate env_logger; -#[cfg(feature = "std")] -extern crate getopts; +#![allow(clippy::collapsible_if)] #[cfg(feature = "std")] #[allow(dead_code)] mod utils; use core::str; -use smoltcp::phy::Loopback; -use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder}; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; +use log::{debug, error, info}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{Device, Loopback, Medium}; +use smoltcp::socket::tcp; use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; #[cfg(not(feature = "std"))] mod mock { - use smoltcp::time::{Duration, Instant}; use core::cell::Cell; + use smoltcp::time::{Duration, Instant}; #[derive(Debug)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Clock(Cell); impl Clock { @@ -47,12 +41,13 @@ mod mock { #[cfg(feature = "std")] mod mock { + use smoltcp::time::{Duration, Instant}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::sync::atomic::{Ordering, AtomicUsize}; - use smoltcp::time::{Duration, Instant}; // should be AtomicU64 but that's unstable #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Clock(Arc); impl Clock { @@ -61,7 +56,8 @@ mod mock { } pub fn advance(&self, duration: Duration) { - self.0.fetch_add(duration.total_millis() as usize, Ordering::SeqCst); + self.0 + .fetch_add(duration.total_millis() as usize, Ordering::SeqCst); } pub fn elapsed(&self) -> Instant { @@ -72,10 +68,10 @@ mod mock { fn main() { let clock = mock::Clock::new(); - let device = Loopback::new(); + let device = Loopback::new(Medium::Ethernet); #[cfg(feature = "std")] - let device = { + let mut device = { let clock = clock.clone(); utils::setup_logging_with_clock("", move || clock.elapsed()); @@ -83,21 +79,26 @@ fn main() { utils::add_middleware_options(&mut opts, &mut free); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/true); - - device + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ true) }; - let mut neighbor_cache_entries = [None; 8]; - let mut neighbor_cache = NeighborCache::new(&mut neighbor_cache_entries[..]); + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; - let mut ip_addrs = [IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)]; - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(EthernetAddress::default()) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .finalize(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) + .unwrap(); + }); + // Create sockets let server_socket = { // It is not strictly necessary to use a `static mut` and unsafe code here, but // on embedded systems that smoltcp targets it is far better to allocate the data @@ -105,79 +106,73 @@ fn main() { // when stack overflows. static mut TCP_SERVER_RX_DATA: [u8; 1024] = [0; 1024]; static mut TCP_SERVER_TX_DATA: [u8; 1024] = [0; 1024]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) }; let client_socket = { static mut TCP_CLIENT_RX_DATA: [u8; 1024] = [0; 1024]; static mut TCP_CLIENT_TX_DATA: [u8; 1024] = [0; 1024]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) }; - let mut socket_set_entries: [_; 2] = Default::default(); - let mut socket_set = SocketSet::new(&mut socket_set_entries[..]); - let server_handle = socket_set.add(server_socket); - let client_handle = socket_set.add(client_socket); + let mut sockets: [_; 2] = Default::default(); + let mut sockets = SocketSet::new(&mut sockets[..]); + let server_handle = sockets.add(server_socket); + let client_handle = sockets.add(client_socket); - let mut did_listen = false; + let mut did_listen = false; let mut did_connect = false; let mut done = false; while !done && clock.elapsed() < Instant::from_millis(10_000) { - match iface.poll(&mut socket_set, clock.elapsed()) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}", e); + iface.poll(clock.elapsed(), &mut device, &mut sockets); + + let mut socket = sockets.get_mut::(server_handle); + if !socket.is_active() && !socket.is_listening() { + if !did_listen { + debug!("listening"); + socket.listen(1234).unwrap(); + did_listen = true; } } - { - let mut socket = socket_set.get::(server_handle); - if !socket.is_active() && !socket.is_listening() { - if !did_listen { - debug!("listening"); - socket.listen(1234).unwrap(); - did_listen = true; - } - } - - if socket.can_recv() { - debug!("got {:?}", socket.recv(|buffer| { - (buffer.len(), str::from_utf8(buffer).unwrap()) - })); - socket.close(); - done = true; - } + if socket.can_recv() { + debug!( + "got {:?}", + socket.recv(|buffer| { (buffer.len(), str::from_utf8(buffer).unwrap()) }) + ); + socket.close(); + done = true; } - { - let mut socket = socket_set.get::(client_handle); - if !socket.is_open() { - if !did_connect { - debug!("connecting"); - socket.connect((IpAddress::v4(127, 0, 0, 1), 1234), - (IpAddress::Unspecified, 65000)).unwrap(); - did_connect = true; - } + let mut socket = sockets.get_mut::(client_handle); + let cx = iface.context(); + if !socket.is_open() { + if !did_connect { + debug!("connecting"); + socket + .connect(cx, (IpAddress::v4(127, 0, 0, 1), 1234), 65000) + .unwrap(); + did_connect = true; } + } - if socket.can_send() { - debug!("sending"); - socket.send_slice(b"0123456789abcdef").unwrap(); - socket.close(); - } + if socket.can_send() { + debug!("sending"); + socket.send_slice(b"0123456789abcdef").unwrap(); + socket.close(); } - match iface.poll_delay(&socket_set, clock.elapsed()) { - Some(Duration { millis: 0 }) => debug!("resuming"), + match iface.poll_delay(clock.elapsed(), &sockets) { + Some(Duration::ZERO) => debug!("resuming"), Some(delay) => { debug!("sleeping for {} ms", delay); clock.advance(delay) - }, - None => clock.advance(Duration::from_millis(1)) + } + None => clock.advance(Duration::from_millis(1)), } } diff --git a/examples/multicast.rs b/examples/multicast.rs index 81a606f44..ea89a2e93 100644 --- a/examples/multicast.rs +++ b/examples/multicast.rs @@ -1,22 +1,15 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; -extern crate byteorder; - mod utils; -use std::collections::BTreeMap; use std::os::unix::io::AsRawFd; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, IpVersion, IpProtocol, IpAddress, IpCidr, Ipv4Address, - Ipv4Packet, IgmpPacket, IgmpRepr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder}; -use smoltcp::socket::{SocketSet, - RawSocket, RawSocketBuffer, RawPacketMetadata, - UdpSocket, UdpSocketBuffer, UdpPacketMetadata}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::{raw, udp}; use smoltcp::time::Instant; +use smoltcp::wire::{ + EthernetAddress, IgmpPacket, IgmpRepr, IpAddress, IpCidr, IpProtocol, IpVersion, Ipv4Address, + Ipv4Packet, Ipv6Address, +}; const MDNS_PORT: u16 = 5353; const MDNS_GROUP: [u8; 4] = [224, 0, 0, 251]; @@ -25,89 +18,112 @@ fn main() { utils::setup_logging("warn"); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, - device, - /*loopback=*/ - false); - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - - let local_addr = Ipv4Address::new(192, 168, 69, 2); - - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]); - let ip_addr = IpCidr::new(IpAddress::from(local_addr), 24); - let mut ipv4_multicast_storage = [None; 1]; - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs([ip_addr]) - .ipv4_multicast_groups(&mut ipv4_multicast_storage[..]) - .finalize(); - - let now = Instant::now(); - // Join a multicast group to receive mDNS traffic - iface.join_multicast_group(Ipv4Address::from_bytes(&MDNS_GROUP), now).unwrap(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets let mut sockets = SocketSet::new(vec![]); // Must fit at least one IGMP packet - let raw_rx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; 2], vec![0; 512]); + let raw_rx_buffer = raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; 2], vec![0; 512]); // Will not send IGMP - let raw_tx_buffer = RawSocketBuffer::new(vec![], vec![]); - let raw_socket = RawSocket::new( - IpVersion::Ipv4, IpProtocol::Igmp, - raw_rx_buffer, raw_tx_buffer + let raw_tx_buffer = raw::PacketBuffer::new(vec![], vec![]); + let raw_socket = raw::Socket::new( + IpVersion::Ipv4, + IpProtocol::Igmp, + raw_rx_buffer, + raw_tx_buffer, ); let raw_handle = sockets.add(raw_socket); // Must fit mDNS payload of at least one packet - let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY; 4], vec![0; 1024]); + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY; 4], vec![0; 1024]); // Will not send mDNS - let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 0]); - let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 0]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); let udp_handle = sockets.add(udp_socket); + // Join a multicast group to receive mDNS traffic + iface + .join_multicast_group( + &mut device, + Ipv4Address::from_bytes(&MDNS_GROUP), + Instant::now(), + ) + .unwrap(); + loop { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}",e); + iface.poll(timestamp, &mut device, &mut sockets); + + let socket = sockets.get_mut::(raw_handle); + + if socket.can_recv() { + // For display purposes only - normally we wouldn't process incoming IGMP packets + // in the application layer + match socket.recv() { + Err(e) => println!("Recv IGMP error: {e:?}"), + Ok(buf) => { + Ipv4Packet::new_checked(buf) + .and_then(|ipv4_packet| IgmpPacket::new_checked(ipv4_packet.payload())) + .and_then(|igmp_packet| IgmpRepr::parse(&igmp_packet)) + .map(|igmp_repr| println!("IGMP packet: {igmp_repr:?}")) + .unwrap_or_else(|e| println!("parse IGMP error: {e:?}")); + } } } - { - let mut socket = sockets.get::(raw_handle); - - if socket.can_recv() { - // For display purposes only - normally we wouldn't process incoming IGMP packets - // in the application layer - socket.recv() - .and_then(|payload| Ipv4Packet::new_checked(payload)) - .and_then(|ipv4_packet| IgmpPacket::new_checked(ipv4_packet.payload())) - .and_then(|igmp_packet| IgmpRepr::parse(&igmp_packet)) - .map(|igmp_repr| println!("IGMP packet: {:?}", igmp_repr)) - .unwrap_or_else(|e| println!("Recv IGMP error: {:?}", e)); - } + let socket = sockets.get_mut::(udp_handle); + if !socket.is_open() { + socket.bind(MDNS_PORT).unwrap() } - { - let mut socket = sockets.get::(udp_handle); - if !socket.is_open() { - socket.bind(MDNS_PORT).unwrap() - } - if socket.can_recv() { - socket.recv() - .map(|(data, sender)| println!("mDNS traffic: {} UDP bytes from {}", data.len(), sender)) - .unwrap_or_else(|e| println!("Recv UDP error: {:?}", e)); - } + if socket.can_recv() { + socket + .recv() + .map(|(data, sender)| { + println!("mDNS traffic: {} UDP bytes from {}", data.len(), sender) + }) + .unwrap_or_else(|e| println!("Recv UDP error: {e:?}")); } - phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error"); + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); } } diff --git a/examples/ping.rs b/examples/ping.rs index bdbf96e6c..341413a3a 100644 --- a/examples/ping.rs +++ b/examples/ping.rs @@ -1,26 +1,24 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; -extern crate byteorder; - mod utils; -use std::str::FromStr; -use std::collections::BTreeMap; +use byteorder::{ByteOrder, NetworkEndian}; +use smoltcp::iface::{Interface, SocketSet}; use std::cmp; +use std::collections::HashMap; use std::os::unix::io::AsRawFd; -use smoltcp::time::{Duration, Instant}; -use smoltcp::phy::Device; +use std::str::FromStr; + +use smoltcp::iface::Config; use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, - Ipv6Address, Icmpv6Repr, Icmpv6Packet, - Ipv4Address, Icmpv4Repr, Icmpv4Packet}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder, Routes}; -use smoltcp::socket::{SocketSet, IcmpSocket, IcmpSocketBuffer, IcmpPacketMetadata, IcmpEndpoint}; -use std::collections::HashMap; -use byteorder::{ByteOrder, NetworkEndian}; +use smoltcp::phy::Device; +use smoltcp::socket::icmp; +use smoltcp::wire::{ + EthernetAddress, Icmpv4Packet, Icmpv4Repr, Icmpv6Packet, Icmpv6Repr, IpAddress, IpCidr, + Ipv4Address, Ipv6Address, +}; +use smoltcp::{ + phy::Medium, + time::{Duration, Instant}, +}; macro_rules! send_icmp_ping { ( $repr_type:ident, $packet_type:ident, $ident:expr, $seq_no:expr, @@ -31,13 +29,11 @@ macro_rules! send_icmp_ping { data: &$echo_payload, }; - let icmp_payload = $socket - .send(icmp_repr.buffer_len(), $remote_addr) - .unwrap(); + let icmp_payload = $socket.send(icmp_repr.buffer_len(), $remote_addr).unwrap(); let icmp_packet = $packet_type::new_unchecked(icmp_payload); (icmp_repr, icmp_packet) - }} + }}; } macro_rules! get_icmp_pong { @@ -46,70 +42,103 @@ macro_rules! get_icmp_pong { if let $repr_type::EchoReply { seq_no, data, .. } = $repr { if let Some(_) = $waiting_queue.get(&seq_no) { let packet_timestamp_ms = NetworkEndian::read_i64(data); - println!("{} bytes from {}: icmp_seq={}, time={}ms", - data.len(), $remote_addr, seq_no, - $timestamp.total_millis() - packet_timestamp_ms); + println!( + "{} bytes from {}: icmp_seq={}, time={}ms", + data.len(), + $remote_addr, + seq_no, + $timestamp.total_millis() - packet_timestamp_ms + ); $waiting_queue.remove(&seq_no); $received += 1; } } - }} + }}; } fn main() { utils::setup_logging("warn"); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); - opts.optopt("c", "count", "Amount of echo request packets to send (default: 4)", "COUNT"); - opts.optopt("i", "interval", - "Interval between successive packets sent (seconds) (default: 1)", "INTERVAL"); - opts.optopt("", "timeout", - "Maximum wait duration for an echo response packet (seconds) (default: 5)", - "TIMEOUT"); + opts.optopt( + "c", + "count", + "Amount of echo request packets to send (default: 4)", + "COUNT", + ); + opts.optopt( + "i", + "interval", + "Interval between successive packets sent (seconds) (default: 1)", + "INTERVAL", + ); + opts.optopt( + "", + "timeout", + "Maximum wait duration for an echo response packet (seconds) (default: 5)", + "TIMEOUT", + ); free.push("ADDRESS"); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); let device_caps = device.capabilities(); - let address = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); - let count = matches.opt_str("count").map(|s| usize::from_str(&s).unwrap()).unwrap_or(4); - let interval = matches.opt_str("interval") + let remote_addr = IpAddress::from_str(&matches.free[0]).expect("invalid address format"); + let count = matches + .opt_str("count") + .map(|s| usize::from_str(&s).unwrap()) + .unwrap_or(4); + let interval = matches + .opt_str("interval") .map(|s| Duration::from_secs(u64::from_str(&s).unwrap())) - .unwrap_or(Duration::from_secs(1)); - let timeout = Duration::from_secs( - matches.opt_str("timeout").map(|s| u64::from_str(&s).unwrap()).unwrap_or(5) + .unwrap_or_else(|| Duration::from_secs(1)); + let timeout = Duration::from_secs( + matches + .opt_str("timeout") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(5), ); - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - - let remote_addr = address; - - let icmp_rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 256]); - let icmp_tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 256]); - let icmp_socket = IcmpSocket::new(icmp_rx_buffer, icmp_tx_buffer); - - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x02]); - let src_ipv6 = IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1); - let ip_addrs = [IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24), - IpCidr::new(src_ipv6, 64), - IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)]; - let default_v4_gw = Ipv4Address::new(192, 168, 69, 100); - let default_v6_gw = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100); - let mut routes_storage = [None; 2]; - let mut routes = Routes::new(&mut routes_storage[..]); - routes.add_default_ipv4_route(default_v4_gw).unwrap(); - routes.add_default_ipv6_route(default_v6_gw).unwrap(); - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .ip_addrs(ip_addrs) - .routes(routes) - .neighbor_cache(neighbor_cache) - .finalize(); + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + config.random_seed = rand::random(); + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let icmp_rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 256]); + let icmp_tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 256]); + let icmp_socket = icmp::Socket::new(icmp_rx_buffer, icmp_tx_buffer); let mut sockets = SocketSet::new(vec![]); let icmp_handle = sockets.add(icmp_socket); @@ -122,89 +151,113 @@ fn main() { loop { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}", e); - } + iface.poll(timestamp, &mut device, &mut sockets); + + let timestamp = Instant::now(); + let socket = sockets.get_mut::(icmp_handle); + if !socket.is_open() { + socket.bind(icmp::Endpoint::Ident(ident)).unwrap(); + send_at = timestamp; } - { - let timestamp = Instant::now(); - let mut socket = sockets.get::(icmp_handle); - if !socket.is_open() { - socket.bind(IcmpEndpoint::Ident(ident)).unwrap(); - send_at = timestamp; - } + if socket.can_send() && seq_no < count as u16 && send_at <= timestamp { + NetworkEndian::write_i64(&mut echo_payload, timestamp.total_millis()); - if socket.can_send() && seq_no < count as u16 && - send_at <= timestamp { - NetworkEndian::write_i64(&mut echo_payload, timestamp.total_millis()); - - match remote_addr { - IpAddress::Ipv4(_) => { - let (icmp_repr, mut icmp_packet) = send_icmp_ping!( - Icmpv4Repr, Icmpv4Packet, ident, seq_no, - echo_payload, socket, remote_addr); - icmp_repr.emit(&mut icmp_packet, &device_caps.checksum); - }, - IpAddress::Ipv6(_) => { - let (icmp_repr, mut icmp_packet) = send_icmp_ping!( - Icmpv6Repr, Icmpv6Packet, ident, seq_no, - echo_payload, socket, remote_addr); - icmp_repr.emit(&src_ipv6, &remote_addr, - &mut icmp_packet, &device_caps.checksum); - }, - _ => unimplemented!() + match remote_addr { + IpAddress::Ipv4(_) => { + let (icmp_repr, mut icmp_packet) = send_icmp_ping!( + Icmpv4Repr, + Icmpv4Packet, + ident, + seq_no, + echo_payload, + socket, + remote_addr + ); + icmp_repr.emit(&mut icmp_packet, &device_caps.checksum); } - - waiting_queue.insert(seq_no, timestamp); - seq_no += 1; - send_at += interval; - } - - if socket.can_recv() { - let (payload, _) = socket.recv().unwrap(); - - match remote_addr { - IpAddress::Ipv4(_) => { - let icmp_packet = Icmpv4Packet::new_checked(&payload).unwrap(); - let icmp_repr = - Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum).unwrap(); - get_icmp_pong!(Icmpv4Repr, icmp_repr, payload, - waiting_queue, remote_addr, timestamp, received); - } - IpAddress::Ipv6(_) => { - let icmp_packet = Icmpv6Packet::new_checked(&payload).unwrap(); - let icmp_repr = Icmpv6Repr::parse(&remote_addr, &src_ipv6, - &icmp_packet, &device_caps.checksum).unwrap(); - get_icmp_pong!(Icmpv6Repr, icmp_repr, payload, - waiting_queue, remote_addr, timestamp, received); - }, - _ => unimplemented!() + IpAddress::Ipv6(_) => { + let (icmp_repr, mut icmp_packet) = send_icmp_ping!( + Icmpv6Repr, + Icmpv6Packet, + ident, + seq_no, + echo_payload, + socket, + remote_addr + ); + icmp_repr.emit( + &iface.ipv6_addr().unwrap().into_address(), + &remote_addr, + &mut icmp_packet, + &device_caps.checksum, + ); } } - waiting_queue.retain(|seq, from| { - if timestamp - *from < timeout { - true - } else { - println!("From {} icmp_seq={} timeout", remote_addr, seq); - false + waiting_queue.insert(seq_no, timestamp); + seq_no += 1; + send_at += interval; + } + + if socket.can_recv() { + let (payload, _) = socket.recv().unwrap(); + + match remote_addr { + IpAddress::Ipv4(_) => { + let icmp_packet = Icmpv4Packet::new_checked(&payload).unwrap(); + let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &device_caps.checksum).unwrap(); + get_icmp_pong!( + Icmpv4Repr, + icmp_repr, + payload, + waiting_queue, + remote_addr, + timestamp, + received + ); } - }); + IpAddress::Ipv6(_) => { + let icmp_packet = Icmpv6Packet::new_checked(&payload).unwrap(); + let icmp_repr = Icmpv6Repr::parse( + &remote_addr, + &iface.ipv6_addr().unwrap().into_address(), + &icmp_packet, + &device_caps.checksum, + ) + .unwrap(); + get_icmp_pong!( + Icmpv6Repr, + icmp_repr, + payload, + waiting_queue, + remote_addr, + timestamp, + received + ); + } + } + } - if seq_no == count as u16 && waiting_queue.is_empty() { - break + waiting_queue.retain(|seq, from| { + if timestamp - *from < timeout { + true + } else { + println!("From {remote_addr} icmp_seq={seq} timeout"); + false } + }); + + if seq_no == count as u16 && waiting_queue.is_empty() { + break; } let timestamp = Instant::now(); - match iface.poll_at(&sockets, timestamp) { + match iface.poll_at(timestamp, &sockets) { Some(poll_at) if timestamp < poll_at => { let resume_at = cmp::min(poll_at, send_at); phy_wait(fd, Some(resume_at - timestamp)).expect("wait error"); - }, + } Some(_) => (), None => { phy_wait(fd, Some(send_at - timestamp)).expect("wait error"); @@ -212,7 +265,11 @@ fn main() { } } - println!("--- {} ping statistics ---", remote_addr); - println!("{} packets transmitted, {} received, {:.0}% packet loss", - seq_no, received, 100.0 * (seq_no - received) as f64 / seq_no as f64); + println!("--- {remote_addr} ping statistics ---"); + println!( + "{} packets transmitted, {} received, {:.0}% packet loss", + seq_no, + received, + 100.0 * (seq_no - received) as f64 / seq_no as f64 + ); } diff --git a/examples/server.rs b/examples/server.rs index 98940cd8e..33d95c5d5 100644 --- a/examples/server.rs +++ b/examples/server.rs @@ -1,71 +1,89 @@ -#[macro_use] -extern crate log; -extern crate env_logger; -extern crate getopts; -extern crate smoltcp; - mod utils; -use std::str; -use std::collections::BTreeMap; +use log::debug; use std::fmt::Write; use std::os::unix::io::AsRawFd; -use smoltcp::phy::wait as phy_wait; -use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder}; -use smoltcp::socket::SocketSet; -use smoltcp::socket::{UdpSocket, UdpSocketBuffer, UdpPacketMetadata}; -use smoltcp::socket::{TcpSocket, TcpSocketBuffer}; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium}; +use smoltcp::socket::{tcp, udp}; use smoltcp::time::{Duration, Instant}; +use smoltcp::wire::{EthernetAddress, IpAddress, IpCidr, Ipv4Address, Ipv6Address}; fn main() { utils::setup_logging(""); let (mut opts, mut free) = utils::create_options(); - utils::add_tap_options(&mut opts, &mut free); + utils::add_tuntap_options(&mut opts, &mut free); utils::add_middleware_options(&mut opts, &mut free); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_tap_options(&mut matches); + let device = utils::parse_tuntap_options(&mut matches); let fd = device.as_raw_fd(); - let device = utils::parse_middleware_options(&mut matches, device, /*loopback=*/false); - - let neighbor_cache = NeighborCache::new(BTreeMap::new()); - - let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 64]); - let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 128]); - let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - - let tcp1_rx_buffer = TcpSocketBuffer::new(vec![0; 64]); - let tcp1_tx_buffer = TcpSocketBuffer::new(vec![0; 128]); - let tcp1_socket = TcpSocket::new(tcp1_rx_buffer, tcp1_tx_buffer); - - let tcp2_rx_buffer = TcpSocketBuffer::new(vec![0; 64]); - let tcp2_tx_buffer = TcpSocketBuffer::new(vec![0; 128]); - let tcp2_socket = TcpSocket::new(tcp2_rx_buffer, tcp2_tx_buffer); - - let tcp3_rx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp3_tx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp3_socket = TcpSocket::new(tcp3_rx_buffer, tcp3_tx_buffer); - - let tcp4_rx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp4_tx_buffer = TcpSocketBuffer::new(vec![0; 65535]); - let tcp4_socket = TcpSocket::new(tcp4_rx_buffer, tcp4_tx_buffer); - - let ethernet_addr = EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]); - let ip_addrs = [ - IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24), - IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64), - IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64) - ]; - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(ethernet_addr) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .finalize(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => todo!(), + }; + + config.random_seed = rand::random(); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(192, 168, 69, 1), 24)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdaa, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + iface + .routes_mut() + .add_default_ipv4_route(Ipv4Address::new(192, 168, 69, 100)) + .unwrap(); + iface + .routes_mut() + .add_default_ipv6_route(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0x100)) + .unwrap(); + + // Create sockets + let udp_rx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY, udp::PacketMetadata::EMPTY], + vec![0; 65535], + ); + let udp_tx_buffer = udp::PacketBuffer::new( + vec![udp::PacketMetadata::EMPTY, udp::PacketMetadata::EMPTY], + vec![0; 65535], + ); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 64]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 128]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); + + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 64]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 128]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); + + let tcp3_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp3_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp3_socket = tcp::Socket::new(tcp3_rx_buffer, tcp3_tx_buffer); + + let tcp4_rx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp4_tx_buffer = tcp::SocketBuffer::new(vec![0; 65535]); + let tcp4_socket = tcp::Socket::new(tcp4_rx_buffer, tcp4_tx_buffer); let mut sockets = SocketSet::new(vec![]); - let udp_handle = sockets.add(udp_socket); + let udp_handle = sockets.add(udp_socket); let tcp1_handle = sockets.add(tcp1_socket); let tcp2_handle = sockets.add(tcp2_socket); let tcp3_handle = sockets.add(tcp3_socket); @@ -74,130 +92,118 @@ fn main() { let mut tcp_6970_active = false; loop { let timestamp = Instant::now(); - match iface.poll(&mut sockets, timestamp) { - Ok(_) => {}, - Err(e) => { - debug!("poll error: {}", e); - } - } + iface.poll(timestamp, &mut device, &mut sockets); // udp:6969: respond "hello" - { - let mut socket = sockets.get::(udp_handle); - if !socket.is_open() { - socket.bind(6969).unwrap() - } + let socket = sockets.get_mut::(udp_handle); + if !socket.is_open() { + socket.bind(6969).unwrap() + } - let client = match socket.recv() { - Ok((data, endpoint)) => { - debug!("udp:6969 recv data: {:?} from {}", - str::from_utf8(data.as_ref()).unwrap(), endpoint); - Some(endpoint) - } - Err(_) => None - }; - if let Some(endpoint) = client { - let data = b"hello\n"; - debug!("udp:6969 send data: {:?}", - str::from_utf8(data.as_ref()).unwrap()); - socket.send_slice(data, endpoint).unwrap(); + let client = match socket.recv() { + Ok((data, endpoint)) => { + debug!("udp:6969 recv data: {:?} from {}", data, endpoint); + let mut data = data.to_vec(); + data.reverse(); + Some((endpoint, data)) } + Err(_) => None, + }; + if let Some((endpoint, data)) = client { + debug!("udp:6969 send data: {:?} to {}", data, endpoint,); + socket.send_slice(&data, endpoint).unwrap(); } // tcp:6969: respond "hello" - { - let mut socket = sockets.get::(tcp1_handle); - if !socket.is_open() { - socket.listen(6969).unwrap(); - } + let socket = sockets.get_mut::(tcp1_handle); + if !socket.is_open() { + socket.listen(6969).unwrap(); + } - if socket.can_send() { - debug!("tcp:6969 send greeting"); - write!(socket, "hello\n").unwrap(); - debug!("tcp:6969 close"); - socket.close(); - } + if socket.can_send() { + debug!("tcp:6969 send greeting"); + writeln!(socket, "hello").unwrap(); + debug!("tcp:6969 close"); + socket.close(); } // tcp:6970: echo with reverse - { - let mut socket = sockets.get::(tcp2_handle); - if !socket.is_open() { - socket.listen(6970).unwrap() - } + let socket = sockets.get_mut::(tcp2_handle); + if !socket.is_open() { + socket.listen(6970).unwrap() + } - if socket.is_active() && !tcp_6970_active { - debug!("tcp:6970 connected"); - } else if !socket.is_active() && tcp_6970_active { - debug!("tcp:6970 disconnected"); - } - tcp_6970_active = socket.is_active(); + if socket.is_active() && !tcp_6970_active { + debug!("tcp:6970 connected"); + } else if !socket.is_active() && tcp_6970_active { + debug!("tcp:6970 disconnected"); + } + tcp_6970_active = socket.is_active(); - if socket.may_recv() { - let data = socket.recv(|buffer| { + if socket.may_recv() { + let data = socket + .recv(|buffer| { let recvd_len = buffer.len(); let mut data = buffer.to_owned(); - if data.len() > 0 { - debug!("tcp:6970 recv data: {:?}", - str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")); + if !data.is_empty() { + debug!("tcp:6970 recv data: {:?}", data); data = data.split(|&b| b == b'\n').collect::>().concat(); data.reverse(); data.extend(b"\n"); } (recvd_len, data) - }).unwrap(); - if socket.can_send() && data.len() > 0 { - debug!("tcp:6970 send data: {:?}", - str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)")); - socket.send_slice(&data[..]).unwrap(); - } - } else if socket.may_send() { - debug!("tcp:6970 close"); - socket.close(); + }) + .unwrap(); + if socket.can_send() && !data.is_empty() { + debug!("tcp:6970 send data: {:?}", data); + socket.send_slice(&data[..]).unwrap(); } + } else if socket.may_send() { + debug!("tcp:6970 close"); + socket.close(); } // tcp:6971: sinkhole - { - let mut socket = sockets.get::(tcp3_handle); - if !socket.is_open() { - socket.listen(6971).unwrap(); - socket.set_keep_alive(Some(Duration::from_millis(1000))); - socket.set_timeout(Some(Duration::from_millis(2000))); - } + let socket = sockets.get_mut::(tcp3_handle); + if !socket.is_open() { + socket.listen(6971).unwrap(); + socket.set_keep_alive(Some(Duration::from_millis(1000))); + socket.set_timeout(Some(Duration::from_millis(2000))); + } - if socket.may_recv() { - socket.recv(|buffer| { - if buffer.len() > 0 { + if socket.may_recv() { + socket + .recv(|buffer| { + if !buffer.is_empty() { debug!("tcp:6971 recv {:?} octets", buffer.len()); } (buffer.len(), ()) - }).unwrap(); - } else if socket.may_send() { - socket.close(); - } + }) + .unwrap(); + } else if socket.may_send() { + socket.close(); } // tcp:6972: fountain - { - let mut socket = sockets.get::(tcp4_handle); - if !socket.is_open() { - socket.listen(6972).unwrap() - } + let socket = sockets.get_mut::(tcp4_handle); + if !socket.is_open() { + socket.listen(6972).unwrap() + } - if socket.may_send() { - socket.send(|data| { - if data.len() > 0 { + if socket.may_send() { + socket + .send(|data| { + if !data.is_empty() { debug!("tcp:6972 send {:?} octets", data.len()); for (i, b) in data.iter_mut().enumerate() { *b = (i % 256) as u8; } } (data.len(), ()) - }).unwrap(); - } + }) + .unwrap(); } - phy_wait(fd, iface.poll_delay(&sockets, timestamp)).expect("wait error"); + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); } } diff --git a/examples/sixlowpan.rs b/examples/sixlowpan.rs new file mode 100644 index 000000000..9d474e3fe --- /dev/null +++ b/examples/sixlowpan.rs @@ -0,0 +1,177 @@ +//! 6lowpan exmaple +//! +//! This example is designed to run using the Linux ieee802154/6lowpan support, +//! using mac802154_hwsim. +//! +//! mac802154_hwsim allows you to create multiple "virtual" radios and specify +//! which is in range with which. This is very useful for testing without +//! needing real hardware. By default it creates two interfaces `wpan0` and +//! `wpan1` that are in range with each other. You can customize this with +//! the `wpan-hwsim` tool. +//! +//! We'll configure Linux to speak 6lowpan on `wpan0`, and leave `wpan1` +//! unconfigured so smoltcp can use it with a raw socket. +//! +//! # Setup +//! +//! modprobe mac802154_hwsim +//! +//! ip link set wpan0 down +//! ip link set wpan1 down +//! iwpan dev wpan0 set pan_id 0xbeef +//! iwpan dev wpan1 set pan_id 0xbeef +//! ip link add link wpan0 name lowpan0 type lowpan +//! ip link set wpan0 up +//! ip link set wpan1 up +//! ip link set lowpan0 up +//! +//! # Running +//! +//! Run it with `sudo ./target/debug/examples/sixlowpan`. +//! +//! You can set wireshark to sniff on interface `wpan0` to see the packets. +//! +//! Ping it with `ping fe80::180b:4242:4242:4242%lowpan0`. +//! +//! Speak UDP with `nc -uv fe80::180b:4242:4242:4242%lowpan0 6969`. +//! +//! # Teardown +//! +//! rmmod mac802154_hwsim +//! + +mod utils; + +use log::debug; +use std::os::unix::io::AsRawFd; +use std::str; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium, RawSocket}; +use smoltcp::socket::tcp; +use smoltcp::socket::udp; +use smoltcp::time::Instant; +use smoltcp::wire::{EthernetAddress, Ieee802154Address, Ieee802154Pan, IpAddress, IpCidr}; + +fn main() { + utils::setup_logging(""); + + let (mut opts, mut free) = utils::create_options(); + utils::add_middleware_options(&mut opts, &mut free); + + let mut matches = utils::parse_options(&opts, free); + + let device = RawSocket::new("wpan1", Medium::Ieee802154).unwrap(); + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => Config::new( + Ieee802154Address::Extended([0x1a, 0x0b, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42]).into(), + ), + }; + config.random_seed = rand::random(); + config.pan_id = Some(Ieee802154Pan(0xbeef)); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242), + 64, + )) + .unwrap(); + }); + + // Create sockets + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1280]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1280]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + + let tcp_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp_socket = tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let udp_handle = sockets.add(udp_socket); + let tcp_handle = sockets.add(tcp_socket); + + let socket = sockets.get_mut::(tcp_handle); + socket.listen(50000).unwrap(); + + let mut tcp_active = false; + + loop { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // udp:6969: respond "hello" + let socket = sockets.get_mut::(udp_handle); + if !socket.is_open() { + socket.bind(6969).unwrap() + } + + let mut buffer = vec![0; 1500]; + let client = match socket.recv() { + Ok((data, endpoint)) => { + debug!( + "udp:6969 recv data: {:?} from {}", + str::from_utf8(data).unwrap(), + endpoint + ); + buffer[..data.len()].copy_from_slice(data); + Some((data.len(), endpoint)) + } + Err(_) => None, + }; + if let Some((len, endpoint)) = client { + debug!( + "udp:6969 send data: {:?}", + str::from_utf8(&buffer[..len]).unwrap() + ); + socket.send_slice(&buffer[..len], endpoint).unwrap(); + } + + let socket = sockets.get_mut::(tcp_handle); + if socket.is_active() && !tcp_active { + debug!("connected"); + } else if !socket.is_active() && tcp_active { + debug!("disconnected"); + } + tcp_active = socket.is_active(); + + if socket.may_recv() { + let data = socket + .recv(|data| { + let data = data.to_owned(); + if !data.is_empty() { + debug!( + "recv data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + } + (data.len(), data) + }) + .unwrap(); + + if socket.can_send() && !data.is_empty() { + debug!( + "send data: {:?}", + str::from_utf8(data.as_ref()).unwrap_or("(invalid utf8)") + ); + socket.send_slice(&data[..]).unwrap(); + } + } else if socket.may_send() { + debug!("close"); + socket.close(); + } + + phy_wait(fd, iface.poll_delay(timestamp, &sockets)).expect("wait error"); + } +} diff --git a/examples/sixlowpan_benchmark.rs b/examples/sixlowpan_benchmark.rs new file mode 100644 index 000000000..b459d8d84 --- /dev/null +++ b/examples/sixlowpan_benchmark.rs @@ -0,0 +1,235 @@ +//! 6lowpan benchmark exmaple +//! +//! This example runs a simple TCP throughput benchmark using the 6lowpan implementation in smoltcp +//! It is designed to run using the Linux ieee802154/6lowpan support, +//! using mac802154_hwsim. +//! +//! mac802154_hwsim allows you to create multiple "virtual" radios and specify +//! which is in range with which. This is very useful for testing without +//! needing real hardware. By default it creates two interfaces `wpan0` and +//! `wpan1` that are in range with each other. You can customize this with +//! the `wpan-hwsim` tool. +//! +//! We'll configure Linux to speak 6lowpan on `wpan0`, and leave `wpan1` +//! unconfigured so smoltcp can use it with a raw socket. +//! +//! +//! +//! +//! +//! # Setup +//! +//! modprobe mac802154_hwsim +//! +//! ip link set wpan0 down +//! ip link set wpan1 down +//! iwpan dev wpan0 set pan_id 0xbeef +//! iwpan dev wpan1 set pan_id 0xbeef +//! ip link add link wpan0 name lowpan0 type lowpan +//! ip link set wpan0 up +//! ip link set wpan1 up +//! ip link set lowpan0 up +//! +//! +//! # Running +//! +//! Compile with `cargo build --release --example sixlowpan_benchmark` +//! Run it with `sudo ./target/release/examples/sixlowpan_benchmark [reader|writer]`. +//! +//! # Teardown +//! +//! rmmod mac802154_hwsim +//! + +mod utils; + +use std::os::unix::io::AsRawFd; +use std::str; + +use smoltcp::iface::{Config, Interface, SocketSet}; +use smoltcp::phy::{wait as phy_wait, Device, Medium, RawSocket}; +use smoltcp::socket::tcp; +use smoltcp::wire::{EthernetAddress, Ieee802154Address, Ieee802154Pan, IpAddress, IpCidr}; + +//For benchmark +use smoltcp::time::{Duration, Instant}; +use std::cmp; +use std::io::{Read, Write}; +use std::net::SocketAddrV6; +use std::net::TcpStream; +use std::sync::atomic::{AtomicBool, Ordering}; +use std::thread; + +use std::fs; + +fn if_nametoindex(ifname: &str) -> u32 { + let contents = fs::read_to_string(format!("/sys/devices/virtual/net/{ifname}/ifindex")) + .expect("couldn't read interface from \"/sys/devices/virtual/net\"") + .replace('\n', ""); + contents.parse::().unwrap() +} + +const AMOUNT: usize = 100_000_000; + +enum Client { + Reader, + Writer, +} + +fn client(kind: Client) { + let port: u16 = match kind { + Client::Reader => 1234, + Client::Writer => 1235, + }; + + let scope_id = if_nametoindex("lowpan0"); + + let socket_addr = SocketAddrV6::new( + "fe80:0:0:0:180b:4242:4242:4242".parse().unwrap(), + port, + 0, + scope_id, + ); + + let mut stream = TcpStream::connect(socket_addr).expect("failed to connect TLKAGMKA"); + let mut buffer = vec![0; 1_000_000]; + + let start = Instant::now(); + + let mut processed = 0; + while processed < AMOUNT { + let length = cmp::min(buffer.len(), AMOUNT - processed); + let result = match kind { + Client::Reader => stream.read(&mut buffer[..length]), + Client::Writer => stream.write(&buffer[..length]), + }; + match result { + Ok(0) => break, + Ok(result) => { + // print!("(P:{})", result); + processed += result + } + Err(err) => panic!("cannot process: {err}"), + } + } + + let end = Instant::now(); + + let elapsed = (end - start).total_millis() as f64 / 1000.0; + + println!("throughput: {:.3} Gbps", AMOUNT as f64 / elapsed / 0.125e9); + + CLIENT_DONE.store(true, Ordering::SeqCst); +} + +static CLIENT_DONE: AtomicBool = AtomicBool::new(false); + +fn main() { + #[cfg(feature = "log")] + utils::setup_logging("info"); + + let (mut opts, mut free) = utils::create_options(); + utils::add_middleware_options(&mut opts, &mut free); + free.push("MODE"); + + let mut matches = utils::parse_options(&opts, free); + + let device = RawSocket::new("wpan1", Medium::Ieee802154).unwrap(); + + let fd = device.as_raw_fd(); + let mut device = + utils::parse_middleware_options(&mut matches, device, /*loopback=*/ false); + + let mode = match matches.free[0].as_ref() { + "reader" => Client::Reader, + "writer" => Client::Writer, + _ => panic!("invalid mode"), + }; + + // Create interface + let mut config = match device.capabilities().medium { + Medium::Ethernet => { + Config::new(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x01]).into()) + } + Medium::Ip => Config::new(smoltcp::wire::HardwareAddress::Ip), + Medium::Ieee802154 => Config::new( + Ieee802154Address::Extended([0x1a, 0x0b, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42]).into(), + ), + }; + config.random_seed = rand::random(); + config.pan_id = Some(Ieee802154Pan(0xbeef)); + + let mut iface = Interface::new(config, &mut device, Instant::now()); + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242), + 64, + )) + .unwrap(); + }); + + let tcp1_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp1_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp1_socket = tcp::Socket::new(tcp1_rx_buffer, tcp1_tx_buffer); + + let tcp2_rx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp2_tx_buffer = tcp::SocketBuffer::new(vec![0; 4096]); + let tcp2_socket = tcp::Socket::new(tcp2_rx_buffer, tcp2_tx_buffer); + + let mut sockets = SocketSet::new(vec![]); + let tcp1_handle = sockets.add(tcp1_socket); + let tcp2_handle = sockets.add(tcp2_socket); + + let default_timeout = Some(Duration::from_millis(1000)); + + thread::spawn(move || client(mode)); + let mut processed = 0; + + while !CLIENT_DONE.load(Ordering::SeqCst) { + let timestamp = Instant::now(); + iface.poll(timestamp, &mut device, &mut sockets); + + // tcp:1234: emit data + let socket = sockets.get_mut::(tcp1_handle); + if !socket.is_open() { + socket.listen(1234).unwrap(); + } + + if socket.can_send() && processed < AMOUNT { + let length = socket + .send(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + + // tcp:1235: sink data + let socket = sockets.get_mut::(tcp2_handle); + if !socket.is_open() { + socket.listen(1235).unwrap(); + } + + if socket.can_recv() && processed < AMOUNT { + let length = socket + .recv(|buffer| { + let length = cmp::min(buffer.len(), AMOUNT - processed); + (length, length) + }) + .unwrap(); + processed += length; + } + + match iface.poll_at(timestamp, &sockets) { + Some(poll_at) if timestamp < poll_at => { + phy_wait(fd, Some(poll_at - timestamp)).expect("wait error"); + } + Some(_) => (), + None => { + phy_wait(fd, default_timeout).expect("wait error"); + } + } + } +} diff --git a/examples/tcpdump.rs b/examples/tcpdump.rs index 166ac1e50..2baf376e1 100644 --- a/examples/tcpdump.rs +++ b/examples/tcpdump.rs @@ -1,21 +1,21 @@ -extern crate smoltcp; - -use std::env; -use std::os::unix::io::AsRawFd; use smoltcp::phy::wait as phy_wait; -use smoltcp::phy::{Device, RxToken, RawSocket}; -use smoltcp::wire::{PrettyPrinter, EthernetFrame}; +use smoltcp::phy::{Device, RawSocket, RxToken}; use smoltcp::time::Instant; +use smoltcp::wire::{EthernetFrame, PrettyPrinter}; +use std::env; +use std::os::unix::io::AsRawFd; fn main() { let ifname = env::args().nth(1).unwrap(); - let mut socket = RawSocket::new(ifname.as_ref()).unwrap(); + let mut socket = RawSocket::new(ifname.as_ref(), smoltcp::phy::Medium::Ethernet).unwrap(); loop { phy_wait(socket.as_raw_fd(), None).unwrap(); - let (rx_token, _) = socket.receive().unwrap(); - rx_token.consume(Instant::now(), |buffer| { - println!("{}", PrettyPrinter::>::new("", &buffer)); - Ok(()) - }).unwrap(); + let (rx_token, _) = socket.receive(Instant::now()).unwrap(); + rx_token.consume(|buffer| { + println!( + "{}", + PrettyPrinter::>::new("", &buffer) + ); + }) } } diff --git a/examples/utils.rs b/examples/utils.rs index f6b94e05a..dbe907615 100644 --- a/examples/utils.rs +++ b/examples/utils.rs @@ -1,56 +1,67 @@ #![allow(dead_code)] -use std::cell::RefCell; -use std::str::{self, FromStr}; -use std::rc::Rc; -use std::io::{self, Write}; -use std::fs::File; -use std::time::{SystemTime, UNIX_EPOCH}; -use std::env; -use std::process; -#[cfg(feature = "log")] -use log::{Level, LevelFilter}; #[cfg(feature = "log")] use env_logger::Builder; -use getopts::{Options, Matches}; +use getopts::{Matches, Options}; +#[cfg(feature = "log")] +use log::{trace, Level, LevelFilter}; +use std::env; +use std::fs::File; +use std::io::{self, Write}; +use std::process; +use std::str::{self, FromStr}; +use std::time::{SystemTime, UNIX_EPOCH}; -use smoltcp::phy::{Device, EthernetTracer, FaultInjector}; -#[cfg(feature = "phy-tap_interface")] -use smoltcp::phy::TapInterface; -use smoltcp::phy::{PcapWriter, PcapSink, PcapMode, PcapLinkType}; -use smoltcp::phy::RawSocket; +#[cfg(feature = "phy-tuntap_interface")] +use smoltcp::phy::TunTapInterface; +use smoltcp::phy::{Device, FaultInjector, Medium, Tracer}; +use smoltcp::phy::{PcapMode, PcapWriter}; use smoltcp::time::{Duration, Instant}; #[cfg(feature = "log")] pub fn setup_logging_with_clock(filter: &str, since_startup: F) - where F: Fn() -> Instant + Send + Sync + 'static { +where + F: Fn() -> Instant + Send + Sync + 'static, +{ Builder::new() .format(move |buf, record| { let elapsed = since_startup(); - let timestamp = format!("[{}]", elapsed); + let timestamp = format!("[{elapsed}]"); if record.target().starts_with("smoltcp::") { - writeln!(buf, "\x1b[0m{} ({}): {}\x1b[0m", timestamp, - record.target().replace("smoltcp::", ""), record.args()) + writeln!( + buf, + "\x1b[0m{} ({}): {}\x1b[0m", + timestamp, + record.target().replace("smoltcp::", ""), + record.args() + ) } else if record.level() == Level::Trace { let message = format!("{}", record.args()); - writeln!(buf, "\x1b[37m{} {}\x1b[0m", timestamp, - message.replace("\n", "\n ")) + writeln!( + buf, + "\x1b[37m{} {}\x1b[0m", + timestamp, + message.replace('\n', "\n ") + ) } else { - writeln!(buf, "\x1b[32m{} ({}): {}\x1b[0m", timestamp, - record.target(), record.args()) + writeln!( + buf, + "\x1b[32m{} ({}): {}\x1b[0m", + timestamp, + record.target(), + record.args() + ) } }) .filter(None, LevelFilter::Trace) - .parse(filter) - .parse(&env::var("RUST_LOG").unwrap_or("".to_owned())) + .parse_filters(filter) + .parse_env("RUST_LOG") .init(); } #[cfg(feature = "log")] pub fn setup_logging(filter: &str) { - setup_logging_with_clock(filter, move || { - Instant::now() - }) + setup_logging_with_clock(filter, Instant::now) } pub fn create_options() -> (Options, Vec<&'static str>) { @@ -62,81 +73,140 @@ pub fn create_options() -> (Options, Vec<&'static str>) { pub fn parse_options(options: &Options, free: Vec<&str>) -> Matches { match options.parse(env::args().skip(1)) { Err(err) => { - println!("{}", err); + println!("{err}"); process::exit(1) } Ok(matches) => { if matches.opt_present("h") || matches.free.len() != free.len() { - let brief = format!("Usage: {} [OPTION]... {}", - env::args().nth(0).unwrap(), free.join(" ")); + let brief = format!( + "Usage: {} [OPTION]... {}", + env::args().next().unwrap(), + free.join(" ") + ); print!("{}", options.usage(&brief)); - process::exit(if matches.free.len() != free.len() { 1 } else { 0 }) + process::exit((matches.free.len() != free.len()) as _); } matches } } } -pub fn add_tap_options(_opts: &mut Options, free: &mut Vec<&str>) { - free.push("INTERFACE"); -} - -#[cfg(feature = "phy-tap_interface")] -pub fn parse_tap_options(matches: &mut Matches) -> TapInterface { - let interface = matches.free.remove(0); - TapInterface::new(&interface).unwrap() +pub fn add_tuntap_options(opts: &mut Options, _free: &mut [&str]) { + opts.optopt("", "tun", "TUN interface to use", "tun0"); + opts.optopt("", "tap", "TAP interface to use", "tap0"); } -pub fn parse_raw_socket_options(matches: &mut Matches) -> RawSocket { - let interface = matches.free.remove(0); - RawSocket::new(&interface).unwrap() +#[cfg(feature = "phy-tuntap_interface")] +pub fn parse_tuntap_options(matches: &mut Matches) -> TunTapInterface { + let tun = matches.opt_str("tun"); + let tap = matches.opt_str("tap"); + match (tun, tap) { + (Some(tun), None) => TunTapInterface::new(&tun, Medium::Ip).unwrap(), + (None, Some(tap)) => TunTapInterface::new(&tap, Medium::Ethernet).unwrap(), + _ => panic!("You must specify exactly one of --tun or --tap"), + } } -pub fn add_middleware_options(opts: &mut Options, _free: &mut Vec<&str>) { +pub fn add_middleware_options(opts: &mut Options, _free: &mut [&str]) { opts.optopt("", "pcap", "Write a packet capture file", "FILE"); - opts.optopt("", "drop-chance", "Chance of dropping a packet (%)", "CHANCE"); - opts.optopt("", "corrupt-chance", "Chance of corrupting a packet (%)", "CHANCE"); - opts.optopt("", "size-limit", "Drop packets larger than given size (octets)", "SIZE"); - opts.optopt("", "tx-rate-limit", "Drop packets after transmit rate exceeds given limit \ - (packets per interval)", "RATE"); - opts.optopt("", "rx-rate-limit", "Drop packets after transmit rate exceeds given limit \ - (packets per interval)", "RATE"); - opts.optopt("", "shaping-interval", "Sets the interval for rate limiting (ms)", "RATE"); + opts.optopt( + "", + "drop-chance", + "Chance of dropping a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "corrupt-chance", + "Chance of corrupting a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "size-limit", + "Drop packets larger than given size (octets)", + "SIZE", + ); + opts.optopt( + "", + "tx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "rx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "shaping-interval", + "Sets the interval for rate limiting (ms)", + "RATE", + ); } -pub fn parse_middleware_options(matches: &mut Matches, device: D, loopback: bool) - -> FaultInjector>>> - where D: for<'a> Device<'a> +pub fn parse_middleware_options( + matches: &mut Matches, + device: D, + loopback: bool, +) -> FaultInjector>>> +where + D: Device, { - let drop_chance = matches.opt_str("drop-chance").map(|s| u8::from_str(&s).unwrap()) - .unwrap_or(0); - let corrupt_chance = matches.opt_str("corrupt-chance").map(|s| u8::from_str(&s).unwrap()) - .unwrap_or(0); - let size_limit = matches.opt_str("size-limit").map(|s| usize::from_str(&s).unwrap()) - .unwrap_or(0); - let tx_rate_limit = matches.opt_str("tx-rate-limit").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); - let rx_rate_limit = matches.opt_str("rx-rate-limit").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); - let shaping_interval = matches.opt_str("shaping-interval").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); + let drop_chance = matches + .opt_str("drop-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let corrupt_chance = matches + .opt_str("corrupt-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let size_limit = matches + .opt_str("size-limit") + .map(|s| usize::from_str(&s).unwrap()) + .unwrap_or(0); + let tx_rate_limit = matches + .opt_str("tx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let rx_rate_limit = matches + .opt_str("rx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let shaping_interval = matches + .opt_str("shaping-interval") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); - let pcap_writer: Box; - if let Some(pcap_filename) = matches.opt_str("pcap") { - pcap_writer = Box::new(File::create(pcap_filename).expect("cannot open file")) - } else { - pcap_writer = Box::new(io::sink()) - } + let pcap_writer: Box = match matches.opt_str("pcap") { + Some(pcap_filename) => Box::new(File::create(pcap_filename).expect("cannot open file")), + None => Box::new(io::sink()), + }; - let seed = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().subsec_nanos(); + let seed = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos(); - let device = PcapWriter::new(device, Rc::new(RefCell::new(pcap_writer)) as Rc, - if loopback { PcapMode::TxOnly } else { PcapMode::Both }, - PcapLinkType::Ethernet); - let device = EthernetTracer::new(device, |_timestamp, _printer| { + let device = PcapWriter::new( + device, + pcap_writer, + if loopback { + PcapMode::TxOnly + } else { + PcapMode::Both + }, + ); + + let device = Tracer::new(device, |_timestamp, _printer| { #[cfg(feature = "log")] trace!("{}", _printer); }); + let mut device = FaultInjector::new(device, seed); device.set_drop_chance(drop_chance); device.set_corrupt_chance(corrupt_chance); diff --git a/fuzz/Cargo.toml b/fuzz/Cargo.toml index 895812411..f526d7551 100644 --- a/fuzz/Cargo.toml +++ b/fuzz/Cargo.toml @@ -3,21 +3,16 @@ name = "smoltcp-fuzz" version = "0.0.1" authors = ["Automatically generated"] publish = false +edition = "2018" [package.metadata] cargo-fuzz = true [dependencies] +libfuzzer-sys = "0.4" +arbitrary = { version = "1", features = ["derive"] } getopts = "0.2" - -[dependencies.smoltcp] -path = ".." - -[dependencies.libfuzzer-sys] -git = "https://github.com/rust-fuzz/libfuzzer-sys.git" - -[profile.release] -codegen-units = 1 # needed to prevent weird linker error about sancov guards +smoltcp = { path = ".." } # Prevent this from interfering with workspaces [workspace] @@ -26,7 +21,29 @@ members = ["."] [[bin]] name = "packet_parser" path = "fuzz_targets/packet_parser.rs" +test = false +doc = false [[bin]] name = "tcp_headers" path = "fuzz_targets/tcp_headers.rs" +test = false +doc = false + +[[bin]] +name = "dhcp_header" +path = "fuzz_targets/dhcp_header.rs" +test = false +doc = false + +[[bin]] +name = "ieee802154_header" +path = "fuzz_targets/ieee802154_header.rs" +test = false +doc = false + +[[bin]] +name = "sixlowpan_packet" +path = "fuzz_targets/sixlowpan_packet.rs" +test = false +doc = false diff --git a/fuzz/fuzz_targets/dhcp_header.rs b/fuzz/fuzz_targets/dhcp_header.rs new file mode 100644 index 000000000..f56efd091 --- /dev/null +++ b/fuzz/fuzz_targets/dhcp_header.rs @@ -0,0 +1,19 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; +use smoltcp::wire::{DhcpPacket, DhcpRepr}; + +fuzz_target!(|data: &[u8]| { + let _ = match DhcpPacket::new_checked(data) { + Ok(packet) => match DhcpRepr::parse(packet) { + Ok(dhcp_repr) => { + let mut dhcp_payload = vec![0; dhcp_repr.buffer_len()]; + match DhcpPacket::new_checked(&mut dhcp_payload[..]) { + Ok(mut dhcp_packet) => Some(dhcp_repr.emit(&mut dhcp_packet)), + Err(_) => None, + } + } + Err(_) => None, + }, + Err(_) => None, + }; +}); diff --git a/fuzz/fuzz_targets/ieee802154_header.rs b/fuzz/fuzz_targets/ieee802154_header.rs new file mode 100644 index 000000000..88f52f63e --- /dev/null +++ b/fuzz/fuzz_targets/ieee802154_header.rs @@ -0,0 +1,19 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; +use smoltcp::wire::{Ieee802154Frame, Ieee802154Repr}; + +fuzz_target!(|data: &[u8]| { + if let Ok(frame) = Ieee802154Frame::new_checked(data) { + if let Ok(repr) = Ieee802154Repr::parse(frame) { + // The buffer len returns only the length required for emitting the header + // and does not take into account the length of the payload. + let mut buffer = vec![0; repr.buffer_len()]; + + // NOTE: unchecked because the checked version checks if the addressing mode field + // is valid or not. The addressing mode field is required for calculating the length of + // the header, which is used in `check_len`. + let mut frame = Ieee802154Frame::new_unchecked(&mut buffer[..]); + repr.emit(&mut frame); + } + }; +}); diff --git a/fuzz/fuzz_targets/packet_parser.rs b/fuzz/fuzz_targets/packet_parser.rs index 357e1f333..e9e58bffb 100644 --- a/fuzz/fuzz_targets/packet_parser.rs +++ b/fuzz/fuzz_targets/packet_parser.rs @@ -1,8 +1,10 @@ #![no_main] -#[macro_use] extern crate libfuzzer_sys; -extern crate smoltcp; +use libfuzzer_sys::fuzz_target; +use smoltcp::wire::*; fuzz_target!(|data: &[u8]| { - use smoltcp::wire::*; - format!("{}", PrettyPrinter::>::new("", &data)); + format!( + "{}", + PrettyPrinter::>::new("", &data) + ); }); diff --git a/fuzz/fuzz_targets/sixlowpan_packet.rs b/fuzz/fuzz_targets/sixlowpan_packet.rs new file mode 100644 index 000000000..1cb287029 --- /dev/null +++ b/fuzz/fuzz_targets/sixlowpan_packet.rs @@ -0,0 +1,242 @@ +#![no_main] +use libfuzzer_sys::fuzz_target; +use smoltcp::{phy::ChecksumCapabilities, wire::*}; + +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, arbitrary::Arbitrary)] +pub enum AddressFuzzer { + Absent, + Short([u8; 2]), + Extended([u8; 8]), +} + +impl From for Ieee802154Address { + fn from(val: AddressFuzzer) -> Self { + match val { + AddressFuzzer::Absent => Ieee802154Address::Absent, + AddressFuzzer::Short(b) => Ieee802154Address::Short(b), + AddressFuzzer::Extended(b) => Ieee802154Address::Extended(b), + } + } +} + +#[derive(Debug, arbitrary::Arbitrary)] +struct SixlowpanPacketFuzzer<'a> { + data: &'a [u8], + ll_src_addr: Option, + ll_dst_addr: Option, +} + +fuzz_target!(|fuzz: SixlowpanPacketFuzzer| { + match SixlowpanPacket::dispatch(fuzz.data) { + Ok(SixlowpanPacket::FragmentHeader) => { + if let Ok(frame) = SixlowpanFragPacket::new_checked(fuzz.data) { + if let Ok(repr) = SixlowpanFragRepr::parse(&frame) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut frame = SixlowpanFragPacket::new_unchecked(&mut buffer[..]); + repr.emit(&mut frame); + } + } + } + Ok(SixlowpanPacket::IphcHeader) => { + if let Ok(frame) = SixlowpanIphcPacket::new_checked(fuzz.data) { + if let Ok(iphc_repr) = SixlowpanIphcRepr::parse( + &frame, + fuzz.ll_src_addr.map(Into::into), + fuzz.ll_dst_addr.map(Into::into), + &[], + ) { + let mut buffer = vec![0; iphc_repr.buffer_len()]; + let mut iphc_frame = SixlowpanIphcPacket::new_unchecked(&mut buffer[..]); + iphc_repr.emit(&mut iphc_frame); + + let payload = frame.payload(); + match iphc_repr.next_header { + SixlowpanNextHeader::Compressed => { + if let Ok(p) = SixlowpanNhcPacket::dispatch(payload) { + match p { + SixlowpanNhcPacket::ExtHeader => { + if let Ok(frame) = + SixlowpanExtHeaderPacket::new_checked(payload) + { + if let Ok(repr) = SixlowpanExtHeaderRepr::parse(&frame) + { + let mut buffer = vec![0; repr.buffer_len()]; + let mut ext_header_frame = + SixlowpanExtHeaderPacket::new_unchecked( + &mut buffer[..], + ); + repr.emit(&mut ext_header_frame); + } + } + } + SixlowpanNhcPacket::UdpHeader => { + if let Ok(frame) = + SixlowpanUdpNhcPacket::new_checked(payload) + { + if let Ok(repr) = SixlowpanUdpNhcRepr::parse( + &frame, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + &Default::default(), + ) { + let mut buffer = vec![ + 0; + repr.header_len() + + frame.payload().len() + ]; + let mut udp_packet = + SixlowpanUdpNhcPacket::new_unchecked( + &mut buffer[..], + ); + repr.emit( + &mut udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + frame.payload().len(), + |b| b.copy_from_slice(frame.payload()), + ); + } + } + } + } + } + } + SixlowpanNextHeader::Uncompressed(proto) => match proto { + IpProtocol::HopByHop => { + if let Ok(frame) = Ipv6HopByHopHeader::new_checked(payload) { + if let Ok(repr) = Ipv6HopByHopRepr::parse(&frame) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut hop_by_hop_frame = + Ipv6HopByHopHeader::new_unchecked(&mut buffer[..]); + repr.emit(&mut hop_by_hop_frame); + } + } + } + IpProtocol::Icmp => { + if let Ok(frame) = Icmpv4Packet::new_checked(payload) { + if let Ok(repr) = + Icmpv4Repr::parse(&frame, &ChecksumCapabilities::default()) + { + let mut buffer = vec![0; repr.buffer_len()]; + let mut icmpv4_packet = + Icmpv4Packet::new_unchecked(&mut buffer[..]); + repr.emit( + &mut icmpv4_packet, + &ChecksumCapabilities::default(), + ); + } + } + } + IpProtocol::Igmp => { + if let Ok(frame) = IgmpPacket::new_checked(payload) { + if let Ok(repr) = IgmpRepr::parse(&frame) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut frame = IgmpPacket::new_unchecked(&mut buffer[..]); + repr.emit(&mut frame); + } + } + } + IpProtocol::Tcp => { + if let Ok(frame) = TcpPacket::new_checked(payload) { + if let Ok(repr) = TcpRepr::parse( + &frame, + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + &ChecksumCapabilities::default(), + ) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut frame = TcpPacket::new_unchecked(&mut buffer[..]); + repr.emit( + &mut frame, + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + &ChecksumCapabilities::default(), + ); + } + } + } + IpProtocol::Udp => { + if let Ok(frame) = UdpPacket::new_checked(payload) { + if let Ok(repr) = UdpRepr::parse( + &frame, + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + &ChecksumCapabilities::default(), + ) { + let mut buffer = + vec![0; repr.header_len() + frame.payload().len()]; + let mut packet = UdpPacket::new_unchecked(&mut buffer[..]); + repr.emit( + &mut packet, + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + frame.payload().len(), + |b| b.copy_from_slice(frame.payload()), + &ChecksumCapabilities::default(), + ); + } + } + } + IpProtocol::Ipv6Route => { + if let Ok(frame) = Ipv6RoutingHeader::new_checked(payload) { + if let Ok(repr) = Ipv6RoutingRepr::parse(&frame) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut packet = Ipv6RoutingHeader::new(&mut buffer[..]); + repr.emit(&mut packet); + } + } + } + IpProtocol::Ipv6Frag => { + if let Ok(frame) = Ipv6FragmentHeader::new_checked(payload) { + if let Ok(repr) = Ipv6FragmentRepr::parse(&frame) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut frame = + Ipv6FragmentHeader::new_unchecked(&mut buffer[..]); + repr.emit(&mut frame); + } + } + } + IpProtocol::Icmpv6 => { + if let Ok(packet) = Icmpv6Packet::new_checked(payload) { + if let Ok(repr) = Icmpv6Repr::parse( + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + &packet, + &ChecksumCapabilities::default(), + ) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut packet = + Icmpv6Packet::new_unchecked(&mut buffer[..]); + repr.emit( + &iphc_repr.src_addr.into_address(), + &iphc_repr.dst_addr.into_address(), + &mut packet, + &ChecksumCapabilities::default(), + ); + } + } + } + IpProtocol::Ipv6NoNxt => (), + IpProtocol::Ipv6Opts => { + if let Ok(packet) = Ipv6Option::new_checked(payload) { + if let Ok(repr) = Ipv6OptionRepr::parse(&packet) { + let mut buffer = vec![0; repr.buffer_len()]; + let mut packet = Ipv6Option::new_unchecked(&mut buffer[..]); + repr.emit(&mut packet); + } + } + } + IpProtocol::Unknown(_) => (), + }, + }; + + let mut buffer = vec![0; iphc_repr.buffer_len()]; + + let mut frame = SixlowpanIphcPacket::new_unchecked(&mut buffer[..]); + iphc_repr.emit(&mut frame); + } + }; + } + Err(_) => (), + } +}); diff --git a/fuzz/fuzz_targets/tcp_headers.rs b/fuzz/fuzz_targets/tcp_headers.rs index fec277294..7d4d4eaa2 100644 --- a/fuzz/fuzz_targets/tcp_headers.rs +++ b/fuzz/fuzz_targets/tcp_headers.rs @@ -1,29 +1,24 @@ #![no_main] -#[macro_use] extern crate libfuzzer_sys; -extern crate smoltcp; - -use std as core; -extern crate getopts; - -use core::cmp; -use smoltcp::phy::Loopback; +use libfuzzer_sys::fuzz_target; +use smoltcp::iface::{InterfaceBuilder, NeighborCache}; +use smoltcp::phy::{Loopback, Medium}; +use smoltcp::socket::tcp; +use smoltcp::time::{Duration, Instant}; use smoltcp::wire::{EthernetAddress, EthernetFrame, EthernetProtocol}; use smoltcp::wire::{IpAddress, IpCidr, Ipv4Packet, Ipv6Packet, TcpPacket}; -use smoltcp::iface::{NeighborCache, EthernetInterfaceBuilder}; -use smoltcp::socket::{SocketSet, TcpSocket, TcpSocketBuffer}; -use smoltcp::time::{Duration, Instant}; +use std::cmp; -mod utils { - include!("../utils.rs"); -} +#[path = "../utils.rs"] +mod utils; mod mock { + use smoltcp::time::{Duration, Instant}; + use std::sync::atomic::{AtomicUsize, Ordering}; use std::sync::Arc; - use std::sync::atomic::{Ordering, AtomicUsize}; - use smoltcp::time::{Duration, Instant}; // should be AtomicU64 but that's unstable #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Clock(Arc); impl Clock { @@ -32,7 +27,8 @@ mod mock { } pub fn advance(&self, duration: Duration) { - self.0.fetch_add(duration.total_millis() as usize, Ordering::SeqCst); + self.0 + .fetch_add(duration.total_millis() as usize, Ordering::SeqCst); } pub fn elapsed(&self) -> Instant { @@ -51,7 +47,10 @@ impl TcpHeaderFuzzer { // // Otherwise, it replaces the entire rest of the TCP header with the fuzzer's output. pub fn new(data: &[u8]) -> TcpHeaderFuzzer { - let copy_len = cmp::min(data.len(), 56 /* max TCP header length without port numbers*/); + let copy_len = cmp::min( + data.len(), + 56, /* max TCP header length without port numbers*/ + ); let mut fuzzer = TcpHeaderFuzzer([0; 56], copy_len); fuzzer.0[..copy_len].copy_from_slice(&data[..copy_len]); @@ -67,13 +66,16 @@ impl smoltcp::phy::Fuzzer for TcpHeaderFuzzer { let tcp_packet_offset = { let eth_frame = EthernetFrame::new_unchecked(&frame_data); - EthernetFrame::<&mut [u8]>::header_len() + match eth_frame.ethertype() { - EthernetProtocol::Ipv4 => - Ipv4Packet::new_unchecked(eth_frame.payload()).header_len() as usize, - EthernetProtocol::Ipv6 => - Ipv6Packet::new_unchecked(eth_frame.payload()).header_len() as usize, - _ => return - } + EthernetFrame::<&mut [u8]>::header_len() + + match eth_frame.ethertype() { + EthernetProtocol::Ipv4 => { + Ipv4Packet::new_unchecked(eth_frame.payload()).header_len() as usize + } + EthernetProtocol::Ipv6 => { + Ipv6Packet::new_unchecked(eth_frame.payload()).header_len() as usize + } + _ => return, + } }; let tcp_is_syn = { @@ -94,7 +96,7 @@ impl smoltcp::phy::Fuzzer for TcpHeaderFuzzer { (tcp_packet[12] as usize >> 4) * 4 }; - let tcp_packet = &mut frame_data[tcp_packet_offset+4..]; + let tcp_packet = &mut frame_data[tcp_packet_offset + 4..]; let replacement_data = &self.0[..self.1]; let copy_len = cmp::min(replacement_data.len(), tcp_header_len); @@ -113,28 +115,28 @@ fuzz_target!(|data: &[u8]| { let clock = mock::Clock::new(); let device = { - let (mut opts, mut free) = utils::create_options(); utils::add_middleware_options(&mut opts, &mut free); let mut matches = utils::parse_options(&opts, free); - let device = utils::parse_middleware_options(&mut matches, Loopback::new(), - /*loopback=*/true); + let device = utils::parse_middleware_options( + &mut matches, + Loopback::new(Medium::Ethernet), + /*loopback=*/ true, + ); - smoltcp::phy::FuzzInjector::new(device, - EmptyFuzzer(), - TcpHeaderFuzzer::new(data)) + smoltcp::phy::FuzzInjector::new(device, EmptyFuzzer(), TcpHeaderFuzzer::new(data)) }; let mut neighbor_cache_entries = [None; 8]; let neighbor_cache = NeighborCache::new(&mut neighbor_cache_entries[..]); let ip_addrs = [IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)]; - let mut iface = EthernetInterfaceBuilder::new(device) - .ethernet_addr(EthernetAddress::default()) - .neighbor_cache(neighbor_cache) - .ip_addrs(ip_addrs) - .finalize(); + let mut iface = InterfaceBuilder::new() + .ethernet_addr(EthernetAddress::default()) + .neighbor_cache(neighbor_cache) + .ip_addrs(ip_addrs) + .finalize(&mut device); let server_socket = { // It is not strictly necessary to use a `static mut` and unsafe code here, but @@ -143,17 +145,17 @@ fuzz_target!(|data: &[u8]| { // when stack overflows. static mut TCP_SERVER_RX_DATA: [u8; 1024] = [0; 1024]; static mut TCP_SERVER_TX_DATA: [u8; 1024] = [0; 1024]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_SERVER_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) }; let client_socket = { static mut TCP_CLIENT_RX_DATA: [u8; 1024] = [0; 1024]; static mut TCP_CLIENT_TX_DATA: [u8; 1024] = [0; 1024]; - let tcp_rx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); - let tcp_tx_buffer = TcpSocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); - TcpSocket::new(tcp_rx_buffer, tcp_tx_buffer) + let tcp_rx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_RX_DATA[..] }); + let tcp_tx_buffer = tcp::SocketBuffer::new(unsafe { &mut TCP_CLIENT_TX_DATA[..] }); + tcp::Socket::new(tcp_rx_buffer, tcp_tx_buffer) }; let mut socket_set_entries: [_; 2] = Default::default(); @@ -161,14 +163,14 @@ fuzz_target!(|data: &[u8]| { let server_handle = socket_set.add(server_socket); let client_handle = socket_set.add(client_socket); - let mut did_listen = false; + let mut did_listen = false; let mut did_connect = false; let mut done = false; while !done && clock.elapsed() < Instant::from_millis(4_000) { let _ = iface.poll(&mut socket_set, clock.elapsed()); { - let mut socket = socket_set.get::(server_handle); + let mut socket = socket_set.get::(server_handle); if !socket.is_active() && !socket.is_listening() { if !did_listen { socket.listen(1234).unwrap(); @@ -183,27 +185,31 @@ fuzz_target!(|data: &[u8]| { } { - let mut socket = socket_set.get::(client_handle); + let mut socket = socket_set.get::(client_handle); if !socket.is_open() { if !did_connect { - socket.connect((IpAddress::v4(127, 0, 0, 1), 1234), - (IpAddress::Unspecified, 65000)).unwrap(); + socket + .connect( + (IpAddress::v4(127, 0, 0, 1), 1234), + (IpAddress::Unspecified, 65000), + ) + .unwrap(); did_connect = true; } } if socket.can_send() { - socket.send_slice(b"0123456789abcdef0123456789abcdef0123456789abcdef").unwrap(); + socket + .send_slice(b"0123456789abcdef0123456789abcdef0123456789abcdef") + .unwrap(); socket.close(); } } match iface.poll_delay(&socket_set, clock.elapsed()) { - Some(Duration { millis: 0 }) => {}, - Some(delay) => { - clock.advance(delay) - }, - None => clock.advance(Duration::from_millis(1)) + Some(Duration::ZERO) => {} + Some(delay) => clock.advance(delay), + None => clock.advance(Duration::from_millis(1)), } } }); diff --git a/fuzz/utils.rs b/fuzz/utils.rs index 26763c23a..206443420 100644 --- a/fuzz/utils.rs +++ b/fuzz/utils.rs @@ -1,18 +1,17 @@ // TODO: this is literally a copy of examples/utils.rs, but without an allow dead code attribute. // The include logic does not allow having attributes in included files. -use std::cell::RefCell; -use std::str::{self, FromStr}; -use std::rc::Rc; -use std::io; -use std::fs::File; -use std::time::{SystemTime, UNIX_EPOCH}; +use getopts::{Matches, Options}; use std::env; +use std::fs::File; +use std::io; +use std::io::Write; use std::process; -use getopts::{Options, Matches}; +use std::str::{self, FromStr}; +use std::time::{SystemTime, UNIX_EPOCH}; -use smoltcp::phy::{Device, EthernetTracer, FaultInjector}; -use smoltcp::phy::{PcapWriter, PcapSink, PcapMode, PcapLinkType}; +use smoltcp::phy::{Device, FaultInjector, Tracer}; +use smoltcp::phy::{PcapMode, PcapWriter}; use smoltcp::time::Duration; pub fn create_options() -> (Options, Vec<&'static str>) { @@ -29,10 +28,17 @@ pub fn parse_options(options: &Options, free: Vec<&str>) -> Matches { } Ok(matches) => { if matches.opt_present("h") || matches.free.len() != free.len() { - let brief = format!("Usage: {} [OPTION]... {}", - env::args().nth(0).unwrap(), free.join(" ")); + let brief = format!( + "Usage: {} [OPTION]... {}", + env::args().nth(0).unwrap(), + free.join(" ") + ); print!("{}", options.usage(&brief)); - process::exit(if matches.free.len() != free.len() { 1 } else { 0 }) + process::exit(if matches.free.len() != free.len() { + 1 + } else { + 0 + }) } matches } @@ -41,46 +47,102 @@ pub fn parse_options(options: &Options, free: Vec<&str>) -> Matches { pub fn add_middleware_options(opts: &mut Options, _free: &mut Vec<&str>) { opts.optopt("", "pcap", "Write a packet capture file", "FILE"); - opts.optopt("", "drop-chance", "Chance of dropping a packet (%)", "CHANCE"); - opts.optopt("", "corrupt-chance", "Chance of corrupting a packet (%)", "CHANCE"); - opts.optopt("", "size-limit", "Drop packets larger than given size (octets)", "SIZE"); - opts.optopt("", "tx-rate-limit", "Drop packets after transmit rate exceeds given limit \ - (packets per interval)", "RATE"); - opts.optopt("", "rx-rate-limit", "Drop packets after transmit rate exceeds given limit \ - (packets per interval)", "RATE"); - opts.optopt("", "shaping-interval", "Sets the interval for rate limiting (ms)", "RATE"); + opts.optopt( + "", + "drop-chance", + "Chance of dropping a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "corrupt-chance", + "Chance of corrupting a packet (%)", + "CHANCE", + ); + opts.optopt( + "", + "size-limit", + "Drop packets larger than given size (octets)", + "SIZE", + ); + opts.optopt( + "", + "tx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "rx-rate-limit", + "Drop packets after transmit rate exceeds given limit \ + (packets per interval)", + "RATE", + ); + opts.optopt( + "", + "shaping-interval", + "Sets the interval for rate limiting (ms)", + "RATE", + ); } -pub fn parse_middleware_options(matches: &mut Matches, device: D, loopback: bool) - -> FaultInjector>>> - where D: for<'a> Device<'a> +pub fn parse_middleware_options( + matches: &mut Matches, + device: D, + loopback: bool, +) -> FaultInjector>>> +where + D: Device, { - let drop_chance = matches.opt_str("drop-chance").map(|s| u8::from_str(&s).unwrap()) - .unwrap_or(0); - let corrupt_chance = matches.opt_str("corrupt-chance").map(|s| u8::from_str(&s).unwrap()) - .unwrap_or(0); - let size_limit = matches.opt_str("size-limit").map(|s| usize::from_str(&s).unwrap()) - .unwrap_or(0); - let tx_rate_limit = matches.opt_str("tx-rate-limit").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); - let rx_rate_limit = matches.opt_str("rx-rate-limit").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); - let shaping_interval = matches.opt_str("shaping-interval").map(|s| u64::from_str(&s).unwrap()) - .unwrap_or(0); + let drop_chance = matches + .opt_str("drop-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let corrupt_chance = matches + .opt_str("corrupt-chance") + .map(|s| u8::from_str(&s).unwrap()) + .unwrap_or(0); + let size_limit = matches + .opt_str("size-limit") + .map(|s| usize::from_str(&s).unwrap()) + .unwrap_or(0); + let tx_rate_limit = matches + .opt_str("tx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let rx_rate_limit = matches + .opt_str("rx-rate-limit") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); + let shaping_interval = matches + .opt_str("shaping-interval") + .map(|s| u64::from_str(&s).unwrap()) + .unwrap_or(0); - let pcap_writer: Box; + let pcap_writer: Box; if let Some(pcap_filename) = matches.opt_str("pcap") { pcap_writer = Box::new(File::create(pcap_filename).expect("cannot open file")) } else { pcap_writer = Box::new(io::sink()) } - let seed = SystemTime::now().duration_since(UNIX_EPOCH).unwrap().subsec_nanos(); + let seed = SystemTime::now() + .duration_since(UNIX_EPOCH) + .unwrap() + .subsec_nanos(); + + let device = PcapWriter::new( + device, + pcap_writer, + if loopback { + PcapMode::TxOnly + } else { + PcapMode::Both + }, + ); - let device = PcapWriter::new(device, Rc::new(RefCell::new(pcap_writer)) as Rc, - if loopback { PcapMode::TxOnly } else { PcapMode::Both }, - PcapLinkType::Ethernet); - let device = EthernetTracer::new(device, |_timestamp, _printer| { + let device = Tracer::new(device, |_timestamp, _printer| { #[cfg(feature = "log")] trace!("{}", _printer); }); diff --git a/gen_config.py b/gen_config.py new file mode 100644 index 000000000..5d327a6db --- /dev/null +++ b/gen_config.py @@ -0,0 +1,85 @@ +import os + +abspath = os.path.abspath(__file__) +dname = os.path.dirname(abspath) +os.chdir(dname) + +features = [] + + +def feature(name, default, min, max, pow2=None): + vals = set() + val = min + while val <= max: + vals.add(val) + if pow2 == True or (isinstance(pow2, int) and val >= pow2): + val *= 2 + else: + val += 1 + vals.add(default) + + features.append( + { + "name": name, + "default": default, + "vals": sorted(list(vals)), + } + ) + + +feature("iface_max_addr_count", default=2, min=1, max=8) +feature("iface_max_multicast_group_count", default=4, min=1, max=1024, pow2=8) +feature("iface_max_sixlowpan_address_context_count", default=4, min=1, max=1024, pow2=8) +feature("iface_neighbor_cache_count", default=4, min=1, max=1024, pow2=8) +feature("iface_max_route_count", default=2, min=1, max=1024, pow2=8) +feature("fragmentation_buffer_size", default=1500, min=256, max=65536, pow2=True) +feature("assembler_max_segment_count", default=4, min=1, max=32, pow2=4) +feature("reassembly_buffer_size", default=1500, min=256, max=65536, pow2=True) +feature("reassembly_buffer_count", default=1, min=1, max=32, pow2=4) +feature("dns_max_result_count", default=1, min=1, max=32, pow2=4) +feature("dns_max_server_count", default=1, min=1, max=32, pow2=4) +feature("dns_max_name_size", default=255, min=64, max=255, pow2=True) +feature("rpl_relations_buffer_count", default=16, min=1, max=128, pow2=True) +feature("rpl_parents_buffer_count", default=8, min=2, max=32, pow2=True) + +# ========= Update Cargo.toml + +things = "" +for f in features: + name = f["name"].replace("_", "-") + for val in f["vals"]: + things += f"{name}-{val} = []" + if val == f["default"]: + things += " # Default" + things += "\n" + things += "\n" + +SEPARATOR_START = "# BEGIN AUTOGENERATED CONFIG FEATURES\n" +SEPARATOR_END = "# END AUTOGENERATED CONFIG FEATURES\n" +HELP = "# Generated by gen_config.py. DO NOT EDIT.\n" +with open("Cargo.toml", "r") as f: + data = f.read() +before, data = data.split(SEPARATOR_START, maxsplit=1) +_, after = data.split(SEPARATOR_END, maxsplit=1) +data = before + SEPARATOR_START + HELP + things + SEPARATOR_END + after +with open("Cargo.toml", "w") as f: + f.write(data) + + +# ========= Update build.rs + +things = "" +for f in features: + name = f["name"].upper() + things += f' ("{name}", {f["default"]}),\n' + +SEPARATOR_START = "// BEGIN AUTOGENERATED CONFIG FEATURES\n" +SEPARATOR_END = "// END AUTOGENERATED CONFIG FEATURES\n" +HELP = " // Generated by gen_config.py. DO NOT EDIT.\n" +with open("build.rs", "r") as f: + data = f.read() +before, data = data.split(SEPARATOR_START, maxsplit=1) +_, after = data.split(SEPARATOR_END, maxsplit=1) +data = before + SEPARATOR_START + HELP + things + " " + SEPARATOR_END + after +with open("build.rs", "w") as f: + f.write(data) diff --git a/src/dhcp/clientv4.rs b/src/dhcp/clientv4.rs deleted file mode 100644 index 498a7afb4..000000000 --- a/src/dhcp/clientv4.rs +++ /dev/null @@ -1,435 +0,0 @@ -use {Result, Error}; -use wire::{IpVersion, IpProtocol, IpEndpoint, IpAddress, - Ipv4Cidr, Ipv4Address, Ipv4Packet, Ipv4Repr, - UdpPacket, UdpRepr, - DhcpPacket, DhcpRepr, DhcpMessageType}; -use wire::dhcpv4::field as dhcpv4_field; -use socket::{SocketSet, SocketHandle, RawSocket, RawSocketBuffer}; -use phy::{Device, ChecksumCapabilities}; -use iface::EthernetInterface as Interface; -use time::{Instant, Duration}; -use super::{UDP_SERVER_PORT, UDP_CLIENT_PORT}; - -const DISCOVER_TIMEOUT: u64 = 10; -const REQUEST_TIMEOUT: u64 = 1; -const REQUEST_RETRIES: u16 = 15; -const RENEW_INTERVAL: u64 = 60; -const RENEW_RETRIES: u16 = 3; -const PARAMETER_REQUEST_LIST: &[u8] = &[ - dhcpv4_field::OPT_SUBNET_MASK, - dhcpv4_field::OPT_ROUTER, - dhcpv4_field::OPT_DOMAIN_NAME_SERVER, -]; - -/// IPv4 configuration data returned by `client.poll()` -#[derive(Debug)] -pub struct Config { - pub address: Option, - pub router: Option, - pub dns_servers: [Option; 3], -} - -#[derive(Debug)] -struct RequestState { - retry: u16, - endpoint_ip: Ipv4Address, - server_identifier: Ipv4Address, -} - -#[derive(Debug)] -struct RenewState { - retry: u16, - endpoint_ip: Ipv4Address, - server_identifier: Ipv4Address, -} - -#[derive(Debug)] -enum ClientState { - /// Discovering the DHCP server - Discovering, - /// Requesting an address - Requesting(RequestState), - /// Having an address, refresh it periodically - Renew(RenewState), -} - -pub struct Client { - state: ClientState, - raw_handle: SocketHandle, - /// When to send next request - next_egress: Instant, - transaction_id: u32, -} - -/// DHCP client with a RawSocket. -/// -/// To provide memory for the dynamic IP address, configure your -/// `Interface` with one of `ip_addrs` and the `ipv4_gateway` being -/// `Ipv4Address::UNSPECIFIED`. You must also assign this `0.0.0.0/0` -/// while the client's state is `Discovering`. Hence, the `poll()` -/// method returns a corresponding `Config` struct in this case. -/// -/// You must call `dhcp_client.poll()` after `iface.poll()` to send -/// and receive DHCP packets. -impl Client { - /// # Usage - /// ```rust - /// use smoltcp::socket::{SocketSet, RawSocketBuffer, RawPacketMetadata}; - /// use smoltcp::dhcp::Dhcpv4Client; - /// use smoltcp::time::Instant; - /// - /// let mut sockets = SocketSet::new(vec![]); - /// let dhcp_rx_buffer = RawSocketBuffer::new( - /// [RawPacketMetadata::EMPTY; 1], - /// vec![0; 600] - /// ); - /// let dhcp_tx_buffer = RawSocketBuffer::new( - /// [RawPacketMetadata::EMPTY; 1], - /// vec![0; 600] - /// ); - /// let mut dhcp = Dhcpv4Client::new( - /// &mut sockets, - /// dhcp_rx_buffer, dhcp_tx_buffer, - /// Instant::now() - /// ); - /// ``` - pub fn new<'a, 'b, 'c>(sockets: &mut SocketSet<'a, 'b, 'c>, rx_buffer: RawSocketBuffer<'b, 'c>, tx_buffer: RawSocketBuffer<'b, 'c>, now: Instant) -> Self - where 'b: 'c, - { - let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); - let raw_handle = sockets.add(raw_socket); - - Client { - state: ClientState::Discovering, - raw_handle, - next_egress: now, - transaction_id: 1, - } - } - - /// When to send next packet - /// - /// Useful for suspending execution after polling. - pub fn next_poll(&self, now: Instant) -> Duration { - self.next_egress - now - } - - /// Process incoming packets on the contained RawSocket, and send - /// DHCP requests when timeouts are ready. - /// - /// Applying the obtained network configuration is left to the - /// user. You must configure the new IPv4 address from the - /// returned `Config`. Otherwise, DHCP will not work. - /// - /// A Config can be returned from any valid DHCP reply. The client - /// performs no bookkeeping on configuration or their changes. - pub fn poll(&mut self, - iface: &mut Interface, sockets: &mut SocketSet, - now: Instant - ) -> Result> - where - DeviceT: for<'d> Device<'d>, - { - let checksum_caps = ChecksumCapabilities::default(); // ? - let mut raw_socket = sockets.get::(self.raw_handle); - - // Process incoming - let config = { - match raw_socket.recv() - .and_then(|packet| parse_udp(packet, &checksum_caps)) { - Ok((IpEndpoint { - addr: IpAddress::Ipv4(src_ip), - port: UDP_SERVER_PORT, - }, IpEndpoint { - addr: _, - port: UDP_CLIENT_PORT, - }, payload)) => - self.ingress(iface, now, payload, &src_ip), - Ok(_) => - return Err(Error::Unrecognized), - Err(Error::Exhausted) => - None, - Err(e) => - return Err(e), - } - }; - - if config.is_some() { - // Return a new config immediately so that addresses can - // be configured that are required by egress(). - Ok(config) - } else { - // Send requests - if raw_socket.can_send() && now >= self.next_egress { - self.egress(iface, &mut *raw_socket, &checksum_caps, now) - } else { - Ok(None) - } - } - } - - fn ingress(&mut self, - iface: &mut Interface, now: Instant, - data: &[u8], src_ip: &Ipv4Address - ) -> Option - where - DeviceT: for<'d> Device<'d>, - { - let dhcp_packet = match DhcpPacket::new_checked(data) { - Ok(dhcp_packet) => dhcp_packet, - Err(e) => { - net_debug!("DHCP invalid pkt from {}: {:?}", src_ip, e); - return None; - } - }; - let dhcp_repr = match DhcpRepr::parse(&dhcp_packet) { - Ok(dhcp_repr) => dhcp_repr, - Err(e) => { - net_debug!("DHCP error parsing pkt from {}: {:?}", src_ip, e); - return None; - } - }; - let mac = iface.ethernet_addr(); - if dhcp_repr.client_hardware_address != mac { return None } - if dhcp_repr.transaction_id != self.transaction_id { return None } - let server_identifier = match dhcp_repr.server_identifier { - Some(server_identifier) => server_identifier, - None => return None, - }; - net_debug!("DHCP recv {:?} from {} ({})", dhcp_repr.message_type, src_ip, server_identifier); - - let config = if (dhcp_repr.message_type == DhcpMessageType::Offer || - dhcp_repr.message_type == DhcpMessageType::Ack) && - dhcp_repr.your_ip != Ipv4Address::UNSPECIFIED { - let address = dhcp_repr.subnet_mask - .and_then(|mask| IpAddress::Ipv4(mask).to_prefix_len()) - .map(|prefix_len| Ipv4Cidr::new(dhcp_repr.your_ip, prefix_len)); - let router = dhcp_repr.router; - let dns_servers = dhcp_repr.dns_servers - .unwrap_or([None; 3]); - Some(Config { address, router, dns_servers }) - } else { - None - }; - - match self.state { - ClientState::Discovering - if dhcp_repr.message_type == DhcpMessageType::Offer => - { - self.next_egress = now; - let r_state = RequestState { - retry: 0, - endpoint_ip: *src_ip, - server_identifier, - }; - Some(ClientState::Requesting(r_state)) - } - ClientState::Requesting(ref r_state) - if dhcp_repr.message_type == DhcpMessageType::Ack && - server_identifier == r_state.server_identifier => - { - self.next_egress = now + Duration::from_secs(RENEW_INTERVAL); - let p_state = RenewState { - retry: 0, - endpoint_ip: *src_ip, - server_identifier, - }; - Some(ClientState::Renew(p_state)) - } - ClientState::Renew(ref mut p_state) - if dhcp_repr.message_type == DhcpMessageType::Ack && - server_identifier == p_state.server_identifier => - { - self.next_egress = now + Duration::from_secs(RENEW_INTERVAL); - p_state.retry = 0; - None - } - _ => None - }.map(|new_state| self.state = new_state); - - config - } - - fn egress Device<'d>>(&mut self, iface: &mut Interface, raw_socket: &mut RawSocket, checksum_caps: &ChecksumCapabilities, now: Instant) -> Result> { - // Reset after maximum amount of retries - let retries_exceeded = match self.state { - ClientState::Requesting(ref mut r_state) if r_state.retry >= REQUEST_RETRIES => { - net_debug!("DHCP request retries exceeded, restarting discovery"); - true - } - ClientState::Renew(ref mut r_state) if r_state.retry >= RENEW_RETRIES => { - net_debug!("DHCP renew retries exceeded, restarting discovery"); - true - } - _ => false - }; - if retries_exceeded { - self.reset(now); - // Return a config now so that user code assigns the - // 0.0.0.0/0 address, which will be used sending a DHCP - // discovery packet in the next call to egress(). - return Ok(Some(Config { - address: Some(Ipv4Cidr::new(Ipv4Address::UNSPECIFIED, 0)), - router: None, - dns_servers: [None; 3], - })); - } - - // Prepare sending next packet - self.transaction_id += 1; - let mac = iface.ethernet_addr(); - - let mut dhcp_repr = DhcpRepr { - message_type: DhcpMessageType::Discover, - transaction_id: self.transaction_id, - client_hardware_address: mac, - client_ip: Ipv4Address::UNSPECIFIED, - your_ip: Ipv4Address::UNSPECIFIED, - server_ip: Ipv4Address::UNSPECIFIED, - router: None, - subnet_mask: None, - relay_agent_ip: Ipv4Address::UNSPECIFIED, - broadcast: true, - requested_ip: None, - client_identifier: Some(mac), - server_identifier: None, - parameter_request_list: None, - max_size: Some(raw_socket.payload_recv_capacity() as u16), - dns_servers: None, - }; - let mut send_packet = |iface, endpoint, dhcp_repr| { - send_packet(iface, raw_socket, &endpoint, &dhcp_repr, checksum_caps) - .map(|()| None) - }; - - - match self.state { - ClientState::Discovering => { - self.next_egress = now + Duration::from_secs(DISCOVER_TIMEOUT); - let endpoint = IpEndpoint { - addr: Ipv4Address::BROADCAST.into(), - port: UDP_SERVER_PORT, - }; - net_trace!("DHCP send discover to {}: {:?}", endpoint, dhcp_repr); - send_packet(iface, endpoint, dhcp_repr) - } - ClientState::Requesting(ref mut r_state) => { - r_state.retry += 1; - self.next_egress = now + Duration::from_secs(REQUEST_TIMEOUT); - - let endpoint = IpEndpoint { - addr: Ipv4Address::BROADCAST.into(), - port: UDP_SERVER_PORT, - }; - let requested_ip = match iface.ipv4_addr() { - Some(addr) if !addr.is_unspecified() => - Some(addr), - _ => - None, - }; - dhcp_repr.message_type = DhcpMessageType::Request; - dhcp_repr.broadcast = false; - dhcp_repr.requested_ip = requested_ip; - dhcp_repr.server_identifier = Some(r_state.server_identifier); - dhcp_repr.parameter_request_list = Some(PARAMETER_REQUEST_LIST); - net_trace!("DHCP send request to {} = {:?}", endpoint, dhcp_repr); - send_packet(iface, endpoint, dhcp_repr) - } - ClientState::Renew(ref mut p_state) => { - p_state.retry += 1; - self.next_egress = now + Duration::from_secs(RENEW_INTERVAL); - - let endpoint = IpEndpoint { - addr: p_state.endpoint_ip.into(), - port: UDP_SERVER_PORT, - }; - let client_ip = iface.ipv4_addr().unwrap_or(Ipv4Address::UNSPECIFIED); - dhcp_repr.message_type = DhcpMessageType::Request; - dhcp_repr.client_ip = client_ip; - dhcp_repr.broadcast = false; - net_trace!("DHCP send renew to {}: {:?}", endpoint, dhcp_repr); - send_packet(iface, endpoint, dhcp_repr) - } - } - } - - /// Reset state and restart discovery phase. - /// - /// Use this to speed up acquisition of an address in a new - /// network if a link was down and it is now back up. - /// - /// You *must* configure a `0.0.0.0` address on your interface - /// before the next call to `poll()`! - pub fn reset(&mut self, now: Instant) { - net_trace!("DHCP reset"); - self.state = ClientState::Discovering; - self.next_egress = now; - } -} - -fn send_packet Device<'d>>(iface: &mut Interface, raw_socket: &mut RawSocket, endpoint: &IpEndpoint, dhcp_repr: &DhcpRepr, checksum_caps: &ChecksumCapabilities) -> Result<()> { - let mut dhcp_payload_buf = [0; 320]; - assert!(dhcp_repr.buffer_len() <= dhcp_payload_buf.len()); - let dhcp_payload = &mut dhcp_payload_buf[0..dhcp_repr.buffer_len()]; - { - let mut dhcp_packet = DhcpPacket::new_checked(&mut dhcp_payload[..])?; - dhcp_repr.emit(&mut dhcp_packet)?; - } - - let udp_repr = UdpRepr { - src_port: UDP_CLIENT_PORT, - dst_port: endpoint.port, - payload: dhcp_payload, - }; - - let src_addr = iface.ipv4_addr().unwrap(); - let dst_addr = match endpoint.addr { - IpAddress::Ipv4(addr) => addr, - _ => return Err(Error::Illegal), - }; - let ipv4_repr = Ipv4Repr { - src_addr, - dst_addr, - protocol: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 64, - }; - - let mut packet = raw_socket.send( - ipv4_repr.buffer_len() + udp_repr.buffer_len() - )?; - { - let mut ipv4_packet = Ipv4Packet::new_unchecked(&mut packet); - ipv4_repr.emit(&mut ipv4_packet, &checksum_caps); - } - { - let mut udp_packet = UdpPacket::new_unchecked( - &mut packet[ipv4_repr.buffer_len()..] - ); - udp_repr.emit(&mut udp_packet, - &src_addr.into(), &dst_addr.into(), - checksum_caps); - } - Ok(()) -} - -fn parse_udp<'a>(data: &'a [u8], checksum_caps: &ChecksumCapabilities) -> Result<(IpEndpoint, IpEndpoint, &'a [u8])> { - let ipv4_packet = Ipv4Packet::new_checked(data)?; - let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &checksum_caps)?; - let udp_packet = UdpPacket::new_checked(ipv4_packet.payload())?; - let udp_repr = UdpRepr::parse( - &udp_packet, - &ipv4_repr.src_addr.into(), &ipv4_repr.dst_addr.into(), - checksum_caps - )?; - let src = IpEndpoint { - addr: ipv4_repr.src_addr.into(), - port: udp_repr.src_port, - }; - let dst = IpEndpoint { - addr: ipv4_repr.dst_addr.into(), - port: udp_repr.dst_port, - }; - let data = udp_repr.payload; - Ok((src, dst, data)) -} diff --git a/src/dhcp/mod.rs b/src/dhcp/mod.rs deleted file mode 100644 index 2c1b47a99..000000000 --- a/src/dhcp/mod.rs +++ /dev/null @@ -1,5 +0,0 @@ -pub const UDP_SERVER_PORT: u16 = 67; -pub const UDP_CLIENT_PORT: u16 = 68; - -mod clientv4; -pub use self::clientv4::{Client as Dhcpv4Client, Config as Dhcpv4Config}; diff --git a/src/iface/ethernet.rs b/src/iface/ethernet.rs deleted file mode 100644 index f879efef9..000000000 --- a/src/iface/ethernet.rs +++ /dev/null @@ -1,2808 +0,0 @@ -// Heads up! Before working on this file you should read the parts -// of RFC 1122 that discuss Ethernet, ARP and IP for any IPv4 work -// and RFCs 8200 and 4861 for any IPv6 and NDISC work. - -use core::cmp; -use managed::{ManagedSlice, ManagedMap}; -#[cfg(not(feature = "proto-igmp"))] -use core::marker::PhantomData; - -use {Error, Result}; -use phy::{Device, DeviceCapabilities, RxToken, TxToken}; -use time::{Duration, Instant}; -use wire::pretty_print::PrettyPrinter; -use wire::{EthernetAddress, EthernetProtocol, EthernetFrame}; -use wire::{IpAddress, IpProtocol, IpRepr, IpCidr}; -#[cfg(feature = "proto-ipv6")] -use wire::{Ipv6Address, Ipv6Packet, Ipv6Repr, IPV6_MIN_MTU}; -#[cfg(feature = "proto-ipv4")] -use wire::{Ipv4Address, Ipv4Packet, Ipv4Repr, IPV4_MIN_MTU}; -#[cfg(feature = "proto-ipv4")] -use wire::{ArpPacket, ArpRepr, ArpOperation}; -#[cfg(feature = "proto-ipv4")] -use wire::{Icmpv4Packet, Icmpv4Repr, Icmpv4DstUnreachable}; -#[cfg(feature = "proto-igmp")] -use wire::{IgmpPacket, IgmpRepr, IgmpVersion}; -#[cfg(feature = "proto-ipv6")] -use wire::{Icmpv6Packet, Icmpv6Repr, Icmpv6ParamProblem}; -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -use wire::IcmpRepr; -#[cfg(feature = "proto-ipv6")] -use wire::{Ipv6HopByHopHeader, Ipv6HopByHopRepr}; -#[cfg(feature = "proto-ipv6")] -use wire::{Ipv6OptionRepr, Ipv6OptionFailureType}; -#[cfg(feature = "proto-ipv6")] -use wire::{NdiscNeighborFlags, NdiscRepr}; -#[cfg(all(feature = "proto-ipv6", feature = "socket-udp"))] -use wire::Icmpv6DstUnreachable; -#[cfg(feature = "socket-udp")] -use wire::{UdpPacket, UdpRepr}; -#[cfg(feature = "socket-tcp")] -use wire::{TcpPacket, TcpRepr, TcpControl}; - -use socket::{Socket, SocketSet, AnySocket, PollAt}; -#[cfg(feature = "socket-raw")] -use socket::RawSocket; -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -use socket::IcmpSocket; -#[cfg(feature = "socket-udp")] -use socket::UdpSocket; -#[cfg(feature = "socket-tcp")] -use socket::TcpSocket; -use super::{NeighborCache, NeighborAnswer}; -use super::Routes; - -/// An Ethernet network interface. -/// -/// The network interface logically owns a number of other data structures; to avoid -/// a dependency on heap allocation, it instead owns a `BorrowMut<[T]>`, which can be -/// a `&mut [T]`, or `Vec` if a heap is available. -pub struct Interface<'b, 'c, 'e, DeviceT: for<'d> Device<'d>> { - device: DeviceT, - inner: InterfaceInner<'b, 'c, 'e>, -} - -/// The device independent part of an Ethernet network interface. -/// -/// Separating the device from the data required for prorcessing and dispatching makes -/// it possible to borrow them independently. For example, the tx and rx tokens borrow -/// the `device` mutably until they're used, which makes it impossible to call other -/// methods on the `Interface` in this time (since its `device` field is borrowed -/// exclusively). However, it is still possible to call methods on its `inner` field. -struct InterfaceInner<'b, 'c, 'e> { - neighbor_cache: NeighborCache<'b>, - ethernet_addr: EthernetAddress, - ip_addrs: ManagedSlice<'c, IpCidr>, - #[cfg(feature = "proto-ipv4")] - any_ip: bool, - routes: Routes<'e>, - #[cfg(feature = "proto-igmp")] - ipv4_multicast_groups: ManagedMap<'e, Ipv4Address, ()>, - #[cfg(not(feature = "proto-igmp"))] - _ipv4_multicast_groups: PhantomData<&'e ()>, - /// When to report for (all or) the next multicast group membership via IGMP - #[cfg(feature = "proto-igmp")] - igmp_report_state: IgmpReportState, - device_capabilities: DeviceCapabilities, -} - -/// A builder structure used for creating a Ethernet network -/// interface. -pub struct InterfaceBuilder <'b, 'c, 'e, DeviceT: for<'d> Device<'d>> { - device: DeviceT, - ethernet_addr: Option, - neighbor_cache: Option>, - ip_addrs: ManagedSlice<'c, IpCidr>, - #[cfg(feature = "proto-ipv4")] - any_ip: bool, - routes: Routes<'e>, - /// Does not share storage with `ipv6_multicast_groups` to avoid IPv6 size overhead. - #[cfg(feature = "proto-igmp")] - ipv4_multicast_groups: ManagedMap<'e, Ipv4Address, ()>, - #[cfg(not(feature = "proto-igmp"))] - _ipv4_multicast_groups: PhantomData<&'e ()>, -} - -impl<'b, 'c, 'e, DeviceT> InterfaceBuilder<'b, 'c, 'e, DeviceT> - where DeviceT: for<'d> Device<'d> { - /// Create a builder used for creating a network interface using the - /// given device and address. - /// - /// # Examples - /// - /// ``` - /// # use std::collections::BTreeMap; - /// use smoltcp::iface::{EthernetInterfaceBuilder, NeighborCache}; - /// # use smoltcp::phy::Loopback; - /// use smoltcp::wire::{EthernetAddress, IpCidr, IpAddress}; - /// - /// let device = // ... - /// # Loopback::new(); - /// let hw_addr = // ... - /// # EthernetAddress::default(); - /// let neighbor_cache = // ... - /// # NeighborCache::new(BTreeMap::new()); - /// let ip_addrs = // ... - /// # []; - /// let iface = EthernetInterfaceBuilder::new(device) - /// .ethernet_addr(hw_addr) - /// .neighbor_cache(neighbor_cache) - /// .ip_addrs(ip_addrs) - /// .finalize(); - /// ``` - pub fn new(device: DeviceT) -> Self { - InterfaceBuilder { - device: device, - ethernet_addr: None, - neighbor_cache: None, - ip_addrs: ManagedSlice::Borrowed(&mut []), - #[cfg(feature = "proto-ipv4")] - any_ip: false, - routes: Routes::new(ManagedMap::Borrowed(&mut [])), - #[cfg(feature = "proto-igmp")] - ipv4_multicast_groups: ManagedMap::Borrowed(&mut []), - #[cfg(not(feature = "proto-igmp"))] - _ipv4_multicast_groups: PhantomData, - } - } - - /// Set the Ethernet address the interface will use. See also - /// [ethernet_addr]. - /// - /// # Panics - /// This function panics if the address is not unicast. - /// - /// [ethernet_addr]: struct.EthernetInterface.html#method.ethernet_addr - pub fn ethernet_addr(mut self, addr: EthernetAddress) -> Self { - InterfaceInner::check_ethernet_addr(&addr); - self.ethernet_addr = Some(addr); - self - } - - /// Set the IP addresses the interface will use. See also - /// [ip_addrs]. - /// - /// # Panics - /// This function panics if any of the addresses are not unicast. - /// - /// [ip_addrs]: struct.EthernetInterface.html#method.ip_addrs - pub fn ip_addrs(mut self, ip_addrs: T) -> Self - where T: Into> - { - let ip_addrs = ip_addrs.into(); - InterfaceInner::check_ip_addrs(&ip_addrs); - self.ip_addrs = ip_addrs; - self - } - - /// Enable or disable the AnyIP capability, allowing packets to be received - /// locally on IPv4 addresses other than the interface's configured [ip_addrs]. - /// When AnyIP is enabled and a route prefix in [routes] specifies one of - /// the interface's [ip_addrs] as its gateway, the interface will accept - /// packets addressed to that prefix. - /// - /// # IPv6 - /// - /// This option is not available or required for IPv6 as packets sent to - /// the interface are not filtered by IPv6 address. - /// - /// [routes]: struct.EthernetInterface.html#method.routes - /// [ip_addrs]: struct.EthernetInterface.html#method.ip_addrs - #[cfg(feature = "proto-ipv4")] - pub fn any_ip(mut self, enabled: bool) -> Self { - self.any_ip = enabled; - self - } - - /// Set the IP routes the interface will use. See also - /// [routes]. - /// - /// [routes]: struct.EthernetInterface.html#method.routes - pub fn routes(mut self, routes: T) -> InterfaceBuilder<'b, 'c, 'e, DeviceT> - where T: Into> - { - self.routes = routes.into(); - self - } - - /// Provide storage for multicast groups. - /// - /// Join multicast groups by calling [`join_multicast_group()`] on an `Interface`. - /// Using [`join_multicast_group()`] will send initial membership reports. - /// - /// A previously destroyed interface can be recreated by reusing the multicast group - /// storage, i.e. providing a non-empty storage to `ipv4_multicast_groups()`. - /// Note that this way initial membership reports are **not** sent. - /// - /// [`join_multicast_group()`]: struct.EthernetInterface.html#method.join_multicast_group - #[cfg(feature = "proto-igmp")] - pub fn ipv4_multicast_groups(mut self, ipv4_multicast_groups: T) -> Self - where T: Into> - { - self.ipv4_multicast_groups = ipv4_multicast_groups.into(); - self - } - - /// Set the Neighbor Cache the interface will use. - pub fn neighbor_cache(mut self, neighbor_cache: NeighborCache<'b>) -> Self { - self.neighbor_cache = Some(neighbor_cache); - self - } - - /// Create a network interface using the previously provided configuration. - /// - /// # Panics - /// If a required option is not provided, this function will panic. Required - /// options are: - /// - /// - [ethernet_addr] - /// - [neighbor_cache] - /// - /// [ethernet_addr]: #method.ethernet_addr - /// [neighbor_cache]: #method.neighbor_cache - pub fn finalize(self) -> Interface<'b, 'c, 'e, DeviceT> { - match (self.ethernet_addr, self.neighbor_cache) { - (Some(ethernet_addr), Some(neighbor_cache)) => { - let device_capabilities = self.device.capabilities(); - - Interface { - device: self.device, - inner: InterfaceInner { - ethernet_addr, device_capabilities, neighbor_cache, - ip_addrs: self.ip_addrs, - #[cfg(feature = "proto-ipv4")] - any_ip: self.any_ip, - routes: self.routes, - #[cfg(feature = "proto-igmp")] - ipv4_multicast_groups: self.ipv4_multicast_groups, - #[cfg(not(feature = "proto-igmp"))] - _ipv4_multicast_groups: PhantomData, - #[cfg(feature = "proto-igmp")] - igmp_report_state: IgmpReportState::Inactive, - } - } - }, - _ => panic!("a required option was not set"), - } - } -} - -#[derive(Debug, PartialEq)] -enum Packet<'a> { - None, - #[cfg(feature = "proto-ipv4")] - Arp(ArpRepr), - #[cfg(feature = "proto-ipv4")] - Icmpv4((Ipv4Repr, Icmpv4Repr<'a>)), - #[cfg(feature = "proto-igmp")] - Igmp((Ipv4Repr, IgmpRepr)), - #[cfg(feature = "proto-ipv6")] - Icmpv6((Ipv6Repr, Icmpv6Repr<'a>)), - #[cfg(feature = "socket-raw")] - Raw((IpRepr, &'a [u8])), - #[cfg(feature = "socket-udp")] - Udp((IpRepr, UdpRepr<'a>)), - #[cfg(feature = "socket-tcp")] - Tcp((IpRepr, TcpRepr<'a>)) -} - -impl<'a> Packet<'a> { - fn neighbor_addr(&self) -> Option { - match self { - &Packet::None => None, - #[cfg(feature = "proto-ipv4")] - &Packet::Arp(_) => None, - #[cfg(feature = "proto-ipv4")] - &Packet::Icmpv4((ref ipv4_repr, _)) => Some(ipv4_repr.dst_addr.into()), - #[cfg(feature = "proto-igmp")] - &Packet::Igmp((ref ipv4_repr, _)) => Some(ipv4_repr.dst_addr.into()), - #[cfg(feature = "proto-ipv6")] - &Packet::Icmpv6((ref ipv6_repr, _)) => Some(ipv6_repr.dst_addr.into()), - #[cfg(feature = "socket-raw")] - &Packet::Raw((ref ip_repr, _)) => Some(ip_repr.dst_addr()), - #[cfg(feature = "socket-udp")] - &Packet::Udp((ref ip_repr, _)) => Some(ip_repr.dst_addr()), - #[cfg(feature = "socket-tcp")] - &Packet::Tcp((ref ip_repr, _)) => Some(ip_repr.dst_addr()) - } - } -} - -#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] -fn icmp_reply_payload_len(len: usize, mtu: usize, header_len: usize) -> usize { - // Send back as much of the original payload as will fit within - // the minimum MTU required by IPv4. See RFC 1812 § 4.3.2.3 for - // more details. - // - // Since the entire network layer packet must fit within the minumum - // MTU supported, the payload must not exceed the following: - // - // - IP Header Size * 2 - ICMPv4 DstUnreachable hdr size - cmp::min(len, mtu - header_len * 2 - 8) -} - -#[cfg(feature = "proto-igmp")] -enum IgmpReportState { - Inactive, - ToGeneralQuery { - version: IgmpVersion, - timeout: Instant, - interval: Duration, - next_index: usize - }, - ToSpecificQuery { - version: IgmpVersion, - timeout: Instant, - group: Ipv4Address - }, -} - -impl<'b, 'c, 'e, DeviceT> Interface<'b, 'c, 'e, DeviceT> - where DeviceT: for<'d> Device<'d> { - /// Get the Ethernet address of the interface. - pub fn ethernet_addr(&self) -> EthernetAddress { - self.inner.ethernet_addr - } - - /// Set the Ethernet address of the interface. - /// - /// # Panics - /// This function panics if the address is not unicast. - pub fn set_ethernet_addr(&mut self, addr: EthernetAddress) { - self.inner.ethernet_addr = addr; - InterfaceInner::check_ethernet_addr(&self.inner.ethernet_addr); - } - - /// Get a reference to the inner device. - pub fn device(&self) -> &DeviceT { - &self.device - } - - /// Get a mutable reference to the inner device. - /// - /// There are no invariants imposed on the device by the interface itself. Furthermore the - /// trait implementations, required for references of all lifetimes, guarantees that the - /// mutable reference can not invalidate the device as such. For some devices, such access may - /// still allow modifications with adverse effects on the usability as a `phy` device. You - /// should not use them this way. - pub fn device_mut(&mut self) -> &mut DeviceT { - &mut self.device - } - - /// Add an address to a list of subscribed multicast IP addresses. - /// - /// Returns `Ok(announce_sent)` if the address was added successfully, where `annouce_sent` - /// indicates whether an initial immediate announcement has been sent. - pub fn join_multicast_group>(&mut self, addr: T, _timestamp: Instant) -> Result { - match addr.into() { - #[cfg(feature = "proto-igmp")] - IpAddress::Ipv4(addr) => { - let is_not_new = self.inner.ipv4_multicast_groups.insert(addr, ()) - .map_err(|_| Error::Exhausted)? - .is_some(); - if is_not_new { - Ok(false) - } else if let Some(pkt) = - self.inner.igmp_report_packet(IgmpVersion::Version2, addr) { - // Send initial membership report - let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch(tx_token, _timestamp, pkt)?; - Ok(true) - } else { - Ok(false) - } - } - // Multicast is not yet implemented for other address families - _ => Err(Error::Unaddressable) - } - } - - /// Remove an address from the subscribed multicast IP addresses. - /// - /// Returns `Ok(leave_sent)` if the address was removed successfully, where `leave_sent` - /// indicates whether an immediate leave packet has been sent. - pub fn leave_multicast_group>(&mut self, addr: T, _timestamp: Instant) -> Result { - match addr.into() { - #[cfg(feature = "proto-igmp")] - IpAddress::Ipv4(addr) => { - let was_not_present = self.inner.ipv4_multicast_groups.remove(&addr) - .is_none(); - if was_not_present { - Ok(false) - } else if let Some(pkt) = self.inner.igmp_leave_packet(addr) { - // Send group leave packet - let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch(tx_token, _timestamp, pkt)?; - Ok(true) - } else { - Ok(false) - } - } - // Multicast is not yet implemented for other address families - _ => Err(Error::Unaddressable) - } - } - - /// Check whether the interface listens to given destination multicast IP address. - pub fn has_multicast_group>(&self, addr: T) -> bool { - self.inner.has_multicast_group(addr) - } - - /// Get the IP addresses of the interface. - pub fn ip_addrs(&self) -> &[IpCidr] { - self.inner.ip_addrs.as_ref() - } - - /// Get the first IPv4 address if present. - #[cfg(feature = "proto-ipv4")] - pub fn ipv4_addr(&self) -> Option { - self.ip_addrs().iter() - .filter_map(|cidr| match cidr.address() { - IpAddress::Ipv4(addr) => Some(addr), - _ => None, - }).next() - } - - /// Update the IP addresses of the interface. - /// - /// # Panics - /// This function panics if any of the addresses are not unicast. - pub fn update_ip_addrs)>(&mut self, f: F) { - f(&mut self.inner.ip_addrs); - InterfaceInner::check_ip_addrs(&self.inner.ip_addrs) - } - - /// Check whether the interface has the given IP address assigned. - pub fn has_ip_addr>(&self, addr: T) -> bool { - self.inner.has_ip_addr(addr) - } - - /// Get the first IPv4 address of the interface. - #[cfg(feature = "proto-ipv4")] - pub fn ipv4_address(&self) -> Option { - self.inner.ipv4_address() - } - - pub fn routes(&self) -> &Routes<'e> { - &self.inner.routes - } - - pub fn routes_mut(&mut self) -> &mut Routes<'e> { - &mut self.inner.routes - } - - /// Transmit packets queued in the given sockets, and receive packets queued - /// in the device. - /// - /// This function returns a boolean value indicating whether any packets were - /// processed or emitted, and thus, whether the readiness of any socket might - /// have changed. - /// - /// # Errors - /// This method will routinely return errors in response to normal network - /// activity as well as certain boundary conditions such as buffer exhaustion. - /// These errors are provided as an aid for troubleshooting, and are meant - /// to be logged and ignored. - /// - /// As a special case, `Err(Error::Unrecognized)` is returned in response to - /// packets containing any unsupported protocol, option, or form, which is - /// a very common occurrence and on a production system it should not even - /// be logged. - pub fn poll(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result { - let mut readiness_may_have_changed = false; - loop { - let processed_any = self.socket_ingress(sockets, timestamp)?; - let emitted_any = self.socket_egress(sockets, timestamp)?; - - #[cfg(feature = "proto-igmp")] - self.igmp_egress(timestamp)?; - - if processed_any || emitted_any { - readiness_may_have_changed = true; - } else { - break - } - } - Ok(readiness_may_have_changed) - } - - /// Return a _soft deadline_ for calling [poll] the next time. - /// The [Instant] returned is the time at which you should call [poll] next. - /// It is harmless (but wastes energy) to call it before the [Instant], and - /// potentially harmful (impacting quality of service) to call it after the - /// [Instant] - /// - /// [poll]: #method.poll - /// [Instant]: struct.Instant.html - pub fn poll_at(&self, sockets: &SocketSet, timestamp: Instant) -> Option { - sockets.iter().filter_map(|socket| { - let socket_poll_at = socket.poll_at(); - match socket.meta().poll_at(socket_poll_at, |ip_addr| - self.inner.has_neighbor(&ip_addr, timestamp)) { - PollAt::Ingress => None, - PollAt::Time(instant) => Some(instant), - PollAt::Now => Some(Instant::from_millis(0)), - } - }).min() - } - - /// Return an _advisory wait time_ for calling [poll] the next time. - /// The [Duration] returned is the time left to wait before calling [poll] next. - /// It is harmless (but wastes energy) to call it before the [Duration] has passed, - /// and potentially harmful (impacting quality of service) to call it after the - /// [Duration] has passed. - /// - /// [poll]: #method.poll - /// [Duration]: struct.Duration.html - pub fn poll_delay(&self, sockets: &SocketSet, timestamp: Instant) -> Option { - match self.poll_at(sockets, timestamp) { - Some(poll_at) if timestamp < poll_at => { - Some(poll_at - timestamp) - } - Some(_) => { - Some(Duration::from_millis(0)) - } - _ => None - } - } - - fn socket_ingress(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result { - let mut processed_any = false; - loop { - let &mut Self { ref mut device, ref mut inner } = self; - let (rx_token, tx_token) = match device.receive() { - None => break, - Some(tokens) => tokens, - }; - rx_token.consume(timestamp, |frame| { - inner.process_ethernet(sockets, timestamp, &frame).map_err(|err| { - net_debug!("cannot process ingress packet: {}", err); - net_debug!("packet dump follows:\n{}", - PrettyPrinter::>::new("", &frame)); - err - }).and_then(|response| { - processed_any = true; - inner.dispatch(tx_token, timestamp, response).map_err(|err| { - net_debug!("cannot dispatch response packet: {}", err); - err - }) - }) - })?; - } - Ok(processed_any) - } - - fn socket_egress(&mut self, sockets: &mut SocketSet, timestamp: Instant) -> Result { - let mut caps = self.device.capabilities(); - caps.max_transmission_unit -= EthernetFrame::<&[u8]>::header_len(); - - let mut emitted_any = false; - for mut socket in sockets.iter_mut() { - if !socket.meta_mut().egress_permitted(|ip_addr| - self.inner.has_neighbor(&ip_addr, timestamp)) { - continue - } - - let mut neighbor_addr = None; - let mut device_result = Ok(()); - let &mut Self { ref mut device, ref mut inner } = self; - - macro_rules! respond { - ($response:expr) => ({ - let response = $response; - neighbor_addr = response.neighbor_addr(); - let tx_token = device.transmit().ok_or(Error::Exhausted)?; - device_result = inner.dispatch(tx_token, timestamp, response); - device_result - }) - } - - let socket_result = - match *socket { - #[cfg(feature = "socket-raw")] - Socket::Raw(ref mut socket) => - socket.dispatch(&caps.checksum, |response| - respond!(Packet::Raw(response))), - #[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] - Socket::Icmp(ref mut socket) => - socket.dispatch(&caps, |response| { - match response { - #[cfg(feature = "proto-ipv4")] - (IpRepr::Ipv4(ipv4_repr), IcmpRepr::Ipv4(icmpv4_repr)) => - respond!(Packet::Icmpv4((ipv4_repr, icmpv4_repr))), - #[cfg(feature = "proto-ipv6")] - (IpRepr::Ipv6(ipv6_repr), IcmpRepr::Ipv6(icmpv6_repr)) => - respond!(Packet::Icmpv6((ipv6_repr, icmpv6_repr))), - _ => Err(Error::Unaddressable) - } - }), - #[cfg(feature = "socket-udp")] - Socket::Udp(ref mut socket) => - socket.dispatch(|response| - respond!(Packet::Udp(response))), - #[cfg(feature = "socket-tcp")] - Socket::Tcp(ref mut socket) => - socket.dispatch(timestamp, &caps, |response| - respond!(Packet::Tcp(response))), - Socket::__Nonexhaustive(_) => unreachable!() - }; - - match (device_result, socket_result) { - (Err(Error::Exhausted), _) => break, // nowhere to transmit - (Ok(()), Err(Error::Exhausted)) => (), // nothing to transmit - (Err(Error::Unaddressable), _) => { - // `NeighborCache` already takes care of rate limiting the neighbor discovery - // requests from the socket. However, without an additional rate limiting - // mechanism, we would spin on every socket that has yet to discover its - // neighboor. - socket.meta_mut().neighbor_missing(timestamp, - neighbor_addr.expect("non-IP response packet")); - break - } - (Err(err), _) | (_, Err(err)) => { - net_debug!("{}: cannot dispatch egress packet: {}", - socket.meta().handle, err); - return Err(err) - } - (Ok(()), Ok(())) => emitted_any = true - } - } - Ok(emitted_any) - } - - /// Depending on `igmp_report_state` and the therein contained - /// timeouts, send IGMP membership reports. - #[cfg(feature = "proto-igmp")] - fn igmp_egress(&mut self, timestamp: Instant) -> Result { - match self.inner.igmp_report_state { - IgmpReportState::ToSpecificQuery { version, timeout, group } - if timestamp >= timeout => { - if let Some(pkt) = self.inner.igmp_report_packet(version, group) { - // Send initial membership report - let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch(tx_token, timestamp, pkt)?; - } - - self.inner.igmp_report_state = IgmpReportState::Inactive; - Ok(true) - } - IgmpReportState::ToGeneralQuery { version, timeout, interval, next_index } - if timestamp >= timeout => { - let addr = self.inner.ipv4_multicast_groups - .iter() - .nth(next_index) - .map(|(addr, ())| *addr); - - match addr { - Some(addr) => { - if let Some(pkt) = self.inner.igmp_report_packet(version, addr) { - // Send initial membership report - let tx_token = self.device.transmit().ok_or(Error::Exhausted)?; - self.inner.dispatch(tx_token, timestamp, pkt)?; - } - - let next_timeout = (timeout + interval).max(timestamp); - self.inner.igmp_report_state = IgmpReportState::ToGeneralQuery { - version, timeout: next_timeout, interval, next_index: next_index + 1 - }; - Ok(true) - } - - None => { - self.inner.igmp_report_state = IgmpReportState::Inactive; - Ok(false) - } - } - } - _ => Ok(false) - } - } -} - -impl<'b, 'c, 'e> InterfaceInner<'b, 'c, 'e> { - fn check_ethernet_addr(addr: &EthernetAddress) { - if addr.is_multicast() { - panic!("Ethernet address {} is not unicast", addr) - } - } - - fn check_ip_addrs(addrs: &[IpCidr]) { - for cidr in addrs { - if !cidr.address().is_unicast() && !cidr.address().is_unspecified() { - panic!("IP address {} is not unicast", cidr.address()) - } - } - } - - /// Determine if the given `Ipv6Address` is the solicited node - /// multicast address for a IPv6 addresses assigned to the interface. - /// See [RFC 4291 § 2.7.1] for more details. - /// - /// [RFC 4291 § 2.7.1]: https://tools.ietf.org/html/rfc4291#section-2.7.1 - #[cfg(feature = "proto-ipv6")] - pub fn has_solicited_node(&self, addr: Ipv6Address) -> bool { - self.ip_addrs.iter().find(|cidr| { - match *cidr { - &IpCidr::Ipv6(cidr) if cidr.address() != Ipv6Address::LOOPBACK=> { - // Take the lower order 24 bits of the IPv6 address and - // append those bits to FF02:0:0:0:0:1:FF00::/104. - addr.as_bytes()[14..] == cidr.address().as_bytes()[14..] - } - _ => false, - } - }).is_some() - } - - /// Check whether the interface has the given IP address assigned. - fn has_ip_addr>(&self, addr: T) -> bool { - let addr = addr.into(); - self.ip_addrs.iter().any(|probe| probe.address() == addr) - } - - /// Get the first IPv4 address of the interface. - #[cfg(feature = "proto-ipv4")] - pub fn ipv4_address(&self) -> Option { - self.ip_addrs.iter() - .filter_map( - |addr| match addr { - &IpCidr::Ipv4(cidr) => Some(cidr.address()), - _ => None, - }) - .next() - } - - /// Check whether the interface listens to given destination multicast IP address. - /// - /// If built without feature `proto-igmp` this function will - /// always return `false`. - pub fn has_multicast_group>(&self, addr: T) -> bool { - match addr.into() { - #[cfg(feature = "proto-igmp")] - IpAddress::Ipv4(key) => - key == Ipv4Address::MULTICAST_ALL_SYSTEMS || - self.ipv4_multicast_groups.get(&key).is_some(), - _ => - false, - } - } - - fn process_ethernet<'frame, T: AsRef<[u8]>> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, frame: &'frame T) -> - Result> - { - let eth_frame = EthernetFrame::new_checked(frame)?; - - // Ignore any packets not directed to our hardware address or any of the multicast groups. - if !eth_frame.dst_addr().is_broadcast() && - !eth_frame.dst_addr().is_multicast() && - eth_frame.dst_addr() != self.ethernet_addr - { - return Ok(Packet::None) - } - - match eth_frame.ethertype() { - #[cfg(feature = "proto-ipv4")] - EthernetProtocol::Arp => - self.process_arp(timestamp, ð_frame), - #[cfg(feature = "proto-ipv4")] - EthernetProtocol::Ipv4 => - self.process_ipv4(sockets, timestamp, ð_frame), - #[cfg(feature = "proto-ipv6")] - EthernetProtocol::Ipv6 => - self.process_ipv6(sockets, timestamp, ð_frame), - // Drop all other traffic. - _ => Err(Error::Unrecognized), - } - } - - #[cfg(feature = "proto-ipv4")] - fn process_arp<'frame, T: AsRef<[u8]>> - (&mut self, timestamp: Instant, eth_frame: &EthernetFrame<&'frame T>) -> - Result> - { - let arp_packet = ArpPacket::new_checked(eth_frame.payload())?; - let arp_repr = ArpRepr::parse(&arp_packet)?; - - match arp_repr { - // Respond to ARP requests aimed at us, and fill the ARP cache from all ARP - // requests and replies, to minimize the chance that we have to perform - // an explicit ARP request. - ArpRepr::EthernetIpv4 { - operation, source_hardware_addr, source_protocol_addr, target_protocol_addr, .. - } => { - if source_protocol_addr.is_unicast() && source_hardware_addr.is_unicast() { - self.neighbor_cache.fill(source_protocol_addr.into(), - source_hardware_addr, - timestamp); - } else { - // Discard packets with non-unicast source addresses. - net_debug!("non-unicast source address"); - return Err(Error::Malformed) - } - - if operation == ArpOperation::Request && self.has_ip_addr(target_protocol_addr) { - Ok(Packet::Arp(ArpRepr::EthernetIpv4 { - operation: ArpOperation::Reply, - source_hardware_addr: self.ethernet_addr, - source_protocol_addr: target_protocol_addr, - target_hardware_addr: source_hardware_addr, - target_protocol_addr: source_protocol_addr - })) - } else { - Ok(Packet::None) - } - } - - _ => Err(Error::Unrecognized) - } - } - - #[cfg(all(any(feature = "proto-ipv4", feature = "proto-ipv6"), feature = "socket-raw"))] - fn raw_socket_filter<'frame>(&mut self, sockets: &mut SocketSet, ip_repr: &IpRepr, - ip_payload: &'frame [u8]) -> bool { - let checksum_caps = self.device_capabilities.checksum.clone(); - let mut handled_by_raw_socket = false; - - // Pass every IP packet to all raw sockets we have registered. - for mut raw_socket in sockets.iter_mut().filter_map(RawSocket::downcast) { - if !raw_socket.accepts(&ip_repr) { continue } - - match raw_socket.process(&ip_repr, ip_payload, &checksum_caps) { - // The packet is valid and handled by socket. - Ok(()) => handled_by_raw_socket = true, - // The socket buffer is full or the packet was truncated - Err(Error::Exhausted) | Err(Error::Truncated) => (), - // Raw sockets don't validate the packets in any way. - Err(_) => unreachable!(), - } - } - handled_by_raw_socket - } - - #[cfg(feature = "proto-ipv6")] - fn process_ipv6<'frame, T: AsRef<[u8]>> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, - eth_frame: &EthernetFrame<&'frame T>) -> - Result> - { - let ipv6_packet = Ipv6Packet::new_checked(eth_frame.payload())?; - let ipv6_repr = Ipv6Repr::parse(&ipv6_packet)?; - - if !ipv6_repr.src_addr.is_unicast() { - // Discard packets with non-unicast source addresses. - net_debug!("non-unicast source address"); - return Err(Error::Malformed) - } - - if eth_frame.src_addr().is_unicast() { - // Fill the neighbor cache from IP header of unicast frames. - let ip_addr = IpAddress::Ipv6(ipv6_repr.src_addr); - if self.in_same_network(&ip_addr) && - self.neighbor_cache.lookup_pure(&ip_addr, timestamp).is_none() { - self.neighbor_cache.fill(ip_addr, eth_frame.src_addr(), timestamp); - } - } - - let ip_payload = ipv6_packet.payload(); - - #[cfg(feature = "socket-raw")] - let handled_by_raw_socket = self.raw_socket_filter(sockets, &ipv6_repr.into(), ip_payload); - #[cfg(not(feature = "socket-raw"))] - let handled_by_raw_socket = false; - - self.process_nxt_hdr(sockets, timestamp, ipv6_repr, ipv6_repr.next_header, - handled_by_raw_socket, ip_payload) - } - - /// Given the next header value forward the payload onto the correct process - /// function. - #[cfg(feature = "proto-ipv6")] - fn process_nxt_hdr<'frame> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, ipv6_repr: Ipv6Repr, - nxt_hdr: IpProtocol, handled_by_raw_socket: bool, ip_payload: &'frame [u8]) - -> Result> - { - match nxt_hdr { - IpProtocol::Icmpv6 => - self.process_icmpv6(sockets, timestamp, ipv6_repr.into(), ip_payload), - - #[cfg(feature = "socket-udp")] - IpProtocol::Udp => - self.process_udp(sockets, ipv6_repr.into(), handled_by_raw_socket, ip_payload), - - #[cfg(feature = "socket-tcp")] - IpProtocol::Tcp => - self.process_tcp(sockets, timestamp, ipv6_repr.into(), ip_payload), - - IpProtocol::HopByHop => - self.process_hopbyhop(sockets, timestamp, ipv6_repr, handled_by_raw_socket, ip_payload), - - #[cfg(feature = "socket-raw")] - _ if handled_by_raw_socket => - Ok(Packet::None), - - _ => { - // Send back as much of the original payload as we can. - let payload_len = icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, - ipv6_repr.buffer_len()); - let icmp_reply_repr = Icmpv6Repr::ParamProblem { - reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, - // The offending packet is after the IPv6 header. - pointer: ipv6_repr.buffer_len() as u32, - header: ipv6_repr, - data: &ip_payload[0..payload_len] - }; - Ok(self.icmpv6_reply(ipv6_repr, icmp_reply_repr)) - }, - } - } - - #[cfg(feature = "proto-ipv4")] - fn process_ipv4<'frame, T: AsRef<[u8]>> - (&mut self, sockets: &mut SocketSet, timestamp: Instant, - eth_frame: &EthernetFrame<&'frame T>) -> - Result> - { - let ipv4_packet = Ipv4Packet::new_checked(eth_frame.payload())?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &checksum_caps)?; - - if !ipv4_repr.src_addr.is_unicast() { - // Discard packets with non-unicast source addresses. - net_debug!("non-unicast source address"); - return Err(Error::Malformed) - } - - if eth_frame.src_addr().is_unicast() { - // Fill the neighbor cache from IP header of unicast frames. - let ip_addr = IpAddress::Ipv4(ipv4_repr.src_addr); - if self.in_same_network(&ip_addr) { - self.neighbor_cache.fill(ip_addr, eth_frame.src_addr(), timestamp); - } - } - - let ip_repr = IpRepr::Ipv4(ipv4_repr); - let ip_payload = ipv4_packet.payload(); - - #[cfg(feature = "socket-raw")] - let handled_by_raw_socket = self.raw_socket_filter(sockets, &ip_repr, ip_payload); - #[cfg(not(feature = "socket-raw"))] - let handled_by_raw_socket = false; - - if !self.has_ip_addr(ipv4_repr.dst_addr) && - !ipv4_repr.dst_addr.is_broadcast() && - !self.has_multicast_group(ipv4_repr.dst_addr) { - // Ignore IP packets not directed at us, or broadcast, or any of the multicast groups. - // If AnyIP is enabled, also check if the packet is routed locally. - if !self.any_ip { - return Ok(Packet::None); - } else if match self.routes.lookup(&IpAddress::Ipv4(ipv4_repr.dst_addr), timestamp) { - Some(router_addr) => !self.has_ip_addr(router_addr), - None => true, - } { - return Ok(Packet::None); - } - } - - match ipv4_repr.protocol { - IpProtocol::Icmp => - self.process_icmpv4(sockets, ip_repr, ip_payload), - - #[cfg(feature = "proto-igmp")] - IpProtocol::Igmp => - self.process_igmp(timestamp, ipv4_repr, ip_payload), - - #[cfg(feature = "socket-udp")] - IpProtocol::Udp => - self.process_udp(sockets, ip_repr, handled_by_raw_socket, ip_payload), - - #[cfg(feature = "socket-tcp")] - IpProtocol::Tcp => - self.process_tcp(sockets, timestamp, ip_repr, ip_payload), - - _ if handled_by_raw_socket => - Ok(Packet::None), - - _ => { - // Send back as much of the original payload as we can. - let payload_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, - ipv4_repr.buffer_len()); - let icmp_reply_repr = Icmpv4Repr::DstUnreachable { - reason: Icmpv4DstUnreachable::ProtoUnreachable, - header: ipv4_repr, - data: &ip_payload[0..payload_len] - }; - Ok(self.icmpv4_reply(ipv4_repr, icmp_reply_repr)) - } - } - } - - /// Host duties of the **IGMPv2** protocol. - /// - /// Sets up `igmp_report_state` for responding to IGMP general/specific membership queries. - /// Membership must not be reported immediately in order to avoid flooding the network - /// after a query is broadcasted by a router; this is not currently done. - #[cfg(feature = "proto-igmp")] - fn process_igmp<'frame>(&mut self, timestamp: Instant, ipv4_repr: Ipv4Repr, - ip_payload: &'frame [u8]) -> Result> { - let igmp_packet = IgmpPacket::new_checked(ip_payload)?; - let igmp_repr = IgmpRepr::parse(&igmp_packet)?; - - // FIXME: report membership after a delay - match igmp_repr { - IgmpRepr::MembershipQuery { group_addr, version, max_resp_time } => { - // General query - if group_addr.is_unspecified() && - ipv4_repr.dst_addr == Ipv4Address::MULTICAST_ALL_SYSTEMS { - // Are we member in any groups? - if self.ipv4_multicast_groups.iter().next().is_some() { - let interval = match version { - IgmpVersion::Version1 => - Duration::from_millis(100), - IgmpVersion::Version2 => { - // No dependence on a random generator - // (see [#24](https://github.com/m-labs/smoltcp/issues/24)) - // but at least spread reports evenly across max_resp_time. - let intervals = self.ipv4_multicast_groups.len() as u32 + 1; - max_resp_time / intervals - } - }; - self.igmp_report_state = IgmpReportState::ToGeneralQuery { - version, timeout: timestamp + interval, interval, next_index: 0 - }; - } - } else { - // Group-specific query - if self.has_multicast_group(group_addr) && ipv4_repr.dst_addr == group_addr { - // Don't respond immediately - let timeout = max_resp_time / 4; - self.igmp_report_state = IgmpReportState::ToSpecificQuery { - version, timeout: timestamp + timeout, group: group_addr - }; - } - } - }, - // Ignore membership reports - IgmpRepr::MembershipReport { .. } => (), - // Ignore hosts leaving groups - IgmpRepr::LeaveGroup{ .. } => (), - } - - Ok(Packet::None) - } - - #[cfg(feature = "proto-ipv6")] - fn process_icmpv6<'frame>(&mut self, _sockets: &mut SocketSet, timestamp: Instant, - ip_repr: IpRepr, ip_payload: &'frame [u8]) -> Result> - { - let icmp_packet = Icmpv6Packet::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let icmp_repr = Icmpv6Repr::parse(&ip_repr.src_addr(), &ip_repr.dst_addr(), - &icmp_packet, &checksum_caps)?; - - #[cfg(feature = "socket-icmp")] - let mut handled_by_icmp_socket = false; - - #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))] - for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { - if !icmp_socket.accepts(&ip_repr, &icmp_repr.into(), &checksum_caps) { continue } - - match icmp_socket.process(&ip_repr, &icmp_repr.into(), &checksum_caps) { - // The packet is valid and handled by socket. - Ok(()) => handled_by_icmp_socket = true, - // The socket buffer is full. - Err(Error::Exhausted) => (), - // ICMP sockets don't validate the packets in any way. - Err(_) => unreachable!(), - } - } - - match icmp_repr { - // Respond to echo requests. - Icmpv6Repr::EchoRequest { ident, seq_no, data } => { - match ip_repr { - IpRepr::Ipv6(ipv6_repr) => { - let icmp_reply_repr = Icmpv6Repr::EchoReply { - ident: ident, - seq_no: seq_no, - data: data - }; - Ok(self.icmpv6_reply(ipv6_repr, icmp_reply_repr)) - }, - _ => Err(Error::Unrecognized), - } - } - - // Ignore any echo replies. - Icmpv6Repr::EchoReply { .. } => Ok(Packet::None), - - // Forward any NDISC packets to the ndisc packet handler - Icmpv6Repr::Ndisc(repr) if ip_repr.hop_limit() == 0xff => match ip_repr { - IpRepr::Ipv6(ipv6_repr) => self.process_ndisc(timestamp, ipv6_repr, repr), - _ => Ok(Packet::None) - }, - - // Don't report an error if a packet with unknown type - // has been handled by an ICMP socket - #[cfg(feature = "socket-icmp")] - _ if handled_by_icmp_socket => Ok(Packet::None), - - // FIXME: do something correct here? - _ => Err(Error::Unrecognized), - } - } - - #[cfg(feature = "proto-ipv6")] - fn process_ndisc<'frame>(&mut self, timestamp: Instant, ip_repr: Ipv6Repr, - repr: NdiscRepr<'frame>) -> Result> { - let packet = match repr { - NdiscRepr::NeighborAdvert { lladdr, target_addr, flags } => { - let ip_addr = ip_repr.src_addr.into(); - match lladdr { - Some(lladdr) if lladdr.is_unicast() && target_addr.is_unicast() => { - if flags.contains(NdiscNeighborFlags::OVERRIDE) { - self.neighbor_cache.fill(ip_addr, lladdr, timestamp) - } else { - if self.neighbor_cache.lookup_pure(&ip_addr, timestamp).is_none() { - self.neighbor_cache.fill(ip_addr, lladdr, timestamp) - } - } - }, - _ => (), - } - Ok(Packet::None) - } - NdiscRepr::NeighborSolicit { target_addr, lladdr, .. } => { - match lladdr { - Some(lladdr) if lladdr.is_unicast() && target_addr.is_unicast() => { - self.neighbor_cache.fill(ip_repr.src_addr.into(), lladdr, timestamp) - }, - _ => (), - } - if self.has_solicited_node(ip_repr.dst_addr) && self.has_ip_addr(target_addr) { - let advert = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { - flags: NdiscNeighborFlags::SOLICITED, - target_addr: target_addr, - lladdr: Some(self.ethernet_addr) - }); - let ip_repr = Ipv6Repr { - src_addr: target_addr, - dst_addr: ip_repr.src_addr, - next_header: IpProtocol::Icmpv6, - hop_limit: 0xff, - payload_len: advert.buffer_len() - }; - Ok(Packet::Icmpv6((ip_repr, advert))) - } else { - Ok(Packet::None) - } - } - _ => Ok(Packet::None) - }; - packet - } - - #[cfg(feature = "proto-ipv6")] - fn process_hopbyhop<'frame>(&mut self, sockets: &mut SocketSet, timestamp: Instant, - ipv6_repr: Ipv6Repr, handled_by_raw_socket: bool, - ip_payload: &'frame [u8]) -> Result> - { - let hbh_pkt = Ipv6HopByHopHeader::new_checked(ip_payload)?; - let hbh_repr = Ipv6HopByHopRepr::parse(&hbh_pkt)?; - for result in hbh_repr.options() { - let opt_repr = result?; - match opt_repr { - Ipv6OptionRepr::Pad1 | Ipv6OptionRepr::PadN(_) => (), - Ipv6OptionRepr::Unknown { type_, .. } => { - match Ipv6OptionFailureType::from(type_) { - Ipv6OptionFailureType::Skip => (), - Ipv6OptionFailureType::Discard => { - return Ok(Packet::None); - }, - _ => { - // FIXME(dlrobertson): Send an ICMPv6 parameter problem message - // here. - return Err(Error::Unrecognized); - } - } - } - _ => return Err(Error::Unrecognized), - } - } - self.process_nxt_hdr(sockets, timestamp, ipv6_repr, hbh_repr.next_header, - handled_by_raw_socket, &ip_payload[hbh_repr.buffer_len()..]) - } - - #[cfg(feature = "proto-ipv4")] - fn process_icmpv4<'frame>(&self, _sockets: &mut SocketSet, ip_repr: IpRepr, - ip_payload: &'frame [u8]) -> Result> - { - let icmp_packet = Icmpv4Packet::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let icmp_repr = Icmpv4Repr::parse(&icmp_packet, &checksum_caps)?; - - #[cfg(feature = "socket-icmp")] - let mut handled_by_icmp_socket = false; - - #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] - for mut icmp_socket in _sockets.iter_mut().filter_map(IcmpSocket::downcast) { - if !icmp_socket.accepts(&ip_repr, &icmp_repr.into(), &checksum_caps) { continue } - - match icmp_socket.process(&ip_repr, &icmp_repr.into(), &checksum_caps) { - // The packet is valid and handled by socket. - Ok(()) => handled_by_icmp_socket = true, - // The socket buffer is full. - Err(Error::Exhausted) => (), - // ICMP sockets don't validate the packets in any way. - Err(_) => unreachable!(), - } - } - - match icmp_repr { - // Respond to echo requests. - #[cfg(feature = "proto-ipv4")] - Icmpv4Repr::EchoRequest { ident, seq_no, data } => { - let icmp_reply_repr = Icmpv4Repr::EchoReply { - ident: ident, - seq_no: seq_no, - data: data - }; - match ip_repr { - IpRepr::Ipv4(ipv4_repr) => Ok(self.icmpv4_reply(ipv4_repr, icmp_reply_repr)), - _ => Err(Error::Unrecognized), - } - }, - - // Ignore any echo replies. - Icmpv4Repr::EchoReply { .. } => Ok(Packet::None), - - // Don't report an error if a packet with unknown type - // has been handled by an ICMP socket - #[cfg(feature = "socket-icmp")] - _ if handled_by_icmp_socket => Ok(Packet::None), - - // FIXME: do something correct here? - _ => Err(Error::Unrecognized), - } - } - - #[cfg(feature = "proto-ipv4")] - fn icmpv4_reply<'frame, 'icmp: 'frame> - (&self, ipv4_repr: Ipv4Repr, icmp_repr: Icmpv4Repr<'icmp>) -> - Packet<'frame> - { - if !ipv4_repr.src_addr.is_unicast() { - // Do not send ICMP replies to non-unicast sources - Packet::None - } else if ipv4_repr.dst_addr.is_unicast() { - // Reply as normal when src_addr and dst_addr are both unicast - let ipv4_reply_repr = Ipv4Repr { - src_addr: ipv4_repr.dst_addr, - dst_addr: ipv4_repr.src_addr, - protocol: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 64 - }; - Packet::Icmpv4((ipv4_reply_repr, icmp_repr)) - } else if ipv4_repr.dst_addr.is_broadcast() { - // Only reply to broadcasts for echo replies and not other ICMP messages - match icmp_repr { - Icmpv4Repr::EchoReply {..} => match self.ipv4_address() { - Some(src_addr) => { - let ipv4_reply_repr = Ipv4Repr { - src_addr: src_addr, - dst_addr: ipv4_repr.src_addr, - protocol: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 64 - }; - Packet::Icmpv4((ipv4_reply_repr, icmp_repr)) - }, - None => Packet::None, - }, - _ => Packet::None, - } - } else { - Packet::None - } - } - - #[cfg(feature = "proto-ipv6")] - fn icmpv6_reply<'frame, 'icmp: 'frame> - (&self, ipv6_repr: Ipv6Repr, icmp_repr: Icmpv6Repr<'icmp>) -> - Packet<'frame> - { - if ipv6_repr.dst_addr.is_unicast() { - let ipv6_reply_repr = Ipv6Repr { - src_addr: ipv6_repr.dst_addr, - dst_addr: ipv6_repr.src_addr, - next_header: IpProtocol::Icmpv6, - payload_len: icmp_repr.buffer_len(), - hop_limit: 64 - }; - Packet::Icmpv6((ipv6_reply_repr, icmp_repr)) - } else { - // Do not send any ICMP replies to a broadcast destination address. - Packet::None - } - } - - #[cfg(feature = "socket-udp")] - fn process_udp<'frame>(&self, sockets: &mut SocketSet, - ip_repr: IpRepr, handled_by_raw_socket: bool, ip_payload: &'frame [u8]) -> - Result> - { - let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); - let udp_packet = UdpPacket::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let udp_repr = UdpRepr::parse(&udp_packet, &src_addr, &dst_addr, &checksum_caps)?; - - for mut udp_socket in sockets.iter_mut().filter_map(UdpSocket::downcast) { - if !udp_socket.accepts(&ip_repr, &udp_repr) { continue } - - match udp_socket.process(&ip_repr, &udp_repr) { - // The packet is valid and handled by socket. - Ok(()) => return Ok(Packet::None), - // The packet is malformed, or the socket buffer is full. - Err(e) => return Err(e) - } - } - - // The packet wasn't handled by a socket, send an ICMP port unreachable packet. - match ip_repr { - #[cfg(feature = "proto-ipv4")] - IpRepr::Ipv4(_) if handled_by_raw_socket => - Ok(Packet::None), - #[cfg(feature = "proto-ipv6")] - IpRepr::Ipv6(_) if handled_by_raw_socket => - Ok(Packet::None), - #[cfg(feature = "proto-ipv4")] - IpRepr::Ipv4(ipv4_repr) => { - let payload_len = icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, - ipv4_repr.buffer_len()); - let icmpv4_reply_repr = Icmpv4Repr::DstUnreachable { - reason: Icmpv4DstUnreachable::PortUnreachable, - header: ipv4_repr, - data: &ip_payload[0..payload_len] - }; - Ok(self.icmpv4_reply(ipv4_repr, icmpv4_reply_repr)) - }, - #[cfg(feature = "proto-ipv6")] - IpRepr::Ipv6(ipv6_repr) => { - let payload_len = icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, - ipv6_repr.buffer_len()); - let icmpv6_reply_repr = Icmpv6Repr::DstUnreachable { - reason: Icmpv6DstUnreachable::PortUnreachable, - header: ipv6_repr, - data: &ip_payload[0..payload_len] - }; - Ok(self.icmpv6_reply(ipv6_repr, icmpv6_reply_repr)) - }, - IpRepr::Unspecified { .. } | - IpRepr::__Nonexhaustive => Err(Error::Unaddressable), - } - } - - #[cfg(feature = "socket-tcp")] - fn process_tcp<'frame>(&self, sockets: &mut SocketSet, timestamp: Instant, - ip_repr: IpRepr, ip_payload: &'frame [u8]) -> - Result> - { - let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); - let tcp_packet = TcpPacket::new_checked(ip_payload)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - let tcp_repr = TcpRepr::parse(&tcp_packet, &src_addr, &dst_addr, &checksum_caps)?; - - for mut tcp_socket in sockets.iter_mut().filter_map(TcpSocket::downcast) { - if !tcp_socket.accepts(&ip_repr, &tcp_repr) { continue } - - match tcp_socket.process(timestamp, &ip_repr, &tcp_repr) { - // The packet is valid and handled by socket. - Ok(reply) => return Ok(reply.map_or(Packet::None, Packet::Tcp)), - // The packet is malformed, or doesn't match the socket state, - // or the socket buffer is full. - Err(e) => return Err(e) - } - } - - if tcp_repr.control == TcpControl::Rst { - // Never reply to a TCP RST packet with another TCP RST packet. - Ok(Packet::None) - } else { - // The packet wasn't handled by a socket, send a TCP RST packet. - Ok(Packet::Tcp(TcpSocket::rst_reply(&ip_repr, &tcp_repr))) - } - } - - fn dispatch(&mut self, tx_token: Tx, timestamp: Instant, - packet: Packet) -> Result<()> - where Tx: TxToken - { - let checksum_caps = self.device_capabilities.checksum.clone(); - match packet { - #[cfg(feature = "proto-ipv4")] - Packet::Arp(arp_repr) => { - let dst_hardware_addr = - match arp_repr { - ArpRepr::EthernetIpv4 { target_hardware_addr, .. } => target_hardware_addr, - _ => unreachable!() - }; - - self.dispatch_ethernet(tx_token, timestamp, arp_repr.buffer_len(), |mut frame| { - frame.set_dst_addr(dst_hardware_addr); - frame.set_ethertype(EthernetProtocol::Arp); - - let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); - arp_repr.emit(&mut packet); - }) - }, - #[cfg(feature = "proto-ipv4")] - Packet::Icmpv4((ipv4_repr, icmpv4_repr)) => { - self.dispatch_ip(tx_token, timestamp, IpRepr::Ipv4(ipv4_repr), - |_ip_repr, payload| { - icmpv4_repr.emit(&mut Icmpv4Packet::new_unchecked(payload), &checksum_caps); - }) - } - #[cfg(feature = "proto-igmp")] - Packet::Igmp((ipv4_repr, igmp_repr)) => { - self.dispatch_ip(tx_token, timestamp, IpRepr::Ipv4(ipv4_repr), |_ip_repr, payload| { - igmp_repr.emit(&mut IgmpPacket::new_unchecked(payload)); - }) - } - #[cfg(feature = "proto-ipv6")] - Packet::Icmpv6((ipv6_repr, icmpv6_repr)) => { - self.dispatch_ip(tx_token, timestamp, IpRepr::Ipv6(ipv6_repr), - |ip_repr, payload| { - icmpv6_repr.emit(&ip_repr.src_addr(), &ip_repr.dst_addr(), - &mut Icmpv6Packet::new_unchecked(payload), &checksum_caps); - }) - } - #[cfg(feature = "socket-raw")] - Packet::Raw((ip_repr, raw_packet)) => { - self.dispatch_ip(tx_token, timestamp, ip_repr, |_ip_repr, payload| { - payload.copy_from_slice(raw_packet); - }) - } - #[cfg(feature = "socket-udp")] - Packet::Udp((ip_repr, udp_repr)) => { - self.dispatch_ip(tx_token, timestamp, ip_repr, |ip_repr, payload| { - udp_repr.emit(&mut UdpPacket::new_unchecked(payload), - &ip_repr.src_addr(), &ip_repr.dst_addr(), - &checksum_caps); - }) - } - #[cfg(feature = "socket-tcp")] - Packet::Tcp((ip_repr, mut tcp_repr)) => { - let caps = self.device_capabilities.clone(); - self.dispatch_ip(tx_token, timestamp, ip_repr, |ip_repr, payload| { - // This is a terrible hack to make TCP performance more acceptable on systems - // where the TCP buffers are significantly larger than network buffers, - // e.g. a 64 kB TCP receive buffer (and so, when empty, a 64k window) - // together with four 1500 B Ethernet receive buffers. If left untreated, - // this would result in our peer pushing our window and sever packet loss. - // - // I'm really not happy about this "solution" but I don't know what else to do. - if let Some(max_burst_size) = caps.max_burst_size { - let mut max_segment_size = caps.max_transmission_unit; - max_segment_size -= EthernetFrame::<&[u8]>::header_len(); - max_segment_size -= ip_repr.buffer_len(); - max_segment_size -= tcp_repr.header_len(); - - let max_window_size = max_burst_size * max_segment_size; - if tcp_repr.window_len as usize > max_window_size { - tcp_repr.window_len = max_window_size as u16; - } - } - - tcp_repr.emit(&mut TcpPacket::new_unchecked(payload), - &ip_repr.src_addr(), &ip_repr.dst_addr(), - &checksum_caps); - }) - } - Packet::None => Ok(()) - } - } - - fn dispatch_ethernet(&mut self, tx_token: Tx, timestamp: Instant, - buffer_len: usize, f: F) -> Result<()> - where Tx: TxToken, F: FnOnce(EthernetFrame<&mut [u8]>) - { - let tx_len = EthernetFrame::<&[u8]>::buffer_len(buffer_len); - tx_token.consume(timestamp, tx_len, |tx_buffer| { - debug_assert!(tx_buffer.as_ref().len() == tx_len); - let mut frame = EthernetFrame::new_unchecked(tx_buffer.as_mut()); - frame.set_src_addr(self.ethernet_addr); - - f(frame); - - Ok(()) - }) - } - - fn in_same_network(&self, addr: &IpAddress) -> bool { - self.ip_addrs - .iter() - .find(|cidr| cidr.contains_addr(addr)) - .is_some() - } - - fn route(&self, addr: &IpAddress, timestamp: Instant) -> Result { - // Send directly. - if self.in_same_network(addr) || addr.is_broadcast() { - return Ok(*addr) - } - - // Route via a router. - match self.routes.lookup(addr, timestamp) { - Some(router_addr) => Ok(router_addr), - None => Err(Error::Unaddressable), - } - } - - fn has_neighbor<'a>(&self, addr: &'a IpAddress, timestamp: Instant) -> bool { - match self.route(addr, timestamp) { - Ok(routed_addr) => { - self.neighbor_cache - .lookup_pure(&routed_addr, timestamp) - .is_some() - } - Err(_) => false - } - } - - fn lookup_hardware_addr(&mut self, tx_token: Tx, timestamp: Instant, - src_addr: &IpAddress, dst_addr: &IpAddress) -> - Result<(EthernetAddress, Tx)> - where Tx: TxToken - { - if dst_addr.is_multicast() { - let b = dst_addr.as_bytes(); - let hardware_addr = - match dst_addr { - &IpAddress::Unspecified => - None, - #[cfg(feature = "proto-ipv4")] - &IpAddress::Ipv4(_addr) => - Some(EthernetAddress::from_bytes(&[ - 0x01, 0x00, - 0x5e, b[1] & 0x7F, - b[2], b[3], - ])), - #[cfg(feature = "proto-ipv6")] - &IpAddress::Ipv6(_addr) => - Some(EthernetAddress::from_bytes(&[ - 0x33, 0x33, - b[12], b[13], - b[14], b[15], - ])), - &IpAddress::__Nonexhaustive => - unreachable!() - }; - match hardware_addr { - Some(hardware_addr) => - // Destination is multicast - return Ok((hardware_addr, tx_token)), - None => - // Continue - (), - } - } - - let dst_addr = self.route(dst_addr, timestamp)?; - - match self.neighbor_cache.lookup(&dst_addr, timestamp) { - NeighborAnswer::Found(hardware_addr) => - return Ok((hardware_addr, tx_token)), - NeighborAnswer::RateLimited => - return Err(Error::Unaddressable), - NeighborAnswer::NotFound => (), - } - - match (src_addr, dst_addr) { - #[cfg(feature = "proto-ipv4")] - (&IpAddress::Ipv4(src_addr), IpAddress::Ipv4(dst_addr)) => { - net_debug!("address {} not in neighbor cache, sending ARP request", - dst_addr); - - let arp_repr = ArpRepr::EthernetIpv4 { - operation: ArpOperation::Request, - source_hardware_addr: self.ethernet_addr, - source_protocol_addr: src_addr, - target_hardware_addr: EthernetAddress::BROADCAST, - target_protocol_addr: dst_addr, - }; - - self.dispatch_ethernet(tx_token, timestamp, arp_repr.buffer_len(), |mut frame| { - frame.set_dst_addr(EthernetAddress::BROADCAST); - frame.set_ethertype(EthernetProtocol::Arp); - - arp_repr.emit(&mut ArpPacket::new_unchecked(frame.payload_mut())) - })?; - - Err(Error::Unaddressable) - } - - #[cfg(feature = "proto-ipv6")] - (&IpAddress::Ipv6(src_addr), IpAddress::Ipv6(dst_addr)) => { - net_debug!("address {} not in neighbor cache, sending Neighbor Solicitation", - dst_addr); - - let checksum_caps = self.device_capabilities.checksum.clone(); - - let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { - target_addr: src_addr, - lladdr: Some(self.ethernet_addr), - }); - - let ip_repr = IpRepr::Ipv6(Ipv6Repr { - src_addr: src_addr, - dst_addr: dst_addr.solicited_node(), - next_header: IpProtocol::Icmpv6, - payload_len: solicit.buffer_len(), - hop_limit: 0xff - }); - - self.dispatch_ip(tx_token, timestamp, ip_repr, |ip_repr, payload| { - solicit.emit(&ip_repr.src_addr(), &ip_repr.dst_addr(), - &mut Icmpv6Packet::new_unchecked(payload), &checksum_caps); - })?; - - Err(Error::Unaddressable) - } - - _ => Err(Error::Unaddressable) - } - } - - fn dispatch_ip(&mut self, tx_token: Tx, timestamp: Instant, - ip_repr: IpRepr, f: F) -> Result<()> - where Tx: TxToken, F: FnOnce(IpRepr, &mut [u8]) - { - let ip_repr = ip_repr.lower(&self.ip_addrs)?; - let checksum_caps = self.device_capabilities.checksum.clone(); - - let (dst_hardware_addr, tx_token) = - self.lookup_hardware_addr(tx_token, timestamp, - &ip_repr.src_addr(), &ip_repr.dst_addr())?; - - self.dispatch_ethernet(tx_token, timestamp, ip_repr.total_len(), |mut frame| { - frame.set_dst_addr(dst_hardware_addr); - match ip_repr { - #[cfg(feature = "proto-ipv4")] - IpRepr::Ipv4(_) => frame.set_ethertype(EthernetProtocol::Ipv4), - #[cfg(feature = "proto-ipv6")] - IpRepr::Ipv6(_) => frame.set_ethertype(EthernetProtocol::Ipv6), - _ => return - } - - ip_repr.emit(frame.payload_mut(), &checksum_caps); - - let payload = &mut frame.payload_mut()[ip_repr.buffer_len()..]; - f(ip_repr, payload) - }) - } - - #[cfg(feature = "proto-igmp")] - fn igmp_report_packet<'any>(&self, version: IgmpVersion, group_addr: Ipv4Address) -> Option> { - let iface_addr = self.ipv4_address()?; - let igmp_repr = IgmpRepr::MembershipReport { - group_addr, - version, - }; - let pkt = Packet::Igmp((Ipv4Repr { - src_addr: iface_addr, - // Send to the group being reported - dst_addr: group_addr, - protocol: IpProtocol::Igmp, - payload_len: igmp_repr.buffer_len(), - hop_limit: 1, - // TODO: add Router Alert IPv4 header option. See - // [#183](https://github.com/m-labs/smoltcp/issues/183). - }, igmp_repr)); - Some(pkt) - } - - #[cfg(feature = "proto-igmp")] - fn igmp_leave_packet<'any>(&self, group_addr: Ipv4Address) -> Option> { - self.ipv4_address().map(|iface_addr| { - let igmp_repr = IgmpRepr::LeaveGroup { group_addr }; - let pkt = Packet::Igmp((Ipv4Repr { - src_addr: iface_addr, - dst_addr: Ipv4Address::MULTICAST_ALL_ROUTERS, - protocol: IpProtocol::Igmp, - payload_len: igmp_repr.buffer_len(), - hop_limit: 1, - }, igmp_repr)); - pkt - }) - } -} - -#[cfg(test)] -mod test { - #[cfg(feature = "proto-igmp")] - use std::vec::Vec; - use std::collections::BTreeMap; - use {Result, Error}; - - use super::InterfaceBuilder; - use iface::{NeighborCache, EthernetInterface}; - use phy::{self, Loopback, ChecksumCapabilities}; - #[cfg(feature = "proto-igmp")] - use phy::{Device, RxToken, TxToken}; - use time::Instant; - use socket::SocketSet; - #[cfg(feature = "proto-ipv4")] - use wire::{ArpOperation, ArpPacket, ArpRepr}; - use wire::{EthernetAddress, EthernetFrame, EthernetProtocol}; - use wire::{IpAddress, IpCidr, IpProtocol, IpRepr}; - #[cfg(feature = "proto-ipv4")] - use wire::{Ipv4Address, Ipv4Repr}; - #[cfg(feature = "proto-igmp")] - use wire::Ipv4Packet; - #[cfg(feature = "proto-ipv4")] - use wire::{Icmpv4Repr, Icmpv4DstUnreachable}; - #[cfg(feature = "proto-igmp")] - use wire::{IgmpPacket, IgmpRepr, IgmpVersion}; - #[cfg(all(feature = "socket-udp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] - use wire::{UdpPacket, UdpRepr}; - #[cfg(feature = "proto-ipv6")] - use wire::{Ipv6Address, Ipv6Repr}; - #[cfg(feature = "proto-ipv6")] - use wire::{Icmpv6Packet, Icmpv6Repr, Icmpv6ParamProblem}; - #[cfg(feature = "proto-ipv6")] - use wire::{NdiscNeighborFlags, NdiscRepr}; - #[cfg(feature = "proto-ipv6")] - use wire::{Ipv6HopByHopHeader, Ipv6Option, Ipv6OptionRepr}; - - use super::Packet; - - fn create_loopback<'a, 'b, 'c>() -> (EthernetInterface<'static, 'b, 'c, Loopback>, - SocketSet<'static, 'a, 'b>) { - // Create a basic device - let device = Loopback::new(); - let ip_addrs = [ - #[cfg(feature = "proto-ipv4")] - IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8), - #[cfg(feature = "proto-ipv6")] - IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 128), - #[cfg(feature = "proto-ipv6")] - IpCidr::new(IpAddress::v6(0xfdbe, 0, 0, 0, 0, 0, 0, 1), 64), - ]; - - let iface_builder = InterfaceBuilder::new(device) - .ethernet_addr(EthernetAddress::default()) - .neighbor_cache(NeighborCache::new(BTreeMap::new())) - .ip_addrs(ip_addrs); - #[cfg(feature = "proto-igmp")] - let iface_builder = iface_builder - .ipv4_multicast_groups(BTreeMap::new()); - let iface = iface_builder - .finalize(); - - (iface, SocketSet::new(vec![])) - } - - #[cfg(feature = "proto-igmp")] - fn recv_all<'b>(iface: &mut EthernetInterface<'static, 'b, 'static, Loopback>, timestamp: Instant) -> Vec> { - let mut pkts = Vec::new(); - while let Some((rx, _tx)) = iface.device.receive() { - rx.consume(timestamp, |pkt| { - pkts.push(pkt.iter().cloned().collect()); - Ok(()) - }).unwrap(); - } - pkts - } - - #[derive(Debug, PartialEq)] - struct MockTxToken; - - impl phy::TxToken for MockTxToken { - fn consume(self, _: Instant, _: usize, _: F) -> Result - where F: FnOnce(&mut [u8]) -> Result { - Err(Error::__Nonexhaustive) - } - } - - #[test] - #[should_panic(expected = "a required option was not set")] - fn test_builder_initialization_panic() { - InterfaceBuilder::new(Loopback::new()).finalize(); - } - - #[test] - fn test_no_icmp_no_unicast() { - let (mut iface, mut socket_set) = create_loopback(); - - let mut eth_bytes = vec![0u8; 54]; - - // Unknown Ipv4 Protocol - // - // Because the destination is the broadcast address - // this should not trigger and Destination Unreachable - // response. See RFC 1122 § 3.2.2. - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - dst_addr: Ipv4Address::BROADCAST, - protocol: IpProtocol::Unknown(0x0c), - payload_len: 0, - hop_limit: 0x40 - }); - #[cfg(feature = "proto-ipv6")] - let repr = IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), - dst_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, - next_header: IpProtocol::Unknown(0x0c), - payload_len: 0, - hop_limit: 0x40 - }); - - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - frame.set_dst_addr(EthernetAddress::BROADCAST); - frame.set_src_addr(EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00])); - frame.set_ethertype(EthernetProtocol::Ipv4); - repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - // Ensure that the unknown protocol frame does not trigger an - // ICMP error response when the destination address is a - // broadcast address - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), - Ok(Packet::None)); - #[cfg(feature = "proto-ipv6")] - assert_eq!(iface.inner.process_ipv6(&mut socket_set, Instant::from_millis(0), &frame), - Ok(Packet::None)); - } - - #[test] - #[cfg(feature = "proto-ipv4")] - fn test_icmp_error_no_payload() { - static NO_BYTES: [u8; 0] = []; - let (mut iface, mut socket_set) = create_loopback(); - - let mut eth_bytes = vec![0u8; 34]; - - // Unknown Ipv4 Protocol with no payload - let repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - protocol: IpProtocol::Unknown(0x0c), - payload_len: 0, - hop_limit: 0x40 - }); - - // emit the above repr to a frame - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - frame.set_dst_addr(EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00])); - frame.set_src_addr(EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00])); - frame.set_ethertype(EthernetProtocol::Ipv4); - repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - // The expected Destination Unreachable response due to the - // unknown protocol - let icmp_repr = Icmpv4Repr::DstUnreachable { - reason: Icmpv4DstUnreachable::ProtoUnreachable, - header: Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - protocol: IpProtocol::Unknown(12), - payload_len: 0, - hop_limit: 64 - }, - data: &NO_BYTES - }; - - let expected_repr = Packet::Icmpv4(( - Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - protocol: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 64 - }, - icmp_repr - )); - - // Ensure that the unknown protocol triggers an error response. - // And we correctly handle no payload. - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), - Ok(expected_repr)); - } - - #[test] - #[cfg(all(feature = "socket-udp", feature = "proto-ipv4"))] - fn test_icmp_error_port_unreachable() { - static UDP_PAYLOAD: [u8; 12] = [ - 0x48, 0x65, 0x6c, 0x6c, - 0x6f, 0x2c, 0x20, 0x57, - 0x6f, 0x6c, 0x64, 0x21 - ]; - let (iface, mut socket_set) = create_loopback(); - - let mut udp_bytes_unicast = vec![0u8; 20]; - let mut udp_bytes_broadcast = vec![0u8; 20]; - let mut packet_unicast = UdpPacket::new_unchecked(&mut udp_bytes_unicast); - let mut packet_broadcast = UdpPacket::new_unchecked(&mut udp_bytes_broadcast); - - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &UDP_PAYLOAD - }; - - let ip_repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - protocol: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 64 - }); - - // Emit the representations to a packet - udp_repr.emit(&mut packet_unicast, &ip_repr.src_addr(), - &ip_repr.dst_addr(), &ChecksumCapabilities::default()); - - let data = packet_unicast.into_inner(); - - // The expected Destination Unreachable ICMPv4 error response due - // to no sockets listening on the destination port. - let icmp_repr = Icmpv4Repr::DstUnreachable { - reason: Icmpv4DstUnreachable::PortUnreachable, - header: Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - protocol: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 64 - }, - data: &data - }; - let expected_repr = Packet::Icmpv4(( - Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), - dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - protocol: IpProtocol::Icmp, - payload_len: icmp_repr.buffer_len(), - hop_limit: 64 - }, - icmp_repr - )); - - // Ensure that the unknown protocol triggers an error response. - // And we correctly handle no payload. - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, false, data), - Ok(expected_repr)); - - let ip_repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), - dst_addr: Ipv4Address::BROADCAST, - protocol: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 64 - }); - - // Emit the representations to a packet - udp_repr.emit(&mut packet_broadcast, &ip_repr.src_addr(), - &IpAddress::Ipv4(Ipv4Address::BROADCAST), - &ChecksumCapabilities::default()); - - // Ensure that the port unreachable error does not trigger an - // ICMP error response when the destination address is a - // broadcast address and no socket is bound to the port. - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, - false, packet_broadcast.into_inner()), Ok(Packet::None)); - } - - #[test] - #[cfg(feature = "socket-udp")] - fn test_handle_udp_broadcast() { - use socket::{UdpSocket, UdpSocketBuffer, UdpPacketMetadata}; - use wire::IpEndpoint; - - static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; - - let (iface, mut socket_set) = create_loopback(); - - let rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]); - let tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]); - - let udp_socket = UdpSocket::new(rx_buffer, tx_buffer); - - let mut udp_bytes = vec![0u8; 13]; - let mut packet = UdpPacket::new_unchecked(&mut udp_bytes); - - let socket_handle = socket_set.add(udp_socket); - - #[cfg(feature = "proto-ipv6")] - let src_ip = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); - #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] - let src_ip = Ipv4Address::new(0x7f, 0x00, 0x00, 0x02); - - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &UDP_PAYLOAD - }; - - #[cfg(feature = "proto-ipv6")] - let ip_repr = IpRepr::Ipv6(Ipv6Repr { - src_addr: src_ip, - dst_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, - next_header: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 0x40 - }); - #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] - let ip_repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: src_ip, - dst_addr: Ipv4Address::BROADCAST, - protocol: IpProtocol::Udp, - payload_len: udp_repr.buffer_len(), - hop_limit: 0x40 - }); - - { - // Bind the socket to port 68 - let mut socket = socket_set.get::(socket_handle); - assert_eq!(socket.bind(68), Ok(())); - assert!(!socket.can_recv()); - assert!(socket.can_send()); - } - - udp_repr.emit(&mut packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), - &ChecksumCapabilities::default()); - - // Packet should be handled by bound UDP socket - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr, false, packet.into_inner()), - Ok(Packet::None)); - - { - // Make sure the payload to the UDP packet processed by process_udp is - // appended to the bound sockets rx_buffer - let mut socket = socket_set.get::(socket_handle); - assert!(socket.can_recv()); - assert_eq!(socket.recv(), Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_ip.into(), 67)))); - } - } - - #[test] - #[cfg(feature = "proto-ipv4")] - fn test_handle_ipv4_broadcast() { - use wire::{Ipv4Packet, Icmpv4Repr, Icmpv4Packet}; - - let (mut iface, mut socket_set) = create_loopback(); - - let our_ipv4_addr = iface.ipv4_address().unwrap(); - let src_ipv4_addr = Ipv4Address([127, 0, 0, 2]); - - // ICMPv4 echo request - let icmpv4_data: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; - let icmpv4_repr = Icmpv4Repr::EchoRequest { - ident: 0x1234, seq_no: 0xabcd, data: &icmpv4_data - }; - - // Send to IPv4 broadcast address - let ipv4_repr = Ipv4Repr { - src_addr: src_ipv4_addr, - dst_addr: Ipv4Address::BROADCAST, - protocol: IpProtocol::Icmp, - hop_limit: 64, - payload_len: icmpv4_repr.buffer_len(), - }; - - // Emit to ethernet frame - let mut eth_bytes = vec![0u8; - EthernetFrame::<&[u8]>::header_len() + - ipv4_repr.buffer_len() + icmpv4_repr.buffer_len() - ]; - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - ipv4_repr.emit( - &mut Ipv4Packet::new_unchecked(frame.payload_mut()), - &ChecksumCapabilities::default()); - icmpv4_repr.emit( - &mut Icmpv4Packet::new_unchecked( - &mut frame.payload_mut()[ipv4_repr.buffer_len()..]), - &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - // Expected ICMPv4 echo reply - let expected_icmpv4_repr = Icmpv4Repr::EchoReply { - ident: 0x1234, seq_no: 0xabcd, data: &icmpv4_data }; - let expected_ipv4_repr = Ipv4Repr { - src_addr: our_ipv4_addr, - dst_addr: src_ipv4_addr, - protocol: IpProtocol::Icmp, - hop_limit: 64, - payload_len: expected_icmpv4_repr.buffer_len(), - }; - let expected_packet = Packet::Icmpv4((expected_ipv4_repr, expected_icmpv4_repr)); - - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), - Ok(expected_packet)); - } - - #[test] - #[cfg(feature = "socket-udp")] - fn test_icmp_reply_size() { - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - use wire::IPV4_MIN_MTU as MIN_MTU; - #[cfg(feature = "proto-ipv6")] - use wire::Icmpv6DstUnreachable; - #[cfg(feature = "proto-ipv6")] - use wire::IPV6_MIN_MTU as MIN_MTU; - - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - const MAX_PAYLOAD_LEN: usize = 528; - #[cfg(feature = "proto-ipv6")] - const MAX_PAYLOAD_LEN: usize = 1192; - - let (iface, mut socket_set) = create_loopback(); - - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let src_addr = Ipv4Address([192, 168, 1, 1]); - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let dst_addr = Ipv4Address([192, 168, 1, 2]); - #[cfg(feature = "proto-ipv6")] - let src_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); - #[cfg(feature = "proto-ipv6")] - let dst_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2); - - // UDP packet that if not tructated will cause a icmp port unreachable reply - // to exeed the minimum mtu bytes in length. - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &[0x2a; MAX_PAYLOAD_LEN] - }; - let mut bytes = vec![0xff; udp_repr.buffer_len()]; - let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); - udp_repr.emit(&mut packet, &src_addr.into(), &dst_addr.into(), &ChecksumCapabilities::default()); - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let ip_repr = Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: IpProtocol::Udp, - hop_limit: 64, - payload_len: udp_repr.buffer_len() - }; - #[cfg(feature = "proto-ipv6")] - let ip_repr = Ipv6Repr { - src_addr: src_addr, - dst_addr: dst_addr, - next_header: IpProtocol::Udp, - hop_limit: 64, - payload_len: udp_repr.buffer_len() - }; - let payload = packet.into_inner(); - - // Expected packets - #[cfg(feature = "proto-ipv6")] - let expected_icmp_repr = Icmpv6Repr::DstUnreachable { - reason: Icmpv6DstUnreachable::PortUnreachable, - header: ip_repr, - data: &payload[..MAX_PAYLOAD_LEN] - }; - #[cfg(feature = "proto-ipv6")] - let expected_ip_repr = Ipv6Repr { - src_addr: dst_addr, - dst_addr: src_addr, - next_header: IpProtocol::Icmpv6, - hop_limit: 64, - payload_len: expected_icmp_repr.buffer_len() - }; - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let expected_icmp_repr = Icmpv4Repr::DstUnreachable { - reason: Icmpv4DstUnreachable::PortUnreachable, - header: ip_repr, - data: &payload[..MAX_PAYLOAD_LEN] - }; - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - let expected_ip_repr = Ipv4Repr { - src_addr: dst_addr, - dst_addr: src_addr, - protocol: IpProtocol::Icmp, - hop_limit: 64, - payload_len: expected_icmp_repr.buffer_len() - }; - - // The expected packet does not exceed the IPV4_MIN_MTU - assert_eq!(expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), MIN_MTU); - // The expected packet and the generated packet are equal - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr.into(), false, payload), - Ok(Packet::Icmpv4((expected_ip_repr, expected_icmp_repr)))); - #[cfg(feature = "proto-ipv6")] - assert_eq!(iface.inner.process_udp(&mut socket_set, ip_repr.into(), false, payload), - Ok(Packet::Icmpv6((expected_ip_repr, expected_icmp_repr)))); - } - - #[test] - #[cfg(feature = "proto-ipv4")] - fn test_handle_valid_arp_request() { - let (mut iface, mut socket_set) = create_loopback(); - - let mut eth_bytes = vec![0u8; 42]; - - let local_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); - let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); - let local_hw_addr = EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); - let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); - - let repr = ArpRepr::EthernetIpv4 { - operation: ArpOperation::Request, - source_hardware_addr: remote_hw_addr, - source_protocol_addr: remote_ip_addr, - target_hardware_addr: EthernetAddress::default(), - target_protocol_addr: local_ip_addr, - }; - - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - frame.set_dst_addr(EthernetAddress::BROADCAST); - frame.set_src_addr(remote_hw_addr); - frame.set_ethertype(EthernetProtocol::Arp); - { - let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); - repr.emit(&mut packet); - } - - // Ensure an ARP Request for us triggers an ARP Reply - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), - Ok(Packet::Arp(ArpRepr::EthernetIpv4 { - operation: ArpOperation::Reply, - source_hardware_addr: local_hw_addr, - source_protocol_addr: local_ip_addr, - target_hardware_addr: remote_hw_addr, - target_protocol_addr: remote_ip_addr - }))); - - // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), - &IpAddress::Ipv4(local_ip_addr), &IpAddress::Ipv4(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); - } - - #[test] - #[cfg(feature = "proto-ipv6")] - fn test_handle_valid_ndisc_request() { - let (mut iface, mut socket_set) = create_loopback(); - - let mut eth_bytes = vec![0u8; 86]; - - let local_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1); - let remote_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2); - let local_hw_addr = EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); - let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); - - let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { - target_addr: local_ip_addr, - lladdr: Some(remote_hw_addr), - }); - let ip_repr = IpRepr::Ipv6(Ipv6Repr { - src_addr: remote_ip_addr, - dst_addr: local_ip_addr.solicited_node(), - next_header: IpProtocol::Icmpv6, - hop_limit: 0xff, - payload_len: solicit.buffer_len() - }); - - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00])); - frame.set_src_addr(remote_hw_addr); - frame.set_ethertype(EthernetProtocol::Ipv6); - { - ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); - solicit.emit(&remote_ip_addr.into(), &local_ip_addr.solicited_node().into(), - &mut Icmpv6Packet::new_unchecked( - &mut frame.payload_mut()[ip_repr.buffer_len()..]), - &ChecksumCapabilities::default()); - } - - let icmpv6_expected = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { - flags: NdiscNeighborFlags::SOLICITED, - target_addr: local_ip_addr, - lladdr: Some(local_hw_addr) - }); - - let ipv6_expected = Ipv6Repr { - src_addr: local_ip_addr, - dst_addr: remote_ip_addr, - next_header: IpProtocol::Icmpv6, - hop_limit: 0xff, - payload_len: icmpv6_expected.buffer_len() - }; - - // Ensure an Neighbor Solicitation triggers a Neighbor Advertisement - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), - Ok(Packet::Icmpv6((ipv6_expected, icmpv6_expected)))); - - // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), - &IpAddress::Ipv6(local_ip_addr), &IpAddress::Ipv6(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); - } - - #[test] - #[cfg(feature = "proto-ipv4")] - fn test_handle_other_arp_request() { - let (mut iface, mut socket_set) = create_loopback(); - - let mut eth_bytes = vec![0u8; 42]; - - let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); - let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); - - let repr = ArpRepr::EthernetIpv4 { - operation: ArpOperation::Request, - source_hardware_addr: remote_hw_addr, - source_protocol_addr: remote_ip_addr, - target_hardware_addr: EthernetAddress::default(), - target_protocol_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x03]), - }; - - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - frame.set_dst_addr(EthernetAddress::BROADCAST); - frame.set_src_addr(remote_hw_addr); - frame.set_ethertype(EthernetProtocol::Arp); - { - let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); - repr.emit(&mut packet); - } - - // Ensure an ARP Request for someone else does not trigger an ARP Reply - assert_eq!(iface.inner.process_ethernet(&mut socket_set, Instant::from_millis(0), frame.into_inner()), - Ok(Packet::None)); - - // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), - &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), - &IpAddress::Ipv4(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); - } - - #[test] - #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] - fn test_icmpv4_socket() { - use socket::{IcmpSocket, IcmpEndpoint, IcmpSocketBuffer, IcmpPacketMetadata}; - use wire::Icmpv4Packet; - - let (iface, mut socket_set) = create_loopback(); - - let rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 24]); - let tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 24]); - - let icmpv4_socket = IcmpSocket::new(rx_buffer, tx_buffer); - - let socket_handle = socket_set.add(icmpv4_socket); - - let ident = 0x1234; - let seq_no = 0x5432; - let echo_data = &[0xff; 16]; - - { - let mut socket = socket_set.get::(socket_handle); - // Bind to the ID 0x1234 - assert_eq!(socket.bind(IcmpEndpoint::Ident(ident)), Ok(())); - } - - // Ensure the ident we bound to and the ident of the packet are the same. - let mut bytes = [0xff; 24]; - let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); - let echo_repr = Icmpv4Repr::EchoRequest{ ident, seq_no, data: echo_data }; - echo_repr.emit(&mut packet, &ChecksumCapabilities::default()); - let icmp_data = &packet.into_inner()[..]; - - let ipv4_repr = Ipv4Repr { - src_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x02), - dst_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x01), - protocol: IpProtocol::Icmp, - payload_len: 24, - hop_limit: 64 - }; - let ip_repr = IpRepr::Ipv4(ipv4_repr); - - // Open a socket and ensure the packet is handled due to the listening - // socket. - { - assert!(!socket_set.get::(socket_handle).can_recv()); - } - - // Confirm we still get EchoReply from `smoltcp` even with the ICMP socket listening - let echo_reply = Icmpv4Repr::EchoReply{ ident, seq_no, data: echo_data }; - let ipv4_reply = Ipv4Repr { - src_addr: ipv4_repr.dst_addr, - dst_addr: ipv4_repr.src_addr, - ..ipv4_repr - }; - assert_eq!(iface.inner.process_icmpv4(&mut socket_set, ip_repr, icmp_data), - Ok(Packet::Icmpv4((ipv4_reply, echo_reply)))); - - { - let mut socket = socket_set.get::(socket_handle); - assert!(socket.can_recv()); - assert_eq!(socket.recv(), - Ok((&icmp_data[..], - IpAddress::Ipv4(Ipv4Address::new(0x7f, 0x00, 0x00, 0x02))))); - } - } - - #[test] - #[cfg(feature = "proto-ipv6")] - fn test_solicited_node_addrs() { - let (mut iface, _) = create_loopback(); - let mut new_addrs = vec![IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 1, 2, 0, 2), 64), - IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 3, 4, 0, 0xffff), 64)]; - iface.update_ip_addrs(|addrs| { - new_addrs.extend(addrs.to_vec()); - *addrs = From::from(new_addrs); - }); - assert!(iface.inner.has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0002))); - assert!(iface.inner.has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0xffff))); - assert!(!iface.inner.has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0003))); - } - - #[test] - #[cfg(feature = "proto-ipv6")] - fn test_icmpv6_nxthdr_unknown() { - let (mut iface, mut socket_set) = create_loopback(); - - let remote_ip_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); - let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x01]); - - let mut eth_bytes = vec![0; 66]; - let payload = [0x12, 0x34, 0x56, 0x78]; - - let ipv6_repr = Ipv6Repr { - src_addr: remote_ip_addr, - dst_addr: Ipv6Address::LOOPBACK, - next_header: IpProtocol::HopByHop, - payload_len: 12, - hop_limit: 0x40, - }; - - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - let ip_repr = IpRepr::Ipv6(ipv6_repr); - frame.set_dst_addr(EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00])); - frame.set_src_addr(remote_hw_addr); - frame.set_ethertype(EthernetProtocol::Ipv6); - ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); - let mut offset = ipv6_repr.buffer_len(); - { - let mut hbh_pkt = - Ipv6HopByHopHeader::new_unchecked(&mut frame.payload_mut()[offset..]); - hbh_pkt.set_next_header(IpProtocol::Unknown(0x0c)); - hbh_pkt.set_header_len(0); - offset += 8; - { - let mut pad_pkt = Ipv6Option::new_unchecked(&mut hbh_pkt.options_mut()[..]); - Ipv6OptionRepr::PadN(3).emit(&mut pad_pkt); - } - { - let mut pad_pkt = Ipv6Option::new_unchecked(&mut hbh_pkt.options_mut()[5..]); - Ipv6OptionRepr::Pad1.emit(&mut pad_pkt); - } - } - frame.payload_mut()[offset..].copy_from_slice(&payload); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - let reply_icmp_repr = Icmpv6Repr::ParamProblem { - reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, - pointer: 40, - header: ipv6_repr, - data: &payload[..] - }; - - let reply_ipv6_repr = Ipv6Repr { - src_addr: Ipv6Address::LOOPBACK, - dst_addr: remote_ip_addr, - next_header: IpProtocol::Icmpv6, - payload_len: reply_icmp_repr.buffer_len(), - hop_limit: 0x40, - }; - - // Ensure the unknown next header causes a ICMPv6 Parameter Problem - // error message to be sent to the sender. - assert_eq!(iface.inner.process_ipv6(&mut socket_set, Instant::from_millis(0), &frame), - Ok(Packet::Icmpv6((reply_ipv6_repr, reply_icmp_repr)))); - - // Ensure the address of the requestor was entered in the cache - assert_eq!(iface.inner.lookup_hardware_addr(MockTxToken, Instant::from_secs(0), - &IpAddress::Ipv6(Ipv6Address::LOOPBACK), - &IpAddress::Ipv6(remote_ip_addr)), - Ok((remote_hw_addr, MockTxToken))); - } - - #[test] - #[cfg(feature = "proto-igmp")] - fn test_handle_igmp() { - fn recv_igmp<'b>(mut iface: &mut EthernetInterface<'static, 'b, 'static, Loopback>, timestamp: Instant) -> Vec<(Ipv4Repr, IgmpRepr)> { - let checksum_caps = &iface.device.capabilities().checksum; - recv_all(&mut iface, timestamp) - .iter() - .filter_map(|frame| { - let eth_frame = EthernetFrame::new_checked(frame).ok()?; - let ipv4_packet = Ipv4Packet::new_checked(eth_frame.payload()).ok()?; - let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, &checksum_caps).ok()?; - let ip_payload = ipv4_packet.payload(); - let igmp_packet = IgmpPacket::new_checked(ip_payload).ok()?; - let igmp_repr = IgmpRepr::parse(&igmp_packet).ok()?; - Some((ipv4_repr, igmp_repr)) - }) - .collect::>() - } - - let groups = [ - Ipv4Address::new(224, 0, 0, 22), - Ipv4Address::new(224, 0, 0, 56), - ]; - - let (mut iface, mut socket_set) = create_loopback(); - - // Join multicast groups - let timestamp = Instant::now(); - for group in &groups { - iface.join_multicast_group(*group, timestamp) - .unwrap(); - } - - let reports = recv_igmp(&mut iface, timestamp); - assert_eq!(reports.len(), 2); - for (i, group_addr) in groups.iter().enumerate() { - assert_eq!(reports[i].0.protocol, IpProtocol::Igmp); - assert_eq!(reports[i].0.dst_addr, *group_addr); - assert_eq!(reports[i].1, IgmpRepr::MembershipReport { - group_addr: *group_addr, - version: IgmpVersion::Version2, - }); - } - - // General query - let timestamp = Instant::now(); - const GENERAL_QUERY_BYTES: &[u8] = &[ - 0x01, 0x00, 0x5e, 0x00, 0x00, 0x01, 0x0a, 0x14, - 0x48, 0x01, 0x21, 0x01, 0x08, 0x00, 0x46, 0xc0, - 0x00, 0x24, 0xed, 0xb4, 0x00, 0x00, 0x01, 0x02, - 0x47, 0x43, 0xac, 0x16, 0x63, 0x04, 0xe0, 0x00, - 0x00, 0x01, 0x94, 0x04, 0x00, 0x00, 0x11, 0x64, - 0xec, 0x8f, 0x00, 0x00, 0x00, 0x00, 0x02, 0x0c, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 - ]; - { - // Transmit GENERAL_QUERY_BYTES into loopback - let tx_token = iface.device.transmit().unwrap(); - tx_token.consume( - timestamp, GENERAL_QUERY_BYTES.len(), - |buffer| { - buffer.copy_from_slice(GENERAL_QUERY_BYTES); - Ok(()) - }).unwrap(); - } - // Trigger processing until all packets received through the - // loopback have been processed, including responses to - // GENERAL_QUERY_BYTES. Therefore `recv_all()` would return 0 - // pkts that could be checked. - iface.socket_ingress(&mut socket_set, timestamp).unwrap(); - - // Leave multicast groups - let timestamp = Instant::now(); - for group in &groups { - iface.leave_multicast_group(group.clone(), timestamp) - .unwrap(); - } - - let leaves = recv_igmp(&mut iface, timestamp); - assert_eq!(leaves.len(), 2); - for (i, group_addr) in groups.iter().cloned().enumerate() { - assert_eq!(leaves[i].0.protocol, IpProtocol::Igmp); - assert_eq!(leaves[i].0.dst_addr, Ipv4Address::MULTICAST_ALL_ROUTERS); - assert_eq!(leaves[i].1, IgmpRepr::LeaveGroup { group_addr }); - } - } - - #[test] - #[cfg(all(feature = "proto-ipv4", feature = "socket-raw"))] - fn test_raw_socket_no_reply() { - use socket::{RawSocket, RawSocketBuffer, RawPacketMetadata}; - use wire::{IpVersion, Ipv4Packet, UdpPacket, UdpRepr}; - - let (mut iface, mut socket_set) = create_loopback(); - - let packets = 1; - let rx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * 1]); - let tx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * packets]); - let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); - socket_set.add(raw_socket); - - let src_addr = Ipv4Address([127, 0, 0, 2]); - let dst_addr = Ipv4Address([127, 0, 0, 1]); - - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &[0x2a; 10] - }; - let mut bytes = vec![0xff; udp_repr.buffer_len()]; - let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); - udp_repr.emit(&mut packet, &src_addr.into(), &dst_addr.into(), &ChecksumCapabilities::default()); - let ipv4_repr = Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: IpProtocol::Udp, - hop_limit: 64, - payload_len: udp_repr.buffer_len() - }; - - // Emit to ethernet frame - let mut eth_bytes = vec![0u8; - EthernetFrame::<&[u8]>::header_len() + - ipv4_repr.buffer_len() + udp_repr.buffer_len() - ]; - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - ipv4_repr.emit( - &mut Ipv4Packet::new_unchecked(frame.payload_mut()), - &ChecksumCapabilities::default()); - udp_repr.emit( - &mut UdpPacket::new_unchecked( - &mut frame.payload_mut()[ipv4_repr.buffer_len()..]), - &src_addr.into(), - &dst_addr.into(), - &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), - Ok(Packet::None)); - } - - #[test] - #[cfg(all(feature = "proto-ipv4", feature = "socket-raw"))] - fn test_raw_socket_truncated_packet() { - use socket::{RawSocket, RawSocketBuffer, RawPacketMetadata}; - use wire::{IpVersion, Ipv4Packet, UdpPacket, UdpRepr}; - - let (mut iface, mut socket_set) = create_loopback(); - - let packets = 1; - let rx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * 1]); - let tx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * packets]); - let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); - socket_set.add(raw_socket); - - let src_addr = Ipv4Address([127, 0, 0, 2]); - let dst_addr = Ipv4Address([127, 0, 0, 1]); - - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &[0x2a; 49] // 49 > 48, hence packet will be truncated - }; - let mut bytes = vec![0xff; udp_repr.buffer_len()]; - let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); - udp_repr.emit(&mut packet, &src_addr.into(), &dst_addr.into(), &ChecksumCapabilities::default()); - let ipv4_repr = Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: IpProtocol::Udp, - hop_limit: 64, - payload_len: udp_repr.buffer_len() - }; - - // Emit to ethernet frame - let mut eth_bytes = vec![0u8; - EthernetFrame::<&[u8]>::header_len() + - ipv4_repr.buffer_len() + udp_repr.buffer_len() - ]; - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - ipv4_repr.emit( - &mut Ipv4Packet::new_unchecked(frame.payload_mut()), - &ChecksumCapabilities::default()); - udp_repr.emit( - &mut UdpPacket::new_unchecked( - &mut frame.payload_mut()[ipv4_repr.buffer_len()..]), - &src_addr.into(), - &dst_addr.into(), - &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - let frame = iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame); - - // because the packet could not be handled we should send an Icmp message - assert!(match frame { - Ok(Packet::Icmpv4(_)) => true, - _ => false, - }); - } - - #[test] - #[cfg(all(feature = "proto-ipv4", feature = "socket-raw", feature = "socket-udp"))] - fn test_raw_socket_with_udp_socket() { - use socket::{UdpSocket, UdpSocketBuffer, UdpPacketMetadata, - RawSocket, RawSocketBuffer, RawPacketMetadata}; - use wire::{IpVersion, IpEndpoint, Ipv4Packet, UdpPacket, UdpRepr}; - - static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; - - let (mut iface, mut socket_set) = create_loopback(); - - let udp_rx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]); - let udp_tx_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY], vec![0; 15]); - let udp_socket = UdpSocket::new(udp_rx_buffer, udp_tx_buffer); - let udp_socket_handle = socket_set.add(udp_socket); - { - // Bind the socket to port 68 - let mut socket = socket_set.get::(udp_socket_handle); - assert_eq!(socket.bind(68), Ok(())); - assert!(!socket.can_recv()); - assert!(socket.can_send()); - } - - let packets = 1; - let raw_rx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * 1]); - let raw_tx_buffer = RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * packets]); - let raw_socket = RawSocket::new(IpVersion::Ipv4, IpProtocol::Udp, raw_rx_buffer, raw_tx_buffer); - socket_set.add(raw_socket); - - let src_addr = Ipv4Address([127, 0, 0, 2]); - let dst_addr = Ipv4Address([127, 0, 0, 1]); - - let udp_repr = UdpRepr { - src_port: 67, - dst_port: 68, - payload: &UDP_PAYLOAD - }; - let mut bytes = vec![0xff; udp_repr.buffer_len()]; - let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); - udp_repr.emit(&mut packet, &src_addr.into(), &dst_addr.into(), &ChecksumCapabilities::default()); - let ipv4_repr = Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: IpProtocol::Udp, - hop_limit: 64, - payload_len: udp_repr.buffer_len() - }; - - // Emit to ethernet frame - let mut eth_bytes = vec![0u8; - EthernetFrame::<&[u8]>::header_len() + - ipv4_repr.buffer_len() + udp_repr.buffer_len() - ]; - let frame = { - let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); - ipv4_repr.emit( - &mut Ipv4Packet::new_unchecked(frame.payload_mut()), - &ChecksumCapabilities::default()); - udp_repr.emit( - &mut UdpPacket::new_unchecked( - &mut frame.payload_mut()[ipv4_repr.buffer_len()..]), - &src_addr.into(), - &dst_addr.into(), - &ChecksumCapabilities::default()); - EthernetFrame::new_unchecked(&*frame.into_inner()) - }; - - assert_eq!(iface.inner.process_ipv4(&mut socket_set, Instant::from_millis(0), &frame), - Ok(Packet::None)); - - { - // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP - let mut socket = socket_set.get::(udp_socket_handle); - assert!(socket.can_recv()); - assert_eq!(socket.recv(), Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_addr.into(), 67)))); - } - } -} diff --git a/src/iface/fragmentation.rs b/src/iface/fragmentation.rs new file mode 100644 index 000000000..9870ab07e --- /dev/null +++ b/src/iface/fragmentation.rs @@ -0,0 +1,345 @@ +#![allow(unused)] + +use core::fmt; + +use managed::{ManagedMap, ManagedSlice}; + +use crate::config::{REASSEMBLY_BUFFER_COUNT, REASSEMBLY_BUFFER_SIZE}; +use crate::storage::Assembler; +use crate::time::{Duration, Instant}; + +#[cfg(feature = "alloc")] +type Buffer = alloc::vec::Vec; +#[cfg(not(feature = "alloc"))] +type Buffer = [u8; REASSEMBLY_BUFFER_SIZE]; + +/// Problem when assembling: something was out of bounds. +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AssemblerError; + +impl fmt::Display for AssemblerError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AssemblerError") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for AssemblerError {} + +/// Packet assembler is full +#[derive(Copy, Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AssemblerFullError; + +impl fmt::Display for AssemblerFullError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "AssemblerFullError") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for AssemblerFullError {} + +/// Holds different fragments of one packet, used for assembling fragmented packets. +/// +/// The buffer used for the `PacketAssembler` should either be dynamically sized (ex: Vec) +/// or should be statically allocated based upon the MTU of the type of packet being +/// assembled (ex: 1280 for a IPv6 frame). +#[derive(Debug)] +pub struct PacketAssembler { + key: Option, + buffer: Buffer, + + assembler: Assembler, + total_size: Option, + expires_at: Instant, +} + +impl PacketAssembler { + /// Create a new empty buffer for fragments. + pub const fn new() -> Self { + Self { + key: None, + + #[cfg(feature = "alloc")] + buffer: Buffer::new(), + #[cfg(not(feature = "alloc"))] + buffer: [0u8; REASSEMBLY_BUFFER_SIZE], + + assembler: Assembler::new(), + total_size: None, + expires_at: Instant::ZERO, + } + } + + pub(crate) fn reset(&mut self) { + self.key = None; + self.assembler.clear(); + self.total_size = None; + self.expires_at = Instant::ZERO; + } + + /// Set the total size of the packet assembler. + pub(crate) fn set_total_size(&mut self, size: usize) -> Result<(), AssemblerError> { + if let Some(old_size) = self.total_size { + if old_size != size { + return Err(AssemblerError); + } + } + + #[cfg(not(feature = "alloc"))] + if self.buffer.len() < size { + return Err(AssemblerError); + } + + #[cfg(feature = "alloc")] + if self.buffer.len() < size { + self.buffer.resize(size, 0); + } + + self.total_size = Some(size); + Ok(()) + } + + /// Return the instant when the assembler expires. + pub(crate) fn expires_at(&self) -> Instant { + self.expires_at + } + + pub(crate) fn add_with( + &mut self, + offset: usize, + f: impl Fn(&mut [u8]) -> Result, + ) -> Result<(), AssemblerError> { + if self.buffer.len() < offset { + return Err(AssemblerError); + } + + let len = f(&mut self.buffer[offset..])?; + assert!(offset + len <= self.buffer.len()); + + net_debug!( + "frag assembler: receiving {} octets at offset {}", + len, + offset + ); + + self.assembler.add(offset, len); + Ok(()) + } + + /// Add a fragment into the packet that is being reassembled. + /// + /// # Errors + /// + /// - Returns [`Error::PacketAssemblerBufferTooSmall`] when trying to add data into the buffer at a non-existing + /// place. + pub(crate) fn add(&mut self, data: &[u8], offset: usize) -> Result<(), AssemblerError> { + #[cfg(not(feature = "alloc"))] + if self.buffer.len() < offset + data.len() { + return Err(AssemblerError); + } + + #[cfg(feature = "alloc")] + if self.buffer.len() < offset + data.len() { + self.buffer.resize(offset + data.len(), 0); + } + + let len = data.len(); + self.buffer[offset..][..len].copy_from_slice(data); + + net_debug!( + "frag assembler: receiving {} octets at offset {}", + len, + offset + ); + + self.assembler.add(offset, data.len()); + Ok(()) + } + + /// Get an immutable slice of the underlying packet data, if reassembly complete. + /// This will mark the assembler as empty, so that it can be reused. + pub(crate) fn assemble(&mut self) -> Option<&'_ [u8]> { + if !self.is_complete() { + return None; + } + + // NOTE: we can unwrap because `is_complete` already checks this. + let total_size = self.total_size.unwrap(); + self.reset(); + Some(&self.buffer[..total_size]) + } + + /// Returns `true` when all fragments have been received, otherwise `false`. + pub(crate) fn is_complete(&self) -> bool { + self.total_size == Some(self.assembler.peek_front()) + } + + /// Returns `true` when the packet assembler is free to use. + fn is_free(&self) -> bool { + self.key.is_none() + } +} + +/// Set holding multiple [`PacketAssembler`]. +#[derive(Debug)] +pub struct PacketAssemblerSet { + assemblers: [PacketAssembler; REASSEMBLY_BUFFER_COUNT], +} + +impl PacketAssemblerSet { + const NEW_PA: PacketAssembler = PacketAssembler::new(); + + /// Create a new set of packet assemblers. + pub fn new() -> Self { + Self { + assemblers: [Self::NEW_PA; REASSEMBLY_BUFFER_COUNT], + } + } + + /// Get a [`PacketAssembler`] for a specific key. + /// + /// If it doesn't exist, it is created, with the `expires_at` timestamp. + /// + /// If the assembler set is full, in which case an error is returned. + pub(crate) fn get( + &mut self, + key: &K, + expires_at: Instant, + ) -> Result<&mut PacketAssembler, AssemblerFullError> { + let mut empty_slot = None; + for slot in &mut self.assemblers { + if slot.key.as_ref() == Some(key) { + return Ok(slot); + } + if slot.is_free() { + empty_slot = Some(slot) + } + } + + let slot = empty_slot.ok_or(AssemblerFullError)?; + slot.key = Some(*key); + slot.expires_at = expires_at; + Ok(slot) + } + + /// Remove all [`PacketAssembler`]s that are expired. + pub fn remove_expired(&mut self, timestamp: Instant) { + for frag in &mut self.assemblers { + if !frag.is_free() && frag.expires_at < timestamp { + frag.reset(); + } + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[derive(PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] + struct Key { + id: usize, + } + + #[test] + fn packet_assembler_overlap() { + let mut p_assembler = PacketAssembler::::new(); + + p_assembler.set_total_size(5).unwrap(); + + let data = b"Rust"; + p_assembler.add(&data[..], 0); + p_assembler.add(&data[..], 1); + + assert_eq!(p_assembler.assemble(), Some(&b"RRust"[..])) + } + + #[test] + fn packet_assembler_assemble() { + let mut p_assembler = PacketAssembler::::new(); + + let data = b"Hello World!"; + + p_assembler.set_total_size(data.len()).unwrap(); + + p_assembler.add(b"Hello ", 0).unwrap(); + assert_eq!(p_assembler.assemble(), None); + + p_assembler.add(b"World!", b"Hello ".len()).unwrap(); + + assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); + } + + #[test] + fn packet_assembler_out_of_order_assemble() { + let mut p_assembler = PacketAssembler::::new(); + + let data = b"Hello World!"; + + p_assembler.set_total_size(data.len()).unwrap(); + + p_assembler.add(b"World!", b"Hello ".len()).unwrap(); + assert_eq!(p_assembler.assemble(), None); + + p_assembler.add(b"Hello ", 0).unwrap(); + + assert_eq!(p_assembler.assemble(), Some(&b"Hello World!"[..])); + } + + #[test] + fn packet_assembler_set() { + let key = Key { id: 1 }; + + let mut set = PacketAssemblerSet::new(); + + assert!(set.get(&key, Instant::ZERO).is_ok()); + } + + #[test] + fn packet_assembler_set_full() { + let mut set = PacketAssemblerSet::new(); + for i in 0..REASSEMBLY_BUFFER_COUNT { + set.get(&Key { id: i }, Instant::ZERO).unwrap(); + } + assert!(set.get(&Key { id: 4 }, Instant::ZERO).is_err()); + } + + #[test] + fn packet_assembler_set_assembling_many() { + let mut set = PacketAssemblerSet::new(); + + let key = Key { id: 0 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assert_eq!(assr.assemble(), None); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + // Test that `.assemble()` effectively deletes it. + let assr = set.get(&key, Instant::ZERO).unwrap(); + assert_eq!(assr.assemble(), None); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 1 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 2 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(0).unwrap(); + assr.assemble().unwrap(); + + let key = Key { id: 2 }; + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.set_total_size(2).unwrap(); + assr.add(&[0x00], 0).unwrap(); + assert_eq!(assr.assemble(), None); + let assr = set.get(&key, Instant::ZERO).unwrap(); + assr.add(&[0x01], 1).unwrap(); + assert_eq!(assr.assemble(), Some(&[0x00, 0x01][..])); + } +} diff --git a/src/iface/interface/ethernet.rs b/src/iface/interface/ethernet.rs new file mode 100644 index 000000000..c8ec3566d --- /dev/null +++ b/src/iface/interface/ethernet.rs @@ -0,0 +1,76 @@ +use super::check; +use super::DispatchError; +use super::EthernetPacket; +use super::FragmentsBuffer; +use super::InterfaceInner; +use super::SocketSet; +use core::result::Result; + +use crate::phy::TxToken; +use crate::wire::*; + +impl InterfaceInner { + #[cfg(feature = "medium-ethernet")] + pub(super) fn process_ethernet<'frame, T: AsRef<[u8]>>( + &mut self, + sockets: &mut SocketSet, + meta: crate::phy::PacketMeta, + frame: &'frame T, + fragments: &'frame mut FragmentsBuffer, + ) -> Option> { + let eth_frame = check!(EthernetFrame::new_checked(frame)); + + // Ignore any packets not directed to our hardware address or any of the multicast groups. + if !eth_frame.dst_addr().is_broadcast() + && !eth_frame.dst_addr().is_multicast() + && HardwareAddress::Ethernet(eth_frame.dst_addr()) != self.hardware_addr + { + return None; + } + + match eth_frame.ethertype() { + #[cfg(feature = "proto-ipv4")] + EthernetProtocol::Arp => self.process_arp(self.now, ð_frame), + #[cfg(feature = "proto-ipv4")] + EthernetProtocol::Ipv4 => { + let ipv4_packet = check!(Ipv4Packet::new_checked(eth_frame.payload())); + + self.process_ipv4(sockets, meta, &ipv4_packet, fragments) + .map(EthernetPacket::Ip) + } + #[cfg(feature = "proto-ipv6")] + EthernetProtocol::Ipv6 => { + let ipv6_packet = check!(Ipv6Packet::new_checked(eth_frame.payload())); + self.process_ipv6(sockets, meta, &ipv6_packet) + .map(EthernetPacket::Ip) + } + // Drop all other traffic. + _ => None, + } + } + + #[cfg(feature = "medium-ethernet")] + pub(super) fn dispatch_ethernet( + &mut self, + tx_token: Tx, + buffer_len: usize, + f: F, + ) -> Result<(), DispatchError> + where + Tx: TxToken, + F: FnOnce(EthernetFrame<&mut [u8]>), + { + let tx_len = EthernetFrame::<&[u8]>::buffer_len(buffer_len); + tx_token.consume(tx_len, |tx_buffer| { + debug_assert!(tx_buffer.as_ref().len() == tx_len); + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + + f(frame); + + Ok(()) + }) + } +} diff --git a/src/iface/interface/ieee802154.rs b/src/iface/interface/ieee802154.rs new file mode 100644 index 000000000..3e3fdf076 --- /dev/null +++ b/src/iface/interface/ieee802154.rs @@ -0,0 +1,94 @@ +use super::*; + +use crate::phy::TxToken; +use crate::wire::*; + +impl InterfaceInner { + pub(super) fn process_ieee802154<'output, 'payload: 'output, T: AsRef<[u8]> + ?Sized>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + sixlowpan_payload: &'payload T, + _fragments: &'output mut FragmentsBuffer, + ) -> Option> { + let ieee802154_frame = check!(Ieee802154Frame::new_checked(sixlowpan_payload)); + let ieee802154_repr = check!(Ieee802154Repr::parse(&ieee802154_frame)); + + if ieee802154_repr.frame_type != Ieee802154FrameType::Data { + return None; + } + + // Drop frames when the user has set a PAN id and the PAN id from frame is not equal to this + // When the user didn't set a PAN id (so it is None), then we accept all PAN id's. + // We always accept the broadcast PAN id. + if self.pan_id.is_some() + && ieee802154_repr.dst_pan_id != self.pan_id + && ieee802154_repr.dst_pan_id != Some(Ieee802154Pan::BROADCAST) + { + net_debug!( + "IEEE802.15.4: dropping {:?} because not our PAN id (or not broadcast)", + ieee802154_repr + ); + return None; + } + + match ieee802154_frame.payload() { + Some(payload) => { + self.process_sixlowpan(sockets, meta, &ieee802154_repr, payload, _fragments) + } + None => None, + } + } + + pub(super) fn dispatch_ieee802154( + &mut self, + ll_dst_a: Ieee802154Address, + tx_token: Tx, + meta: PacketMeta, + packet: IpPacket, + frag: &mut Fragmenter, + ) { + let ll_src_a = self.hardware_addr.ieee802154_or_panic(); + + // Create the IEEE802.15.4 header. + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(self.get_sequence_number()), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: self.pan_id, + dst_addr: Some(ll_dst_a), + src_pan_id: self.pan_id, + src_addr: Some(ll_src_a), + }; + + self.dispatch_sixlowpan(tx_token, meta, packet, ieee_repr, frag); + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + pub(super) fn dispatch_ieee802154_frag( + &mut self, + tx_token: Tx, + frag: &mut Fragmenter, + ) { + // Create the IEEE802.15.4 header. + let ieee_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(self.get_sequence_number()), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: self.pan_id, + dst_addr: Some(frag.sixlowpan.ll_dst_addr), + src_pan_id: self.pan_id, + src_addr: Some(frag.sixlowpan.ll_src_addr), + }; + + self.dispatch_sixlowpan_frag(tx_token, ieee_repr, frag); + } +} diff --git a/src/iface/interface/igmp.rs b/src/iface/interface/igmp.rs new file mode 100644 index 000000000..9bf6ad946 --- /dev/null +++ b/src/iface/interface/igmp.rs @@ -0,0 +1,289 @@ +use super::{check, IgmpReportState, Interface, InterfaceInner, IpPacket}; +use crate::phy::{Device, PacketMeta}; +use crate::time::{Duration, Instant}; +use crate::wire::*; + +use core::result::Result; + +/// Error type for `join_multicast_group`, `leave_multicast_group`. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum MulticastError { + /// The hardware device transmit buffer is full. Try again later. + Exhausted, + /// The table of joined multicast groups is already full. + GroupTableFull, + /// IPv6 multicast is not yet supported. + Ipv6NotSupported, +} + +impl core::fmt::Display for MulticastError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + MulticastError::Exhausted => write!(f, "Exhausted"), + MulticastError::GroupTableFull => write!(f, "GroupTableFull"), + MulticastError::Ipv6NotSupported => write!(f, "Ipv6NotSupported"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for MulticastError {} + +impl Interface { + /// Add an address to a list of subscribed multicast IP addresses. + /// + /// Returns `Ok(announce_sent)` if the address was added successfully, where `annouce_sent` + /// indicates whether an initial immediate announcement has been sent. + pub fn join_multicast_group>( + &mut self, + device: &mut D, + addr: T, + timestamp: Instant, + ) -> Result + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + match addr.into() { + IpAddress::Ipv4(addr) => { + let is_not_new = self + .inner + .ipv4_multicast_groups + .insert(addr, ()) + .map_err(|_| MulticastError::GroupTableFull)? + .is_some(); + if is_not_new { + Ok(false) + } else if let Some(pkt) = self.inner.igmp_report_packet(IgmpVersion::Version2, addr) + { + // Send initial membership report + let tx_token = device + .transmit(timestamp) + .ok_or(MulticastError::Exhausted)?; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + + Ok(true) + } else { + Ok(false) + } + } + // Multicast is not yet implemented for other address families + #[allow(unreachable_patterns)] + _ => Err(MulticastError::Ipv6NotSupported), + } + } + + /// Remove an address from the subscribed multicast IP addresses. + /// + /// Returns `Ok(leave_sent)` if the address was removed successfully, where `leave_sent` + /// indicates whether an immediate leave packet has been sent. + pub fn leave_multicast_group>( + &mut self, + device: &mut D, + addr: T, + timestamp: Instant, + ) -> Result + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + match addr.into() { + IpAddress::Ipv4(addr) => { + let was_not_present = self.inner.ipv4_multicast_groups.remove(&addr).is_none(); + if was_not_present { + Ok(false) + } else if let Some(pkt) = self.inner.igmp_leave_packet(addr) { + // Send group leave packet + let tx_token = device + .transmit(timestamp) + .ok_or(MulticastError::Exhausted)?; + + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + + Ok(true) + } else { + Ok(false) + } + } + // Multicast is not yet implemented for other address families + #[allow(unreachable_patterns)] + _ => Err(MulticastError::Ipv6NotSupported), + } + } + + /// Check whether the interface listens to given destination multicast IP address. + pub fn has_multicast_group>(&self, addr: T) -> bool { + self.inner.has_multicast_group(addr) + } + + /// Depending on `igmp_report_state` and the therein contained + /// timeouts, send IGMP membership reports. + pub(crate) fn igmp_egress(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + match self.inner.igmp_report_state { + IgmpReportState::ToSpecificQuery { + version, + timeout, + group, + } if self.inner.now >= timeout => { + if let Some(pkt) = self.inner.igmp_report_packet(version, group) { + // Send initial membership report + if let Some(tx_token) = device.transmit(self.inner.now) { + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip(tx_token, PacketMeta::default(), pkt, &mut self.fragmenter) + .unwrap(); + } else { + return false; + } + } + + self.inner.igmp_report_state = IgmpReportState::Inactive; + true + } + IgmpReportState::ToGeneralQuery { + version, + timeout, + interval, + next_index, + } if self.inner.now >= timeout => { + let addr = self + .inner + .ipv4_multicast_groups + .iter() + .nth(next_index) + .map(|(addr, ())| *addr); + + match addr { + Some(addr) => { + if let Some(pkt) = self.inner.igmp_report_packet(version, addr) { + // Send initial membership report + if let Some(tx_token) = device.transmit(self.inner.now) { + // NOTE(unwrap): packet destination is multicast, which is always routable and doesn't require neighbor discovery. + self.inner + .dispatch_ip( + tx_token, + PacketMeta::default(), + pkt, + &mut self.fragmenter, + ) + .unwrap(); + } else { + return false; + } + } + + let next_timeout = (timeout + interval).max(self.inner.now); + self.inner.igmp_report_state = IgmpReportState::ToGeneralQuery { + version, + timeout: next_timeout, + interval, + next_index: next_index + 1, + }; + true + } + + None => { + self.inner.igmp_report_state = IgmpReportState::Inactive; + false + } + } + } + _ => false, + } + } +} + +impl InterfaceInner { + /// Check whether the interface listens to given destination multicast IP address. + /// + /// If built without feature `proto-igmp` this function will + /// always return `false`. + pub fn has_multicast_group>(&self, addr: T) -> bool { + match addr.into() { + IpAddress::Ipv4(key) => { + key == Ipv4Address::MULTICAST_ALL_SYSTEMS + || self.ipv4_multicast_groups.get(&key).is_some() + } + #[allow(unreachable_patterns)] + _ => false, + } + } + + /// Host duties of the **IGMPv2** protocol. + /// + /// Sets up `igmp_report_state` for responding to IGMP general/specific membership queries. + /// Membership must not be reported immediately in order to avoid flooding the network + /// after a query is broadcasted by a router; this is not currently done. + pub(super) fn process_igmp<'frame>( + &mut self, + ipv4_repr: Ipv4Repr, + ip_payload: &'frame [u8], + ) -> Option> { + let igmp_packet = check!(IgmpPacket::new_checked(ip_payload)); + let igmp_repr = check!(IgmpRepr::parse(&igmp_packet)); + + // FIXME: report membership after a delay + match igmp_repr { + IgmpRepr::MembershipQuery { + group_addr, + version, + max_resp_time, + } => { + // General query + if group_addr.is_unspecified() + && ipv4_repr.dst_addr == Ipv4Address::MULTICAST_ALL_SYSTEMS + { + // Are we member in any groups? + if self.ipv4_multicast_groups.iter().next().is_some() { + let interval = match version { + IgmpVersion::Version1 => Duration::from_millis(100), + IgmpVersion::Version2 => { + // No dependence on a random generator + // (see [#24](https://github.com/m-labs/smoltcp/issues/24)) + // but at least spread reports evenly across max_resp_time. + let intervals = self.ipv4_multicast_groups.len() as u32 + 1; + max_resp_time / intervals + } + }; + self.igmp_report_state = IgmpReportState::ToGeneralQuery { + version, + timeout: self.now + interval, + interval, + next_index: 0, + }; + } + } else { + // Group-specific query + if self.has_multicast_group(group_addr) && ipv4_repr.dst_addr == group_addr { + // Don't respond immediately + let timeout = max_resp_time / 4; + self.igmp_report_state = IgmpReportState::ToSpecificQuery { + version, + timeout: self.now + timeout, + group: group_addr, + }; + } + } + } + // Ignore membership reports + IgmpRepr::MembershipReport { .. } => (), + // Ignore hosts leaving groups + IgmpRepr::LeaveGroup { .. } => (), + } + + None + } +} diff --git a/src/iface/interface/ipv4.rs b/src/iface/interface/ipv4.rs new file mode 100644 index 000000000..dc351c436 --- /dev/null +++ b/src/iface/interface/ipv4.rs @@ -0,0 +1,442 @@ +use super::*; + +#[cfg(feature = "socket-dhcpv4")] +use crate::socket::dhcpv4; +#[cfg(feature = "socket-icmp")] +use crate::socket::icmp; +use crate::socket::AnySocket; + +use crate::phy::{Medium, TxToken}; +use crate::time::Instant; +use crate::wire::*; + +impl InterfaceInner { + pub(super) fn process_ipv4<'a, T: AsRef<[u8]> + ?Sized>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv4_packet: &Ipv4Packet<&'a T>, + frag: &'a mut FragmentsBuffer, + ) -> Option> { + let ipv4_repr = check!(Ipv4Repr::parse(ipv4_packet, &self.caps.checksum)); + if !self.is_unicast_v4(ipv4_repr.src_addr) && !ipv4_repr.src_addr.is_unspecified() { + // Discard packets with non-unicast source addresses but allow unspecified + net_debug!("non-unicast or unspecified source address"); + return None; + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + let ip_payload = { + if ipv4_packet.more_frags() || ipv4_packet.frag_offset() != 0 { + let key = FragKey::Ipv4(ipv4_packet.get_key()); + + let f = match frag.assembler.get(&key, self.now + frag.reassembly_timeout) { + Ok(f) => f, + Err(_) => { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + }; + + if !ipv4_packet.more_frags() { + // This is the last fragment, so we know the total size + check!(f.set_total_size( + ipv4_packet.total_len() as usize - ipv4_packet.header_len() as usize + + ipv4_packet.frag_offset() as usize, + )); + } + + if let Err(e) = f.add(ipv4_packet.payload(), ipv4_packet.frag_offset() as usize) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + + // NOTE: according to the standard, the total length needs to be + // recomputed, as well as the checksum. However, we don't really use + // the IPv4 header after the packet is reassembled. + match f.assemble() { + Some(payload) => payload, + None => return None, + } + } else { + ipv4_packet.payload() + } + }; + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] + let ip_payload = ipv4_packet.payload(); + + let ip_repr = IpRepr::Ipv4(ipv4_repr); + + #[cfg(feature = "socket-raw")] + let handled_by_raw_socket = self.raw_socket_filter(sockets, &ip_repr, ip_payload); + #[cfg(not(feature = "socket-raw"))] + let handled_by_raw_socket = false; + + #[cfg(feature = "socket-dhcpv4")] + { + if ipv4_repr.next_header == IpProtocol::Udp + && matches!(self.caps.medium, Medium::Ethernet) + { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + if let Some(dhcp_socket) = sockets + .items_mut() + .find_map(|i| dhcpv4::Socket::downcast_mut(&mut i.socket)) + { + // First check for source and dest ports, then do `UdpRepr::parse` if they match. + // This way we avoid validating the UDP checksum twice for all non-DHCP UDP packets (one here, one in `process_udp`) + if udp_packet.src_port() == dhcp_socket.server_port + && udp_packet.dst_port() == dhcp_socket.client_port + { + let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &src_addr, + &dst_addr, + &self.caps.checksum + )); + let udp_payload = udp_packet.payload(); + + dhcp_socket.process(self, &ipv4_repr, &udp_repr, udp_payload); + return None; + } + } + } + } + + if !self.has_ip_addr(ipv4_repr.dst_addr) + && !self.has_multicast_group(ipv4_repr.dst_addr) + && !self.is_broadcast_v4(ipv4_repr.dst_addr) + { + // Ignore IP packets not directed at us, or broadcast, or any of the multicast groups. + // If AnyIP is enabled, also check if the packet is routed locally. + if !self.any_ip + || !ipv4_repr.dst_addr.is_unicast() + || self + .routes + .lookup(&IpAddress::Ipv4(ipv4_repr.dst_addr), self.now) + .map_or(true, |router_addr| !self.has_ip_addr(router_addr)) + { + return None; + } + } + + match ipv4_repr.next_header { + IpProtocol::Icmp => self.process_icmpv4(sockets, ip_repr, ip_payload), + + #[cfg(feature = "proto-igmp")] + IpProtocol::Igmp => self.process_igmp(ipv4_repr, ip_payload), + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpProtocol::Udp => { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &ipv4_repr.src_addr.into(), + &ipv4_repr.dst_addr.into(), + &self.checksum_caps(), + )); + + self.process_udp( + sockets, + meta, + ip_repr, + udp_repr, + handled_by_raw_socket, + udp_packet.payload(), + ip_payload, + ) + } + + #[cfg(feature = "socket-tcp")] + IpProtocol::Tcp => self.process_tcp(sockets, ip_repr, ip_payload), + + _ if handled_by_raw_socket => None, + + _ => { + // Send back as much of the original payload as we can. + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, ipv4_repr.buffer_len()); + let icmp_reply_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::ProtoUnreachable, + header: ipv4_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv4_reply(ipv4_repr, icmp_reply_repr) + } + } + } + + #[cfg(feature = "medium-ethernet")] + pub(super) fn process_arp<'frame, T: AsRef<[u8]>>( + &mut self, + timestamp: Instant, + eth_frame: &EthernetFrame<&'frame T>, + ) -> Option> { + let arp_packet = check!(ArpPacket::new_checked(eth_frame.payload())); + let arp_repr = check!(ArpRepr::parse(&arp_packet)); + + match arp_repr { + ArpRepr::EthernetIpv4 { + operation, + source_hardware_addr, + source_protocol_addr, + target_protocol_addr, + .. + } => { + // Only process ARP packets for us. + if !self.has_ip_addr(target_protocol_addr) { + return None; + } + + // Only process REQUEST and RESPONSE. + if let ArpOperation::Unknown(_) = operation { + net_debug!("arp: unknown operation code"); + return None; + } + + // Discard packets with non-unicast source addresses. + if !source_protocol_addr.is_unicast() || !source_hardware_addr.is_unicast() { + net_debug!("arp: non-unicast source address"); + return None; + } + + if !self.in_same_network(&IpAddress::Ipv4(source_protocol_addr)) { + net_debug!("arp: source IP address not in same network as us"); + return None; + } + + // Fill the ARP cache from any ARP packet aimed at us (both request or response). + // We fill from requests too because if someone is requesting our address they + // are probably going to talk to us, so we avoid having to request their address + // when we later reply to them. + self.neighbor_cache.fill( + source_protocol_addr.into(), + source_hardware_addr.into(), + timestamp, + ); + + if operation == ArpOperation::Request { + let src_hardware_addr = self.hardware_addr.ethernet_or_panic(); + + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: src_hardware_addr, + source_protocol_addr: target_protocol_addr, + target_hardware_addr: source_hardware_addr, + target_protocol_addr: source_protocol_addr, + })) + } else { + None + } + } + } + } + + pub(super) fn process_icmpv4<'frame>( + &mut self, + _sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option> { + let icmp_packet = check!(Icmpv4Packet::new_checked(ip_payload)); + let icmp_repr = check!(Icmpv4Repr::parse(&icmp_packet, &self.caps.checksum)); + + #[cfg(feature = "socket-icmp")] + let mut handled_by_icmp_socket = false; + + #[cfg(all(feature = "socket-icmp", feature = "proto-ipv4"))] + for icmp_socket in _sockets + .items_mut() + .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket)) + { + if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) { + icmp_socket.process(self, &ip_repr, &icmp_repr.into()); + handled_by_icmp_socket = true; + } + } + + match icmp_repr { + // Respond to echo requests. + #[cfg(feature = "proto-ipv4")] + Icmpv4Repr::EchoRequest { + ident, + seq_no, + data, + } => { + let icmp_reply_repr = Icmpv4Repr::EchoReply { + ident, + seq_no, + data, + }; + match ip_repr { + IpRepr::Ipv4(ipv4_repr) => self.icmpv4_reply(ipv4_repr, icmp_reply_repr), + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + } + + // Ignore any echo replies. + Icmpv4Repr::EchoReply { .. } => None, + + // Don't report an error if a packet with unknown type + // has been handled by an ICMP socket + #[cfg(feature = "socket-icmp")] + _ if handled_by_icmp_socket => None, + + // FIXME: do something correct here? + _ => None, + } + } + + pub(super) fn icmpv4_reply<'frame, 'icmp: 'frame>( + &self, + ipv4_repr: Ipv4Repr, + icmp_repr: Icmpv4Repr<'icmp>, + ) -> Option> { + if !self.is_unicast_v4(ipv4_repr.src_addr) { + // Do not send ICMP replies to non-unicast sources + None + } else if self.is_unicast_v4(ipv4_repr.dst_addr) { + // Reply as normal when src_addr and dst_addr are both unicast + let ipv4_reply_repr = Ipv4Repr { + src_addr: ipv4_repr.dst_addr, + dst_addr: ipv4_repr.src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(IpPacket::Icmpv4((ipv4_reply_repr, icmp_repr))) + } else if self.is_broadcast_v4(ipv4_repr.dst_addr) { + // Only reply to broadcasts for echo replies and not other ICMP messages + match icmp_repr { + Icmpv4Repr::EchoReply { .. } => match self.ipv4_addr() { + Some(src_addr) => { + let ipv4_reply_repr = Ipv4Repr { + src_addr, + dst_addr: ipv4_repr.src_addr, + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(IpPacket::Icmpv4((ipv4_reply_repr, icmp_repr))) + } + None => None, + }, + _ => None, + } + } else { + None + } + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + pub(super) fn dispatch_ipv4_frag(&mut self, tx_token: Tx, frag: &mut Fragmenter) { + let caps = self.caps.clone(); + + let mtu_max = self.ip_mtu(); + let ip_len = (frag.packet_len - frag.sent_bytes + frag.ipv4.repr.buffer_len()).min(mtu_max); + let payload_len = ip_len - frag.ipv4.repr.buffer_len(); + + let more_frags = (frag.packet_len - frag.sent_bytes) != payload_len; + frag.ipv4.repr.payload_len = payload_len; + frag.sent_bytes += payload_len; + + let mut tx_len = ip_len; + #[cfg(feature = "medium-ethernet")] + if matches!(caps.medium, Medium::Ethernet) { + tx_len += EthernetFrame::<&[u8]>::header_len(); + } + + // Emit function for the Ethernet header. + #[cfg(feature = "medium-ethernet")] + let emit_ethernet = |repr: &IpRepr, tx_buffer: &mut [u8]| { + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + frame.set_dst_addr(frag.ipv4.dst_hardware_addr); + + match repr.version() { + #[cfg(feature = "proto-ipv4")] + IpVersion::Ipv4 => frame.set_ethertype(EthernetProtocol::Ipv4), + #[cfg(feature = "proto-ipv6")] + IpVersion::Ipv6 => frame.set_ethertype(EthernetProtocol::Ipv6), + } + }; + + tx_token.consume(tx_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&IpRepr::Ipv4(frag.ipv4.repr), tx_buffer); + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + let mut packet = + Ipv4Packet::new_unchecked(&mut tx_buffer[..frag.ipv4.repr.buffer_len()]); + frag.ipv4.repr.emit(&mut packet, &caps.checksum); + packet.set_ident(frag.ipv4.ident); + packet.set_more_frags(more_frags); + packet.set_dont_frag(false); + packet.set_frag_offset(frag.ipv4.frag_offset); + + if caps.checksum.ipv4.tx() { + packet.fill_checksum(); + } + + tx_buffer[frag.ipv4.repr.buffer_len()..][..payload_len].copy_from_slice( + &frag.buffer[frag.ipv4.frag_offset as usize + frag.ipv4.repr.buffer_len()..] + [..payload_len], + ); + + // Update the frag offset for the next fragment. + frag.ipv4.frag_offset += payload_len as u16; + }) + } + + #[cfg(feature = "proto-igmp")] + pub(super) fn igmp_report_packet<'any>( + &self, + version: IgmpVersion, + group_addr: Ipv4Address, + ) -> Option> { + let iface_addr = self.ipv4_addr()?; + let igmp_repr = IgmpRepr::MembershipReport { + group_addr, + version, + }; + let pkt = IpPacket::Igmp(( + Ipv4Repr { + src_addr: iface_addr, + // Send to the group being reported + dst_addr: group_addr, + next_header: IpProtocol::Igmp, + payload_len: igmp_repr.buffer_len(), + hop_limit: 1, + // [#183](https://github.com/m-labs/smoltcp/issues/183). + }, + igmp_repr, + )); + Some(pkt) + } + + #[cfg(feature = "proto-igmp")] + pub(super) fn igmp_leave_packet<'any>( + &self, + group_addr: Ipv4Address, + ) -> Option> { + self.ipv4_addr().map(|iface_addr| { + let igmp_repr = IgmpRepr::LeaveGroup { group_addr }; + IpPacket::Igmp(( + Ipv4Repr { + src_addr: iface_addr, + dst_addr: Ipv4Address::MULTICAST_ALL_ROUTERS, + next_header: IpProtocol::Igmp, + payload_len: igmp_repr.buffer_len(), + hop_limit: 1, + }, + igmp_repr, + )) + }) + } +} diff --git a/src/iface/interface/ipv6.rs b/src/iface/interface/ipv6.rs new file mode 100644 index 000000000..5355316a4 --- /dev/null +++ b/src/iface/interface/ipv6.rs @@ -0,0 +1,309 @@ +use super::check; +use super::icmp_reply_payload_len; +use super::InterfaceInner; +use super::IpPacket; +use super::SocketSet; + +#[cfg(feature = "socket-icmp")] +use crate::socket::icmp; +use crate::socket::AnySocket; + +use crate::phy::PacketMeta; +use crate::wire::*; + +impl InterfaceInner { + #[cfg(feature = "proto-ipv6")] + pub(super) fn process_ipv6<'frame, T: AsRef<[u8]> + ?Sized>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv6_packet: &Ipv6Packet<&'frame T>, + ) -> Option> { + let ipv6_repr = check!(Ipv6Repr::parse(ipv6_packet)); + + if !ipv6_repr.src_addr.is_unicast() { + // Discard packets with non-unicast source addresses. + net_debug!("non-unicast source address"); + return None; + } + + let ip_payload = ipv6_packet.payload(); + + #[cfg(feature = "socket-raw")] + let handled_by_raw_socket = self.raw_socket_filter(sockets, &ipv6_repr.into(), ip_payload); + #[cfg(not(feature = "socket-raw"))] + let handled_by_raw_socket = false; + + self.process_nxt_hdr( + sockets, + meta, + ipv6_repr, + ipv6_repr.next_header, + handled_by_raw_socket, + ip_payload, + ) + } + + /// Given the next header value forward the payload onto the correct process + /// function. + #[cfg(feature = "proto-ipv6")] + pub(super) fn process_nxt_hdr<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv6_repr: Ipv6Repr, + nxt_hdr: IpProtocol, + handled_by_raw_socket: bool, + ip_payload: &'frame [u8], + ) -> Option> { + match nxt_hdr { + IpProtocol::Icmpv6 => self.process_icmpv6(sockets, ipv6_repr.into(), ip_payload), + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpProtocol::Udp => { + let udp_packet = check!(UdpPacket::new_checked(ip_payload)); + let udp_repr = check!(UdpRepr::parse( + &udp_packet, + &ipv6_repr.src_addr.into(), + &ipv6_repr.dst_addr.into(), + &self.checksum_caps(), + )); + + self.process_udp( + sockets, + meta, + ipv6_repr.into(), + udp_repr, + handled_by_raw_socket, + udp_packet.payload(), + ip_payload, + ) + } + + #[cfg(feature = "socket-tcp")] + IpProtocol::Tcp => self.process_tcp(sockets, ipv6_repr.into(), ip_payload), + + IpProtocol::HopByHop => { + self.process_hopbyhop(sockets, meta, ipv6_repr, handled_by_raw_socket, ip_payload) + } + + #[cfg(feature = "socket-raw")] + _ if handled_by_raw_socket => None, + + _ => { + // Send back as much of the original payload as we can. + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, ipv6_repr.buffer_len()); + let icmp_reply_repr = Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, + // The offending packet is after the IPv6 header. + pointer: ipv6_repr.buffer_len() as u32, + header: ipv6_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv6_reply(ipv6_repr, icmp_reply_repr) + } + } + } + + #[cfg(feature = "proto-ipv6")] + pub(super) fn process_icmpv6<'frame>( + &mut self, + _sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option> { + let icmp_packet = check!(Icmpv6Packet::new_checked(ip_payload)); + let icmp_repr = check!(Icmpv6Repr::parse( + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &icmp_packet, + &self.caps.checksum, + )); + + #[cfg(feature = "socket-icmp")] + let mut handled_by_icmp_socket = false; + + #[cfg(all(feature = "socket-icmp", feature = "proto-ipv6"))] + for icmp_socket in _sockets + .items_mut() + .filter_map(|i| icmp::Socket::downcast_mut(&mut i.socket)) + { + if icmp_socket.accepts(self, &ip_repr, &icmp_repr.into()) { + icmp_socket.process(self, &ip_repr, &icmp_repr.into()); + handled_by_icmp_socket = true; + } + } + + match icmp_repr { + // Respond to echo requests. + Icmpv6Repr::EchoRequest { + ident, + seq_no, + data, + } => match ip_repr { + IpRepr::Ipv6(ipv6_repr) => { + let icmp_reply_repr = Icmpv6Repr::EchoReply { + ident, + seq_no, + data, + }; + self.icmpv6_reply(ipv6_repr, icmp_reply_repr) + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + }, + + // Ignore any echo replies. + Icmpv6Repr::EchoReply { .. } => None, + + // Forward any NDISC packets to the ndisc packet handler + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + Icmpv6Repr::Ndisc(repr) if ip_repr.hop_limit() == 0xff => match ip_repr { + IpRepr::Ipv6(ipv6_repr) => self.process_ndisc(ipv6_repr, repr), + #[allow(unreachable_patterns)] + _ => unreachable!(), + }, + + // Don't report an error if a packet with unknown type + // has been handled by an ICMP socket + #[cfg(feature = "socket-icmp")] + _ if handled_by_icmp_socket => None, + + // FIXME: do something correct here? + _ => None, + } + } + + #[cfg(all( + any(feature = "medium-ethernet", feature = "medium-ieee802154"), + feature = "proto-ipv6" + ))] + pub(super) fn process_ndisc<'frame>( + &mut self, + ip_repr: Ipv6Repr, + repr: NdiscRepr<'frame>, + ) -> Option> { + match repr { + NdiscRepr::NeighborAdvert { + lladdr, + target_addr, + flags, + } => { + let ip_addr = ip_repr.src_addr.into(); + if let Some(lladdr) = lladdr { + let lladdr = check!(lladdr.parse(self.caps.medium)); + if !lladdr.is_unicast() || !target_addr.is_unicast() { + return None; + } + if flags.contains(NdiscNeighborFlags::OVERRIDE) + || !self.neighbor_cache.lookup(&ip_addr, self.now).found() + { + self.neighbor_cache.fill(ip_addr, lladdr, self.now) + } + } + None + } + NdiscRepr::NeighborSolicit { + target_addr, + lladdr, + .. + } => { + if let Some(lladdr) = lladdr { + let lladdr = check!(lladdr.parse(self.caps.medium)); + if !lladdr.is_unicast() || !target_addr.is_unicast() { + return None; + } + self.neighbor_cache + .fill(ip_repr.src_addr.into(), lladdr, self.now); + } + + if self.has_solicited_node(ip_repr.dst_addr) && self.has_ip_addr(target_addr) { + let advert = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr, + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + lladdr: Some(self.hardware_addr.into()), + }); + let ip_repr = Ipv6Repr { + src_addr: target_addr, + dst_addr: ip_repr.src_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: advert.buffer_len(), + }; + Some(IpPacket::Icmpv6((ip_repr, advert))) + } else { + None + } + } + _ => None, + } + } + + #[cfg(feature = "proto-ipv6")] + pub(super) fn process_hopbyhop<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ipv6_repr: Ipv6Repr, + handled_by_raw_socket: bool, + ip_payload: &'frame [u8], + ) -> Option> { + let hbh_hdr = check!(Ipv6HopByHopHeader::new_checked(ip_payload)); + let hbh_repr = check!(Ipv6HopByHopRepr::parse(&hbh_hdr)); + + let hbh_options = Ipv6OptionsIterator::new(hbh_repr.data); + for opt_repr in hbh_options { + let opt_repr = check!(opt_repr); + match opt_repr { + Ipv6OptionRepr::Pad1 | Ipv6OptionRepr::PadN(_) => (), + #[cfg(feature = "proto-rpl")] + Ipv6OptionRepr::Rpl(_) => {} + + Ipv6OptionRepr::Unknown { type_, .. } => { + match Ipv6OptionFailureType::from(type_) { + Ipv6OptionFailureType::Skip => (), + Ipv6OptionFailureType::Discard => { + return None; + } + _ => { + // FIXME(dlrobertson): Send an ICMPv6 parameter problem message + // here. + return None; + } + } + } + } + } + self.process_nxt_hdr( + sockets, + meta, + ipv6_repr, + hbh_repr.next_header, + handled_by_raw_socket, + &ip_payload[hbh_repr.header_len() + hbh_repr.data.len()..], + ) + } + + #[cfg(feature = "proto-ipv6")] + pub(super) fn icmpv6_reply<'frame, 'icmp: 'frame>( + &self, + ipv6_repr: Ipv6Repr, + icmp_repr: Icmpv6Repr<'icmp>, + ) -> Option> { + if ipv6_repr.dst_addr.is_unicast() { + let ipv6_reply_repr = Ipv6Repr { + src_addr: ipv6_repr.dst_addr, + dst_addr: ipv6_repr.src_addr, + next_header: IpProtocol::Icmpv6, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }; + Some(IpPacket::Icmpv6((ipv6_reply_repr, icmp_repr))) + } else { + // Do not send any ICMP replies to a broadcast destination address. + None + } + } +} diff --git a/src/iface/interface/mod.rs b/src/iface/interface/mod.rs new file mode 100644 index 000000000..6f0600875 --- /dev/null +++ b/src/iface/interface/mod.rs @@ -0,0 +1,1885 @@ +// Heads up! Before working on this file you should read the parts +// of RFC 1122 that discuss Ethernet, ARP and IP for any IPv4 work +// and RFCs 8200 and 4861 for any IPv6 and NDISC work. + +#[cfg(test)] +mod tests; + +#[cfg(feature = "medium-ethernet")] +mod ethernet; +#[cfg(feature = "medium-ieee802154")] +mod ieee802154; + +#[cfg(feature = "proto-ipv4")] +mod ipv4; +#[cfg(feature = "proto-ipv6")] +mod ipv6; +#[cfg(feature = "proto-sixlowpan")] +mod sixlowpan; + +#[cfg(feature = "proto-igmp")] +mod igmp; + +#[cfg(feature = "proto-igmp")] +pub use igmp::MulticastError; + +use core::cmp; +use core::result::Result; +use heapless::{LinearMap, Vec}; + +#[cfg(any(feature = "proto-ipv4", feature = "proto-sixlowpan"))] +use super::fragmentation::PacketAssemblerSet; +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +use super::neighbor::{Answer as NeighborAnswer, Cache as NeighborCache}; +use super::socket_set::SocketSet; +use crate::config::{ + FRAGMENTATION_BUFFER_SIZE, IFACE_MAX_ADDR_COUNT, IFACE_MAX_MULTICAST_GROUP_COUNT, + IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT, +}; +use crate::iface::Routes; +use crate::phy::PacketMeta; +use crate::phy::{ChecksumCapabilities, Device, DeviceCapabilities, Medium, RxToken, TxToken}; +use crate::rand::Rand; +#[cfg(feature = "socket-dns")] +use crate::socket::dns; +use crate::socket::*; +use crate::time::{Duration, Instant}; +use crate::wire::*; + +#[cfg(feature = "_proto-fragmentation")] +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum FragKey { + #[cfg(feature = "proto-ipv4-fragmentation")] + Ipv4(Ipv4FragKey), + #[cfg(feature = "proto-sixlowpan-fragmentation")] + Sixlowpan(SixlowpanFragKey), +} + +pub(crate) struct FragmentsBuffer { + #[cfg(feature = "proto-sixlowpan")] + decompress_buf: [u8; sixlowpan::MAX_DECOMPRESSED_LEN], + + #[cfg(feature = "_proto-fragmentation")] + pub(crate) assembler: PacketAssemblerSet, + + #[cfg(feature = "_proto-fragmentation")] + reassembly_timeout: Duration, +} + +#[cfg(not(feature = "_proto-fragmentation"))] +pub(crate) struct Fragmenter {} + +#[cfg(not(feature = "_proto-fragmentation"))] +impl Fragmenter { + pub(crate) fn new() -> Self { + Self {} + } +} + +#[cfg(feature = "_proto-fragmentation")] +pub(crate) struct Fragmenter { + /// The buffer that holds the unfragmented 6LoWPAN packet. + buffer: [u8; FRAGMENTATION_BUFFER_SIZE], + /// The size of the packet without the IEEE802.15.4 header and the fragmentation headers. + packet_len: usize, + /// The amount of bytes that already have been transmitted. + sent_bytes: usize, + + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4: Ipv4Fragmenter, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + sixlowpan: SixlowpanFragmenter, +} + +#[cfg(feature = "proto-ipv4-fragmentation")] +pub(crate) struct Ipv4Fragmenter { + /// The IPv4 representation. + repr: Ipv4Repr, + /// The destination hardware address. + #[cfg(feature = "medium-ethernet")] + dst_hardware_addr: EthernetAddress, + /// The offset of the next fragment. + frag_offset: u16, + /// The identifier of the stream. + ident: u16, +} + +#[cfg(feature = "proto-sixlowpan-fragmentation")] +pub(crate) struct SixlowpanFragmenter { + /// The datagram size that is used for the fragmentation headers. + datagram_size: u16, + /// The datagram tag that is used for the fragmentation headers. + datagram_tag: u16, + datagram_offset: usize, + + /// The size of the FRAG_N packets. + fragn_size: usize, + + /// The link layer IEEE802.15.4 source address. + ll_dst_addr: Ieee802154Address, + /// The link layer IEEE802.15.4 source address. + ll_src_addr: Ieee802154Address, +} + +#[cfg(feature = "_proto-fragmentation")] +impl Fragmenter { + pub(crate) fn new() -> Self { + Self { + buffer: [0u8; FRAGMENTATION_BUFFER_SIZE], + packet_len: 0, + sent_bytes: 0, + + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4: Ipv4Fragmenter { + repr: Ipv4Repr { + src_addr: Ipv4Address::default(), + dst_addr: Ipv4Address::default(), + next_header: IpProtocol::Unknown(0), + payload_len: 0, + hop_limit: 0, + }, + #[cfg(feature = "medium-ethernet")] + dst_hardware_addr: EthernetAddress::default(), + frag_offset: 0, + ident: 0, + }, + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + sixlowpan: SixlowpanFragmenter { + datagram_size: 0, + datagram_tag: 0, + datagram_offset: 0, + fragn_size: 0, + ll_dst_addr: Ieee802154Address::Absent, + ll_src_addr: Ieee802154Address::Absent, + }, + } + } + + /// Return `true` when everything is transmitted. + #[inline] + fn finished(&self) -> bool { + self.packet_len == self.sent_bytes + } + + /// Returns `true` when there is nothing to transmit. + #[inline] + fn is_empty(&self) -> bool { + self.packet_len == 0 + } + + // Reset the buffer. + fn reset(&mut self) { + self.packet_len = 0; + self.sent_bytes = 0; + + #[cfg(feature = "proto-ipv4-fragmentation")] + { + self.ipv4.repr = Ipv4Repr { + src_addr: Ipv4Address::default(), + dst_addr: Ipv4Address::default(), + next_header: IpProtocol::Unknown(0), + payload_len: 0, + hop_limit: 0, + }; + #[cfg(feature = "medium-ethernet")] + { + self.ipv4.dst_hardware_addr = EthernetAddress::default(); + } + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + { + self.sixlowpan.datagram_size = 0; + self.sixlowpan.datagram_tag = 0; + self.sixlowpan.fragn_size = 0; + self.sixlowpan.ll_dst_addr = Ieee802154Address::Absent; + self.sixlowpan.ll_src_addr = Ieee802154Address::Absent; + } + } +} + +macro_rules! check { + ($e:expr) => { + match $e { + Ok(x) => x, + Err(_) => { + // concat!/stringify! doesn't work with defmt macros + #[cfg(not(feature = "defmt"))] + net_trace!(concat!("iface: malformed ", stringify!($e))); + #[cfg(feature = "defmt")] + net_trace!("iface: malformed"); + return Default::default(); + } + } + }; +} +use check; + +/// A network interface. +/// +/// The network interface logically owns a number of other data structures; to avoid +/// a dependency on heap allocation, it instead owns a `BorrowMut<[T]>`, which can be +/// a `&mut [T]`, or `Vec` if a heap is available. +pub struct Interface { + inner: InterfaceInner, + fragments: FragmentsBuffer, + fragmenter: Fragmenter, +} + +/// The device independent part of an Ethernet network interface. +/// +/// Separating the device from the data required for processing and dispatching makes +/// it possible to borrow them independently. For example, the tx and rx tokens borrow +/// the `device` mutably until they're used, which makes it impossible to call other +/// methods on the `Interface` in this time (since its `device` field is borrowed +/// exclusively). However, it is still possible to call methods on its `inner` field. +pub struct InterfaceInner { + caps: DeviceCapabilities, + now: Instant, + rand: Rand, + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + neighbor_cache: NeighborCache, + hardware_addr: HardwareAddress, + #[cfg(feature = "medium-ieee802154")] + sequence_no: u8, + #[cfg(feature = "medium-ieee802154")] + pan_id: Option, + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4_id: u16, + #[cfg(feature = "proto-sixlowpan")] + sixlowpan_address_context: + Vec, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + tag: u16, + ip_addrs: Vec, + #[cfg(feature = "proto-ipv4")] + any_ip: bool, + routes: Routes, + #[cfg(feature = "proto-igmp")] + ipv4_multicast_groups: LinearMap, + /// When to report for (all or) the next multicast group membership via IGMP + #[cfg(feature = "proto-igmp")] + igmp_report_state: IgmpReportState, +} + +/// Configuration structure used for creating a network interface. +#[non_exhaustive] +pub struct Config { + /// Random seed. + /// + /// It is strongly recommended that the random seed is different on each boot, + /// to avoid problems with TCP port/sequence collisions. + /// + /// The seed doesn't have to be cryptographically secure. + pub random_seed: u64, + + /// Set the Hardware address the interface will use. + /// + /// # Panics + /// Creating the interface panics if the address is not unicast. + pub hardware_addr: HardwareAddress, + + /// Set the IEEE802.15.4 PAN ID the interface will use. + /// + /// **NOTE**: we use the same PAN ID for destination and source. + #[cfg(feature = "medium-ieee802154")] + pub pan_id: Option, +} + +impl Config { + pub fn new(hardware_addr: HardwareAddress) -> Self { + Config { + random_seed: 0, + hardware_addr, + #[cfg(feature = "medium-ieee802154")] + pan_id: None, + } + } +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[cfg(feature = "medium-ethernet")] +enum EthernetPacket<'a> { + #[cfg(feature = "proto-ipv4")] + Arp(ArpRepr), + Ip(IpPacket<'a>), +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) enum IpPacket<'a> { + #[cfg(feature = "proto-ipv4")] + Icmpv4((Ipv4Repr, Icmpv4Repr<'a>)), + #[cfg(feature = "proto-igmp")] + Igmp((Ipv4Repr, IgmpRepr)), + #[cfg(feature = "proto-ipv6")] + Icmpv6((Ipv6Repr, Icmpv6Repr<'a>)), + #[cfg(feature = "socket-raw")] + Raw((IpRepr, &'a [u8])), + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + Udp((IpRepr, UdpRepr, &'a [u8])), + #[cfg(feature = "socket-tcp")] + Tcp((IpRepr, TcpRepr<'a>)), + #[cfg(feature = "socket-dhcpv4")] + Dhcpv4((Ipv4Repr, UdpRepr, DhcpRepr<'a>)), +} + +impl<'a> IpPacket<'a> { + pub(crate) fn ip_repr(&self) -> IpRepr { + match self { + #[cfg(feature = "proto-ipv4")] + IpPacket::Icmpv4((ipv4_repr, _)) => IpRepr::Ipv4(*ipv4_repr), + #[cfg(feature = "proto-igmp")] + IpPacket::Igmp((ipv4_repr, _)) => IpRepr::Ipv4(*ipv4_repr), + #[cfg(feature = "proto-ipv6")] + IpPacket::Icmpv6((ipv6_repr, _)) => IpRepr::Ipv6(*ipv6_repr), + #[cfg(feature = "socket-raw")] + IpPacket::Raw((ip_repr, _)) => ip_repr.clone(), + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpPacket::Udp((ip_repr, _, _)) => ip_repr.clone(), + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp((ip_repr, _)) => ip_repr.clone(), + #[cfg(feature = "socket-dhcpv4")] + IpPacket::Dhcpv4((ipv4_repr, _, _)) => IpRepr::Ipv4(*ipv4_repr), + } + } + + pub(crate) fn emit_payload( + &self, + _ip_repr: &IpRepr, + payload: &mut [u8], + caps: &DeviceCapabilities, + ) { + match self { + #[cfg(feature = "proto-ipv4")] + IpPacket::Icmpv4((_, icmpv4_repr)) => { + icmpv4_repr.emit(&mut Icmpv4Packet::new_unchecked(payload), &caps.checksum) + } + #[cfg(feature = "proto-igmp")] + IpPacket::Igmp((_, igmp_repr)) => { + igmp_repr.emit(&mut IgmpPacket::new_unchecked(payload)) + } + #[cfg(feature = "proto-ipv6")] + IpPacket::Icmpv6((_, icmpv6_repr)) => icmpv6_repr.emit( + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + &mut Icmpv6Packet::new_unchecked(payload), + &caps.checksum, + ), + #[cfg(feature = "socket-raw")] + IpPacket::Raw((_, raw_packet)) => payload.copy_from_slice(raw_packet), + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + IpPacket::Udp((_, udp_repr, inner_payload)) => { + udp_repr.emit( + &mut UdpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + inner_payload.len(), + |buf| buf.copy_from_slice(inner_payload), + &caps.checksum, + ); + } + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp((_, mut tcp_repr)) => { + // This is a terrible hack to make TCP performance more acceptable on systems + // where the TCP buffers are significantly larger than network buffers, + // e.g. a 64 kB TCP receive buffer (and so, when empty, a 64k window) + // together with four 1500 B Ethernet receive buffers. If left untreated, + // this would result in our peer pushing our window and sever packet loss. + // + // I'm really not happy about this "solution" but I don't know what else to do. + if let Some(max_burst_size) = caps.max_burst_size { + let mut max_segment_size = caps.max_transmission_unit; + max_segment_size -= _ip_repr.header_len(); + max_segment_size -= tcp_repr.header_len(); + + let max_window_size = max_burst_size * max_segment_size; + if tcp_repr.window_len as usize > max_window_size { + tcp_repr.window_len = max_window_size as u16; + } + } + + tcp_repr.emit( + &mut TcpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + &caps.checksum, + ); + } + #[cfg(feature = "socket-dhcpv4")] + IpPacket::Dhcpv4((_, udp_repr, dhcp_repr)) => udp_repr.emit( + &mut UdpPacket::new_unchecked(payload), + &_ip_repr.src_addr(), + &_ip_repr.dst_addr(), + dhcp_repr.buffer_len(), + |buf| dhcp_repr.emit(&mut DhcpPacket::new_unchecked(buf)).unwrap(), + &caps.checksum, + ), + }; + } +} + +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] +fn icmp_reply_payload_len(len: usize, mtu: usize, header_len: usize) -> usize { + // Send back as much of the original payload as will fit within + // the minimum MTU required by IPv4. See RFC 1812 § 4.3.2.3 for + // more details. + // + // Since the entire network layer packet must fit within the minimum + // MTU supported, the payload must not exceed the following: + // + // - IP Header Size * 2 - ICMPv4 DstUnreachable hdr size + cmp::min(len, mtu - header_len * 2 - 8) +} + +#[cfg(feature = "proto-igmp")] +enum IgmpReportState { + Inactive, + ToGeneralQuery { + version: IgmpVersion, + timeout: Instant, + interval: Duration, + next_index: usize, + }, + ToSpecificQuery { + version: IgmpVersion, + timeout: Instant, + group: Ipv4Address, + }, +} + +impl Interface { + /// Create a network interface using the previously provided configuration. + /// + /// # Panics + /// This function panics if the [`Config::hardware_address`] does not match + /// the medium of the device. + pub fn new(config: Config, device: &mut D, now: Instant) -> Self + where + D: Device + ?Sized, + { + let caps = device.capabilities(); + assert_eq!( + config.hardware_addr.medium(), + caps.medium, + "The hardware address does not match the medium of the interface." + ); + + let mut rand = Rand::new(config.random_seed); + + #[cfg(feature = "medium-ieee802154")] + let mut sequence_no; + #[cfg(feature = "medium-ieee802154")] + loop { + sequence_no = (rand.rand_u32() & 0xff) as u8; + if sequence_no != 0 { + break; + } + } + + #[cfg(feature = "proto-sixlowpan")] + let mut tag; + + #[cfg(feature = "proto-sixlowpan")] + loop { + tag = rand.rand_u16(); + if tag != 0 { + break; + } + } + + #[cfg(feature = "proto-ipv4")] + let mut ipv4_id; + + #[cfg(feature = "proto-ipv4")] + loop { + ipv4_id = rand.rand_u16(); + if ipv4_id != 0 { + break; + } + } + + Interface { + fragments: FragmentsBuffer { + #[cfg(feature = "proto-sixlowpan")] + decompress_buf: [0u8; sixlowpan::MAX_DECOMPRESSED_LEN], + + #[cfg(feature = "_proto-fragmentation")] + assembler: PacketAssemblerSet::new(), + #[cfg(feature = "_proto-fragmentation")] + reassembly_timeout: Duration::from_secs(60), + }, + fragmenter: Fragmenter::new(), + inner: InterfaceInner { + now, + caps, + hardware_addr: config.hardware_addr, + ip_addrs: Vec::new(), + #[cfg(feature = "proto-ipv4")] + any_ip: false, + routes: Routes::new(), + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + neighbor_cache: NeighborCache::new(), + #[cfg(feature = "proto-igmp")] + ipv4_multicast_groups: LinearMap::new(), + #[cfg(feature = "proto-igmp")] + igmp_report_state: IgmpReportState::Inactive, + #[cfg(feature = "medium-ieee802154")] + sequence_no, + #[cfg(feature = "medium-ieee802154")] + pan_id: config.pan_id, + #[cfg(feature = "proto-sixlowpan-fragmentation")] + tag, + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4_id, + #[cfg(feature = "proto-sixlowpan")] + sixlowpan_address_context: Vec::new(), + rand, + }, + } + } + + /// Get the socket context. + /// + /// The context is needed for some socket methods. + pub fn context(&mut self) -> &mut InterfaceInner { + &mut self.inner + } + + /// Get the HardwareAddress address of the interface. + /// + /// # Panics + /// This function panics if the medium is not Ethernet or Ieee802154. + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + pub fn hardware_addr(&self) -> HardwareAddress { + #[cfg(all(feature = "medium-ethernet", not(feature = "medium-ieee802154")))] + assert!(self.inner.caps.medium == Medium::Ethernet); + #[cfg(all(feature = "medium-ieee802154", not(feature = "medium-ethernet")))] + assert!(self.inner.caps.medium == Medium::Ieee802154); + + #[cfg(all(feature = "medium-ieee802154", feature = "medium-ethernet"))] + assert!( + self.inner.caps.medium == Medium::Ethernet + || self.inner.caps.medium == Medium::Ieee802154 + ); + + self.inner.hardware_addr + } + + /// Set the HardwareAddress address of the interface. + /// + /// # Panics + /// This function panics if the address is not unicast, and if the medium is not Ethernet or + /// Ieee802154. + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + pub fn set_hardware_addr(&mut self, addr: HardwareAddress) { + #[cfg(all(feature = "medium-ethernet", not(feature = "medium-ieee802154")))] + assert!(self.inner.caps.medium == Medium::Ethernet); + #[cfg(all(feature = "medium-ieee802154", not(feature = "medium-ethernet")))] + assert!(self.inner.caps.medium == Medium::Ieee802154); + + #[cfg(all(feature = "medium-ieee802154", feature = "medium-ethernet"))] + assert!( + self.inner.caps.medium == Medium::Ethernet + || self.inner.caps.medium == Medium::Ieee802154 + ); + + InterfaceInner::check_hardware_addr(&addr); + self.inner.hardware_addr = addr; + } + + /// Get the IP addresses of the interface. + pub fn ip_addrs(&self) -> &[IpCidr] { + self.inner.ip_addrs.as_ref() + } + + /// Get the first IPv4 address if present. + #[cfg(feature = "proto-ipv4")] + pub fn ipv4_addr(&self) -> Option { + self.inner.ipv4_addr() + } + + /// Get the first IPv6 address if present. + #[cfg(feature = "proto-ipv6")] + pub fn ipv6_addr(&self) -> Option { + self.inner.ipv6_addr() + } + + /// Update the IP addresses of the interface. + /// + /// # Panics + /// This function panics if any of the addresses are not unicast. + pub fn update_ip_addrs)>(&mut self, f: F) { + f(&mut self.inner.ip_addrs); + InterfaceInner::flush_cache(&mut self.inner); + InterfaceInner::check_ip_addrs(&self.inner.ip_addrs) + } + + /// Check whether the interface has the given IP address assigned. + pub fn has_ip_addr>(&self, addr: T) -> bool { + self.inner.has_ip_addr(addr) + } + + pub fn routes(&self) -> &Routes { + &self.inner.routes + } + + pub fn routes_mut(&mut self) -> &mut Routes { + &mut self.inner.routes + } + + /// Enable or disable the AnyIP capability. + /// + /// AnyIP allowins packets to be received + /// locally on IPv4 addresses other than the interface's configured [ip_addrs]. + /// When AnyIP is enabled and a route prefix in [`routes`](Self::routes) specifies one of + /// the interface's [`ip_addrs`](Self::ip_addrs) as its gateway, the interface will accept + /// packets addressed to that prefix. + /// + /// # IPv6 + /// + /// This option is not available or required for IPv6 as packets sent to + /// the interface are not filtered by IPv6 address. + #[cfg(feature = "proto-ipv4")] + pub fn set_any_ip(&mut self, any_ip: bool) { + self.inner.any_ip = any_ip; + } + + /// Get whether AnyIP is enabled. + /// + /// See [`set_any_ip`](Self::set_any_ip) for details on AnyIP + #[cfg(feature = "proto-ipv4")] + pub fn any_ip(&self) -> bool { + self.inner.any_ip + } + + /// Get the 6LoWPAN address contexts. + #[cfg(feature = "proto-sixlowpan")] + pub fn sixlowpan_address_context( + &self, + ) -> &Vec { + &self.inner.sixlowpan_address_context + } + + /// Get a mutable reference to the 6LoWPAN address contexts. + #[cfg(feature = "proto-sixlowpan")] + pub fn sixlowpan_address_context_mut( + &mut self, + ) -> &mut Vec { + &mut self.inner.sixlowpan_address_context + } + + /// Get the packet reassembly timeout. + #[cfg(feature = "_proto-fragmentation")] + pub fn reassembly_timeout(&self) -> Duration { + self.fragments.reassembly_timeout + } + + /// Set the packet reassembly timeout. + #[cfg(feature = "_proto-fragmentation")] + pub fn set_reassembly_timeout(&mut self, timeout: Duration) { + if timeout > Duration::from_secs(60) { + net_debug!("RFC 4944 specifies that the reassembly timeout MUST be set to a maximum of 60 seconds"); + } + self.fragments.reassembly_timeout = timeout; + } + + /// Transmit packets queued in the given sockets, and receive packets queued + /// in the device. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + pub fn poll( + &mut self, + timestamp: Instant, + device: &mut D, + sockets: &mut SocketSet<'_>, + ) -> bool + where + D: Device + ?Sized, + { + self.inner.now = timestamp; + + #[cfg(feature = "_proto-fragmentation")] + self.fragments.assembler.remove_expired(timestamp); + + match self.inner.caps.medium { + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => + { + #[cfg(feature = "proto-sixlowpan-fragmentation")] + if self.sixlowpan_egress(device) { + return true; + } + } + #[cfg(any(feature = "medium-ethernet", feature = "medium-ip"))] + _ => + { + #[cfg(feature = "proto-ipv4-fragmentation")] + if self.ipv4_egress(device) { + return true; + } + } + } + + let mut readiness_may_have_changed = false; + + loop { + let mut did_something = false; + did_something |= self.socket_ingress(device, sockets); + did_something |= self.socket_egress(device, sockets); + + #[cfg(feature = "proto-igmp")] + { + did_something |= self.igmp_egress(device); + } + + if did_something { + readiness_may_have_changed = true; + } else { + break; + } + } + + readiness_may_have_changed + } + + /// Return a _soft deadline_ for calling [poll] the next time. + /// The [Instant] returned is the time at which you should call [poll] next. + /// It is harmless (but wastes energy) to call it before the [Instant], and + /// potentially harmful (impacting quality of service) to call it after the + /// [Instant] + /// + /// [poll]: #method.poll + /// [Instant]: struct.Instant.html + pub fn poll_at(&mut self, timestamp: Instant, sockets: &SocketSet<'_>) -> Option { + self.inner.now = timestamp; + + #[cfg(feature = "_proto-fragmentation")] + if !self.fragmenter.is_empty() { + return Some(Instant::from_millis(0)); + } + + let inner = &mut self.inner; + + sockets + .items() + .filter_map(move |item| { + let socket_poll_at = item.socket.poll_at(inner); + match item + .meta + .poll_at(socket_poll_at, |ip_addr| inner.has_neighbor(&ip_addr)) + { + PollAt::Ingress => None, + PollAt::Time(instant) => Some(instant), + PollAt::Now => Some(Instant::from_millis(0)), + } + }) + .min() + } + + /// Return an _advisory wait time_ for calling [poll] the next time. + /// The [Duration] returned is the time left to wait before calling [poll] next. + /// It is harmless (but wastes energy) to call it before the [Duration] has passed, + /// and potentially harmful (impacting quality of service) to call it after the + /// [Duration] has passed. + /// + /// [poll]: #method.poll + /// [Duration]: struct.Duration.html + pub fn poll_delay(&mut self, timestamp: Instant, sockets: &SocketSet<'_>) -> Option { + match self.poll_at(timestamp, sockets) { + Some(poll_at) if timestamp < poll_at => Some(poll_at - timestamp), + Some(_) => Some(Duration::from_millis(0)), + _ => None, + } + } + + fn socket_ingress(&mut self, device: &mut D, sockets: &mut SocketSet<'_>) -> bool + where + D: Device + ?Sized, + { + let mut processed_any = false; + + while let Some((rx_token, tx_token)) = device.receive(self.inner.now) { + let rx_meta = rx_token.meta(); + rx_token.consume(|frame| { + match self.inner.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + if let Some(packet) = self.inner.process_ethernet( + sockets, + rx_meta, + &frame, + &mut self.fragments, + ) { + if let Err(err) = + self.inner.dispatch(tx_token, packet, &mut self.fragmenter) + { + net_debug!("Failed to send response: {:?}", err); + } + } + } + #[cfg(feature = "medium-ip")] + Medium::Ip => { + if let Some(packet) = + self.inner + .process_ip(sockets, rx_meta, &frame, &mut self.fragments) + { + if let Err(err) = self.inner.dispatch_ip( + tx_token, + PacketMeta::default(), + packet, + &mut self.fragmenter, + ) { + net_debug!("Failed to send response: {:?}", err); + } + } + } + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + if let Some(packet) = self.inner.process_ieee802154( + sockets, + rx_meta, + &frame, + &mut self.fragments, + ) { + if let Err(err) = self.inner.dispatch_ip( + tx_token, + PacketMeta::default(), + packet, + &mut self.fragmenter, + ) { + net_debug!("Failed to send response: {:?}", err); + } + } + } + } + processed_any = true; + }); + } + + processed_any + } + + fn socket_egress(&mut self, device: &mut D, sockets: &mut SocketSet<'_>) -> bool + where + D: Device + ?Sized, + { + let _caps = device.capabilities(); + + enum EgressError { + Exhausted, + Dispatch(DispatchError), + } + + let mut emitted_any = false; + for item in sockets.items_mut() { + if !item + .meta + .egress_permitted(self.inner.now, |ip_addr| self.inner.has_neighbor(&ip_addr)) + { + continue; + } + + let mut neighbor_addr = None; + let mut respond = |inner: &mut InterfaceInner, meta: PacketMeta, response: IpPacket| { + neighbor_addr = Some(response.ip_repr().dst_addr()); + let t = device.transmit(inner.now).ok_or_else(|| { + net_debug!("failed to transmit IP: device exhausted"); + EgressError::Exhausted + })?; + + inner + .dispatch_ip(t, meta, response, &mut self.fragmenter) + .map_err(EgressError::Dispatch)?; + + emitted_any = true; + + Ok(()) + }; + + let result = match &mut item.socket { + #[cfg(feature = "socket-raw")] + Socket::Raw(socket) => socket.dispatch(&mut self.inner, |inner, response| { + respond(inner, PacketMeta::default(), IpPacket::Raw(response)) + }), + #[cfg(feature = "socket-icmp")] + Socket::Icmp(socket) => { + socket.dispatch(&mut self.inner, |inner, response| match response { + #[cfg(feature = "proto-ipv4")] + (IpRepr::Ipv4(ipv4_repr), IcmpRepr::Ipv4(icmpv4_repr)) => respond( + inner, + PacketMeta::default(), + IpPacket::Icmpv4((ipv4_repr, icmpv4_repr)), + ), + #[cfg(feature = "proto-ipv6")] + (IpRepr::Ipv6(ipv6_repr), IcmpRepr::Ipv6(icmpv6_repr)) => respond( + inner, + PacketMeta::default(), + IpPacket::Icmpv6((ipv6_repr, icmpv6_repr)), + ), + #[allow(unreachable_patterns)] + _ => unreachable!(), + }) + } + #[cfg(feature = "socket-udp")] + Socket::Udp(socket) => socket.dispatch(&mut self.inner, |inner, meta, response| { + respond(inner, meta, IpPacket::Udp(response)) + }), + #[cfg(feature = "socket-tcp")] + Socket::Tcp(socket) => socket.dispatch(&mut self.inner, |inner, response| { + respond(inner, PacketMeta::default(), IpPacket::Tcp(response)) + }), + #[cfg(feature = "socket-dhcpv4")] + Socket::Dhcpv4(socket) => socket.dispatch(&mut self.inner, |inner, response| { + respond(inner, PacketMeta::default(), IpPacket::Dhcpv4(response)) + }), + #[cfg(feature = "socket-dns")] + Socket::Dns(socket) => { + socket.dispatch(&mut self.inner, |inner, (ip, udp, payload)| { + respond( + inner, + PacketMeta::default(), + IpPacket::Udp((ip, udp, payload)), + ) + }) + } + }; + + match result { + Err(EgressError::Exhausted) => break, // Device buffer full. + Err(EgressError::Dispatch(_)) => { + // `NeighborCache` already takes care of rate limiting the neighbor discovery + // requests from the socket. However, without an additional rate limiting + // mechanism, we would spin on every socket that has yet to discover its + // neighbor. + item.meta.neighbor_missing( + self.inner.now, + neighbor_addr.expect("non-IP response packet"), + ); + } + Ok(()) => {} + } + } + emitted_any + } + + /// Process fragments that still need to be sent for IPv4 packets. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + #[cfg(feature = "proto-ipv4-fragmentation")] + fn ipv4_egress(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + // Reset the buffer when we transmitted everything. + if self.fragmenter.finished() { + self.fragmenter.reset(); + } + + if self.fragmenter.is_empty() { + return false; + } + + let pkt = &self.fragmenter; + if pkt.packet_len > pkt.sent_bytes { + if let Some(tx_token) = device.transmit(self.inner.now) { + self.inner + .dispatch_ipv4_frag(tx_token, &mut self.fragmenter); + return true; + } + } + false + } + + /// Process fragments that still need to be sent for 6LoWPAN packets. + /// + /// This function returns a boolean value indicating whether any packets were + /// processed or emitted, and thus, whether the readiness of any socket might + /// have changed. + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn sixlowpan_egress(&mut self, device: &mut D) -> bool + where + D: Device + ?Sized, + { + // Reset the buffer when we transmitted everything. + if self.fragmenter.finished() { + self.fragmenter.reset(); + } + + if self.fragmenter.is_empty() { + return false; + } + + let pkt = &self.fragmenter; + if pkt.packet_len > pkt.sent_bytes { + if let Some(tx_token) = device.transmit(self.inner.now) { + self.inner + .dispatch_ieee802154_frag(tx_token, &mut self.fragmenter); + return true; + } + } + false + } +} + +impl InterfaceInner { + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn now(&self) -> Instant { + self.now + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn hardware_addr(&self) -> HardwareAddress { + self.hardware_addr + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn checksum_caps(&self) -> ChecksumCapabilities { + self.caps.checksum.clone() + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn ip_mtu(&self) -> usize { + self.caps.ip_mtu() + } + + #[allow(unused)] // unused depending on which sockets are enabled, and in tests + pub(crate) fn rand(&mut self) -> &mut Rand { + &mut self.rand + } + + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn get_source_address(&mut self, dst_addr: IpAddress) -> Option { + let v = dst_addr.version(); + for cidr in self.ip_addrs.iter() { + let addr = cidr.address(); + if addr.version() == v { + return Some(addr); + } + } + None + } + + #[cfg(feature = "proto-ipv4")] + #[allow(unused)] + pub(crate) fn get_source_address_ipv4( + &mut self, + _dst_addr: Ipv4Address, + ) -> Option { + for cidr in self.ip_addrs.iter() { + #[allow(irrefutable_let_patterns)] // if only ipv4 is enabled + if let IpCidr::Ipv4(cidr) = cidr { + return Some(cidr.address()); + } + } + None + } + + #[cfg(feature = "proto-ipv6")] + #[allow(unused)] + pub(crate) fn get_source_address_ipv6( + &mut self, + _dst_addr: Ipv6Address, + ) -> Option { + for cidr in self.ip_addrs.iter() { + #[allow(irrefutable_let_patterns)] // if only ipv6 is enabled + if let IpCidr::Ipv6(cidr) = cidr { + return Some(cidr.address()); + } + } + None + } + + #[cfg(test)] + pub(crate) fn mock() -> Self { + Self { + caps: DeviceCapabilities { + #[cfg(feature = "medium-ethernet")] + medium: crate::phy::Medium::Ethernet, + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ip"))] + medium: crate::phy::Medium::Ip, + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ieee802154"))] + medium: crate::phy::Medium::Ieee802154, + + checksum: crate::phy::ChecksumCapabilities { + #[cfg(feature = "proto-ipv4")] + icmpv4: crate::phy::Checksum::Both, + #[cfg(feature = "proto-ipv6")] + icmpv6: crate::phy::Checksum::Both, + ipv4: crate::phy::Checksum::Both, + tcp: crate::phy::Checksum::Both, + udp: crate::phy::Checksum::Both, + }, + max_burst_size: None, + #[cfg(feature = "medium-ethernet")] + max_transmission_unit: 1514, + #[cfg(not(feature = "medium-ethernet"))] + max_transmission_unit: 1500, + }, + now: Instant::from_millis_const(0), + + ip_addrs: Vec::from_slice(&[ + #[cfg(feature = "proto-ipv4")] + IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::new(192, 168, 1, 1), 24)), + #[cfg(feature = "proto-ipv6")] + IpCidr::Ipv6(Ipv6Cidr::new( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]), + 64, + )), + ]) + .unwrap(), + rand: Rand::new(1234), + routes: Routes::new(), + + #[cfg(feature = "proto-ipv4")] + any_ip: false, + + #[cfg(feature = "medium-ieee802154")] + pan_id: Some(crate::wire::Ieee802154Pan(0xabcd)), + #[cfg(feature = "medium-ieee802154")] + sequence_no: 1, + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + tag: 1, + + #[cfg(feature = "proto-sixlowpan")] + sixlowpan_address_context: Vec::new(), + + #[cfg(feature = "proto-ipv4-fragmentation")] + ipv4_id: 1, + + #[cfg(all( + feature = "medium-ip", + not(feature = "medium-ethernet"), + not(feature = "medium-ieee802154") + ))] + hardware_addr: crate::wire::HardwareAddress::Ip, + + #[cfg(feature = "medium-ethernet")] + hardware_addr: crate::wire::HardwareAddress::Ethernet(crate::wire::EthernetAddress([ + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, + ])), + + #[cfg(all( + not(feature = "medium-ip"), + not(feature = "medium-ethernet"), + feature = "medium-ieee802154" + ))] + hardware_addr: crate::wire::HardwareAddress::Ieee802154( + crate::wire::Ieee802154Address::Extended([ + 0x02, 0x02, 0x02, 0x02, 0x02, 0x02, 0x2, 0x2, + ]), + ), + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + neighbor_cache: NeighborCache::new(), + + #[cfg(feature = "proto-igmp")] + igmp_report_state: IgmpReportState::Inactive, + #[cfg(feature = "proto-igmp")] + ipv4_multicast_groups: LinearMap::new(), + } + } + + #[cfg(test)] + #[allow(unused)] // unused depending on which sockets are enabled + pub(crate) fn set_now(&mut self, now: Instant) { + self.now = now + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + fn check_hardware_addr(addr: &HardwareAddress) { + if !addr.is_unicast() { + panic!("Hardware address {addr} is not unicast") + } + } + + fn check_ip_addrs(addrs: &[IpCidr]) { + for cidr in addrs { + if !cidr.address().is_unicast() && !cidr.address().is_unspecified() { + panic!("IP address {} is not unicast", cidr.address()) + } + } + } + + #[cfg(feature = "medium-ieee802154")] + fn get_sequence_number(&mut self) -> u8 { + let no = self.sequence_no; + self.sequence_no = self.sequence_no.wrapping_add(1); + no + } + + #[cfg(feature = "proto-ipv4-fragmentation")] + fn get_ipv4_ident(&mut self) -> u16 { + let ipv4_id = self.ipv4_id; + self.ipv4_id = self.ipv4_id.wrapping_add(1); + ipv4_id + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn get_sixlowpan_fragment_tag(&mut self) -> u16 { + let tag = self.tag; + self.tag = self.tag.wrapping_add(1); + tag + } + + /// Determine if the given `Ipv6Address` is the solicited node + /// multicast address for a IPv6 addresses assigned to the interface. + /// See [RFC 4291 § 2.7.1] for more details. + /// + /// [RFC 4291 § 2.7.1]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + #[cfg(feature = "proto-ipv6")] + pub fn has_solicited_node(&self, addr: Ipv6Address) -> bool { + self.ip_addrs.iter().any(|cidr| { + match *cidr { + IpCidr::Ipv6(cidr) if cidr.address() != Ipv6Address::LOOPBACK => { + // Take the lower order 24 bits of the IPv6 address and + // append those bits to FF02:0:0:0:0:1:FF00::/104. + addr.as_bytes()[14..] == cidr.address().as_bytes()[14..] + } + _ => false, + } + }) + } + + /// Check whether the interface has the given IP address assigned. + fn has_ip_addr>(&self, addr: T) -> bool { + let addr = addr.into(); + self.ip_addrs.iter().any(|probe| probe.address() == addr) + } + + /// Get the first IPv4 address of the interface. + #[cfg(feature = "proto-ipv4")] + pub fn ipv4_addr(&self) -> Option { + self.ip_addrs.iter().find_map(|addr| match *addr { + IpCidr::Ipv4(cidr) => Some(cidr.address()), + #[allow(unreachable_patterns)] + _ => None, + }) + } + + /// Get the first IPv6 address if present. + #[cfg(feature = "proto-ipv6")] + pub fn ipv6_addr(&self) -> Option { + self.ip_addrs.iter().find_map(|addr| match *addr { + IpCidr::Ipv6(cidr) => Some(cidr.address()), + #[allow(unreachable_patterns)] + _ => None, + }) + } + + #[cfg(not(feature = "proto-igmp"))] + fn has_multicast_group>(&self, addr: T) -> bool { + false + } + + #[cfg(feature = "medium-ip")] + fn process_ip<'frame, T: AsRef<[u8]>>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ip_payload: &'frame T, + frag: &'frame mut FragmentsBuffer, + ) -> Option> { + match IpVersion::of_packet(ip_payload.as_ref()) { + #[cfg(feature = "proto-ipv4")] + Ok(IpVersion::Ipv4) => { + let ipv4_packet = check!(Ipv4Packet::new_checked(ip_payload)); + + self.process_ipv4(sockets, meta, &ipv4_packet, frag) + } + #[cfg(feature = "proto-ipv6")] + Ok(IpVersion::Ipv6) => { + let ipv6_packet = check!(Ipv6Packet::new_checked(ip_payload)); + self.process_ipv6(sockets, meta, &ipv6_packet) + } + // Drop all other traffic. + _ => None, + } + } + + #[cfg(feature = "socket-raw")] + fn raw_socket_filter( + &mut self, + sockets: &mut SocketSet, + ip_repr: &IpRepr, + ip_payload: &[u8], + ) -> bool { + let mut handled_by_raw_socket = false; + + // Pass every IP packet to all raw sockets we have registered. + for raw_socket in sockets + .items_mut() + .filter_map(|i| raw::Socket::downcast_mut(&mut i.socket)) + { + if raw_socket.accepts(ip_repr) { + raw_socket.process(self, ip_repr, ip_payload); + handled_by_raw_socket = true; + } + } + handled_by_raw_socket + } + + /// Checks if an address is broadcast, taking into account ipv4 subnet-local + /// broadcast addresses. + pub(crate) fn is_broadcast(&self, address: &IpAddress) -> bool { + match address { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(address) => self.is_broadcast_v4(*address), + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_) => false, + } + } + + /// Checks if an address is broadcast, taking into account ipv4 subnet-local + /// broadcast addresses. + #[cfg(feature = "proto-ipv4")] + pub(crate) fn is_broadcast_v4(&self, address: Ipv4Address) -> bool { + if address.is_broadcast() { + return true; + } + + self.ip_addrs + .iter() + .filter_map(|own_cidr| match own_cidr { + IpCidr::Ipv4(own_ip) => Some(own_ip.broadcast()?), + #[cfg(feature = "proto-ipv6")] + IpCidr::Ipv6(_) => None, + }) + .any(|broadcast_address| address == broadcast_address) + } + + /// Checks if an ipv4 address is unicast, taking into account subnet broadcast addresses + #[cfg(feature = "proto-ipv4")] + fn is_unicast_v4(&self, address: Ipv4Address) -> bool { + address.is_unicast() && !self.is_broadcast_v4(address) + } + + #[cfg(any(feature = "socket-udp", feature = "socket-dns"))] + #[allow(clippy::too_many_arguments)] + fn process_udp<'frame>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ip_repr: IpRepr, + udp_repr: UdpRepr, + handled_by_raw_socket: bool, + udp_payload: &'frame [u8], + ip_payload: &'frame [u8], + ) -> Option> { + #[cfg(feature = "socket-udp")] + for udp_socket in sockets + .items_mut() + .filter_map(|i| udp::Socket::downcast_mut(&mut i.socket)) + { + if udp_socket.accepts(self, &ip_repr, &udp_repr) { + udp_socket.process(self, meta, &ip_repr, &udp_repr, udp_payload); + return None; + } + } + + #[cfg(feature = "socket-dns")] + for dns_socket in sockets + .items_mut() + .filter_map(|i| dns::Socket::downcast_mut(&mut i.socket)) + { + if dns_socket.accepts(&ip_repr, &udp_repr) { + dns_socket.process(self, &ip_repr, &udp_repr, udp_payload); + return None; + } + } + + // The packet wasn't handled by a socket, send an ICMP port unreachable packet. + match ip_repr { + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(_) if handled_by_raw_socket => None, + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(_) if handled_by_raw_socket => None, + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(ipv4_repr) => { + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV4_MIN_MTU, ipv4_repr.buffer_len()); + let icmpv4_reply_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: ipv4_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv4_reply(ipv4_repr, icmpv4_reply_repr) + } + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(ipv6_repr) => { + let payload_len = + icmp_reply_payload_len(ip_payload.len(), IPV6_MIN_MTU, ipv6_repr.buffer_len()); + let icmpv6_reply_repr = Icmpv6Repr::DstUnreachable { + reason: Icmpv6DstUnreachable::PortUnreachable, + header: ipv6_repr, + data: &ip_payload[0..payload_len], + }; + self.icmpv6_reply(ipv6_repr, icmpv6_reply_repr) + } + } + } + + #[cfg(feature = "socket-tcp")] + pub(crate) fn process_tcp<'frame>( + &mut self, + sockets: &mut SocketSet, + ip_repr: IpRepr, + ip_payload: &'frame [u8], + ) -> Option> { + let (src_addr, dst_addr) = (ip_repr.src_addr(), ip_repr.dst_addr()); + let tcp_packet = check!(TcpPacket::new_checked(ip_payload)); + let tcp_repr = check!(TcpRepr::parse( + &tcp_packet, + &src_addr, + &dst_addr, + &self.caps.checksum + )); + + for tcp_socket in sockets + .items_mut() + .filter_map(|i| tcp::Socket::downcast_mut(&mut i.socket)) + { + if tcp_socket.accepts(self, &ip_repr, &tcp_repr) { + return tcp_socket + .process(self, &ip_repr, &tcp_repr) + .map(IpPacket::Tcp); + } + } + + if tcp_repr.control == TcpControl::Rst { + // Never reply to a TCP RST packet with another TCP RST packet. + None + } else { + // The packet wasn't handled by a socket, send a TCP RST packet. + Some(IpPacket::Tcp(tcp::Socket::rst_reply(&ip_repr, &tcp_repr))) + } + } + + #[cfg(feature = "medium-ethernet")] + fn dispatch( + &mut self, + tx_token: Tx, + packet: EthernetPacket, + frag: &mut Fragmenter, + ) -> Result<(), DispatchError> + where + Tx: TxToken, + { + match packet { + #[cfg(feature = "proto-ipv4")] + EthernetPacket::Arp(arp_repr) => { + let dst_hardware_addr = match arp_repr { + ArpRepr::EthernetIpv4 { + target_hardware_addr, + .. + } => target_hardware_addr, + }; + + self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { + frame.set_dst_addr(dst_hardware_addr); + frame.set_ethertype(EthernetProtocol::Arp); + + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + arp_repr.emit(&mut packet); + }) + } + EthernetPacket::Ip(packet) => { + self.dispatch_ip(tx_token, PacketMeta::default(), packet, frag) + } + } + } + + fn in_same_network(&self, addr: &IpAddress) -> bool { + self.ip_addrs.iter().any(|cidr| cidr.contains_addr(addr)) + } + + fn route(&self, addr: &IpAddress, timestamp: Instant) -> Option { + // Send directly. + // note: no need to use `self.is_broadcast()` to check for subnet-local broadcast addrs + // here because `in_same_network` will already return true. + if self.in_same_network(addr) || addr.is_broadcast() { + return Some(*addr); + } + + // Route via a router. + self.routes.lookup(addr, timestamp) + } + + fn has_neighbor(&self, addr: &IpAddress) -> bool { + match self.route(addr, self.now) { + Some(_routed_addr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => self.neighbor_cache.lookup(&_routed_addr, self.now).found(), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => self.neighbor_cache.lookup(&_routed_addr, self.now).found(), + #[cfg(feature = "medium-ip")] + Medium::Ip => true, + }, + None => false, + } + } + + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + fn lookup_hardware_addr( + &mut self, + tx_token: Tx, + src_addr: &IpAddress, + dst_addr: &IpAddress, + fragmenter: &mut Fragmenter, + ) -> Result<(HardwareAddress, Tx), DispatchError> + where + Tx: TxToken, + { + if self.is_broadcast(dst_addr) { + let hardware_addr = match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress::BROADCAST), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => HardwareAddress::Ieee802154(Ieee802154Address::BROADCAST), + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + }; + + return Ok((hardware_addr, tx_token)); + } + + if dst_addr.is_multicast() { + let b = dst_addr.as_bytes(); + let hardware_addr = match *dst_addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(_addr) => { + HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0x01, + 0x00, + 0x5e, + b[1] & 0x7F, + b[2], + b[3], + ])) + } + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_addr) => match self.caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0x33, 0x33, b[12], b[13], b[14], b[15], + ])), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + // Not sure if this is correct + HardwareAddress::Ieee802154(Ieee802154Address::BROADCAST) + } + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + }, + }; + + return Ok((hardware_addr, tx_token)); + } + + let dst_addr = self + .route(dst_addr, self.now) + .ok_or(DispatchError::NoRoute)?; + + match self.neighbor_cache.lookup(&dst_addr, self.now) { + NeighborAnswer::Found(hardware_addr) => return Ok((hardware_addr, tx_token)), + NeighborAnswer::RateLimited => return Err(DispatchError::NeighborPending), + _ => (), // XXX + } + + match (src_addr, dst_addr) { + #[cfg(feature = "proto-ipv4")] + (&IpAddress::Ipv4(src_addr), IpAddress::Ipv4(dst_addr)) => { + net_debug!( + "address {} not in neighbor cache, sending ARP request", + dst_addr + ); + let src_hardware_addr = self.hardware_addr.ethernet_or_panic(); + + let arp_repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: src_hardware_addr, + source_protocol_addr: src_addr, + target_hardware_addr: EthernetAddress::BROADCAST, + target_protocol_addr: dst_addr, + }; + + if let Err(e) = + self.dispatch_ethernet(tx_token, arp_repr.buffer_len(), |mut frame| { + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_ethertype(EthernetProtocol::Arp); + + arp_repr.emit(&mut ArpPacket::new_unchecked(frame.payload_mut())) + }) + { + net_debug!("Failed to dispatch ARP request: {:?}", e); + return Err(DispatchError::NeighborPending); + } + } + + #[cfg(feature = "proto-ipv6")] + (&IpAddress::Ipv6(src_addr), IpAddress::Ipv6(dst_addr)) => { + net_debug!( + "address {} not in neighbor cache, sending Neighbor Solicitation", + dst_addr + ); + + let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { + target_addr: dst_addr, + lladdr: Some(self.hardware_addr.into()), + }); + + let packet = IpPacket::Icmpv6(( + Ipv6Repr { + src_addr, + dst_addr: dst_addr.solicited_node(), + next_header: IpProtocol::Icmpv6, + payload_len: solicit.buffer_len(), + hop_limit: 0xff, + }, + solicit, + )); + + if let Err(e) = + self.dispatch_ip(tx_token, PacketMeta::default(), packet, fragmenter) + { + net_debug!("Failed to dispatch NDISC solicit: {:?}", e); + return Err(DispatchError::NeighborPending); + } + } + + #[allow(unreachable_patterns)] + _ => (), + } + + // The request got dispatched, limit the rate on the cache. + self.neighbor_cache.limit_rate(self.now); + Err(DispatchError::NeighborPending) + } + + fn flush_cache(&mut self) { + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + self.neighbor_cache.flush() + } + + fn dispatch_ip( + &mut self, + // NOTE(unused_mut): tx_token isn't always mutated, depending on + // the feature set that is used. + #[allow(unused_mut)] mut tx_token: Tx, + meta: PacketMeta, + packet: IpPacket, + frag: &mut Fragmenter, + ) -> Result<(), DispatchError> { + let mut ip_repr = packet.ip_repr(); + assert!(!ip_repr.dst_addr().is_unspecified()); + + // Dispatch IEEE802.15.4: + + #[cfg(feature = "medium-ieee802154")] + if matches!(self.caps.medium, Medium::Ieee802154) { + let (addr, tx_token) = self.lookup_hardware_addr( + tx_token, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + frag, + )?; + let addr = addr.ieee802154_or_panic(); + + self.dispatch_ieee802154(addr, tx_token, meta, packet, frag); + return Ok(()); + } + + // Dispatch IP/Ethernet: + + let caps = self.caps.clone(); + + #[cfg(feature = "proto-ipv4-fragmentation")] + let ipv4_id = self.get_ipv4_ident(); + + // First we calculate the total length that we will have to emit. + let mut total_len = ip_repr.buffer_len(); + + // Add the size of the Ethernet header if the medium is Ethernet. + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + total_len = EthernetFrame::<&[u8]>::buffer_len(total_len); + } + + // If the medium is Ethernet, then we need to retrieve the destination hardware address. + #[cfg(feature = "medium-ethernet")] + let (dst_hardware_addr, mut tx_token) = match self.caps.medium { + Medium::Ethernet => { + match self.lookup_hardware_addr( + tx_token, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + frag, + )? { + (HardwareAddress::Ethernet(addr), tx_token) => (addr, tx_token), + (_, _) => unreachable!(), + } + } + _ => (EthernetAddress([0; 6]), tx_token), + }; + + // Emit function for the Ethernet header. + #[cfg(feature = "medium-ethernet")] + let emit_ethernet = |repr: &IpRepr, tx_buffer: &mut [u8]| { + let mut frame = EthernetFrame::new_unchecked(tx_buffer); + + let src_addr = self.hardware_addr.ethernet_or_panic(); + frame.set_src_addr(src_addr); + frame.set_dst_addr(dst_hardware_addr); + + match repr.version() { + #[cfg(feature = "proto-ipv4")] + IpVersion::Ipv4 => frame.set_ethertype(EthernetProtocol::Ipv4), + #[cfg(feature = "proto-ipv6")] + IpVersion::Ipv6 => frame.set_ethertype(EthernetProtocol::Ipv6), + } + + Ok(()) + }; + + // Emit function for the IP header and payload. + let emit_ip = |repr: &IpRepr, mut tx_buffer: &mut [u8]| { + repr.emit(&mut tx_buffer, &self.caps.checksum); + + let payload = &mut tx_buffer[repr.header_len()..]; + packet.emit_payload(repr, payload, &caps) + }; + + let total_ip_len = ip_repr.buffer_len(); + + match &mut ip_repr { + #[cfg(feature = "proto-ipv4")] + IpRepr::Ipv4(repr) => { + // If we have an IPv4 packet, then we need to check if we need to fragment it. + if total_ip_len > self.caps.max_transmission_unit { + #[cfg(feature = "proto-ipv4-fragmentation")] + { + net_debug!("start fragmentation"); + + // Calculate how much we will send now (including the Ethernet header). + let tx_len = self.caps.max_transmission_unit; + + let ip_header_len = repr.buffer_len(); + let first_frag_ip_len = self.caps.ip_mtu(); + + if frag.buffer.len() < total_ip_len { + net_debug!( + "Fragmentation buffer is too small, at least {} needed. Dropping", + total_ip_len + ); + return Ok(()); + } + + #[cfg(feature = "medium-ethernet")] + { + frag.ipv4.dst_hardware_addr = dst_hardware_addr; + } + + // Save the total packet len (without the Ethernet header, but with the first + // IP header). + frag.packet_len = total_ip_len; + + // Save the IP header for other fragments. + frag.ipv4.repr = *repr; + + // Save how much bytes we will send now. + frag.sent_bytes = first_frag_ip_len; + + // Modify the IP header + repr.payload_len = first_frag_ip_len - repr.buffer_len(); + + // Emit the IP header to the buffer. + emit_ip(&ip_repr, &mut frag.buffer); + + let mut ipv4_packet = Ipv4Packet::new_unchecked(&mut frag.buffer[..]); + frag.ipv4.ident = ipv4_id; + ipv4_packet.set_ident(ipv4_id); + ipv4_packet.set_more_frags(true); + ipv4_packet.set_dont_frag(false); + ipv4_packet.set_frag_offset(0); + + if caps.checksum.ipv4.tx() { + ipv4_packet.fill_checksum(); + } + + // Transmit the first packet. + tx_token.consume(tx_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + // Change the offset for the next packet. + frag.ipv4.frag_offset = (first_frag_ip_len - ip_header_len) as u16; + + // Copy the IP header and the payload. + tx_buffer[..first_frag_ip_len] + .copy_from_slice(&frag.buffer[..first_frag_ip_len]); + + Ok(()) + }) + } + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] + { + net_debug!("Enable the `proto-ipv4-fragmentation` feature for fragmentation support."); + Ok(()) + } + } else { + tx_token.set_meta(meta); + + // No fragmentation is required. + tx_token.consume(total_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + emit_ip(&ip_repr, tx_buffer); + Ok(()) + }) + } + } + // We don't support IPv6 fragmentation yet. + #[cfg(feature = "proto-ipv6")] + IpRepr::Ipv6(_) => tx_token.consume(total_len, |mut tx_buffer| { + #[cfg(feature = "medium-ethernet")] + if matches!(self.caps.medium, Medium::Ethernet) { + emit_ethernet(&ip_repr, tx_buffer)?; + tx_buffer = &mut tx_buffer[EthernetFrame::<&[u8]>::header_len()..]; + } + + emit_ip(&ip_repr, tx_buffer); + Ok(()) + }), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum DispatchError { + /// No route to dispatch this packet. Retrying won't help unless + /// configuration is changed. + NoRoute, + /// We do have a route to dispatch this packet, but we haven't discovered + /// the neighbor for it yet. Discovery has been initiated, dispatch + /// should be retried later. + NeighborPending, +} diff --git a/src/iface/interface/sixlowpan.rs b/src/iface/interface/sixlowpan.rs new file mode 100644 index 000000000..838866911 --- /dev/null +++ b/src/iface/interface/sixlowpan.rs @@ -0,0 +1,528 @@ +use super::*; + +use crate::phy::ChecksumCapabilities; +use crate::wire::*; + +// Max len of non-fragmented packets after decompression (including ipv6 header and payload) +// TODO: lower. Should be (6lowpan mtu) - (min 6lowpan header size) + (max ipv6 header size) +pub(crate) const MAX_DECOMPRESSED_LEN: usize = 1500; + +impl InterfaceInner { + pub(super) fn process_sixlowpan<'output, 'payload: 'output, T: AsRef<[u8]> + ?Sized>( + &mut self, + sockets: &mut SocketSet, + meta: PacketMeta, + ieee802154_repr: &Ieee802154Repr, + payload: &'payload T, + f: &'output mut FragmentsBuffer, + ) -> Option> { + let payload = match check!(SixlowpanPacket::dispatch(payload)) { + #[cfg(not(feature = "proto-sixlowpan-fragmentation"))] + SixlowpanPacket::FragmentHeader => { + net_debug!( + "Fragmentation is not supported, \ + use the `proto-sixlowpan-fragmentation` feature to add support." + ); + return None; + } + #[cfg(feature = "proto-sixlowpan-fragmentation")] + SixlowpanPacket::FragmentHeader => { + match self.process_sixlowpan_fragment(ieee802154_repr, payload, f) { + Some(payload) => payload, + None => return None, + } + } + SixlowpanPacket::IphcHeader => { + match self.decompress_sixlowpan( + ieee802154_repr, + payload.as_ref(), + None, + &mut f.decompress_buf, + ) { + Ok(len) => &f.decompress_buf[..len], + Err(e) => { + net_debug!("sixlowpan decompress failed: {:?}", e); + return None; + } + } + } + }; + + self.process_ipv6(sockets, meta, &check!(Ipv6Packet::new_checked(payload))) + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + fn process_sixlowpan_fragment<'output, 'payload: 'output, T: AsRef<[u8]> + ?Sized>( + &mut self, + ieee802154_repr: &Ieee802154Repr, + payload: &'payload T, + f: &'output mut FragmentsBuffer, + ) -> Option<&'output [u8]> { + use crate::iface::fragmentation::{AssemblerError, AssemblerFullError}; + + // We have a fragment header, which means we cannot process the 6LoWPAN packet, + // unless we have a complete one after processing this fragment. + let frag = check!(SixlowpanFragPacket::new_checked(payload)); + + // The key specifies to which 6LoWPAN fragment it belongs too. + // It is based on the link layer addresses, the tag and the size. + let key = FragKey::Sixlowpan(frag.get_key(ieee802154_repr)); + + // The offset of this fragment in increments of 8 octets. + let offset = frag.datagram_offset() as usize * 8; + + // We reserve a spot in the packet assembler set and add the required + // information to the packet assembler. + // This information is the total size of the packet when it is fully assmbled. + // We also pass the header size, since this is needed when other fragments + // (other than the first one) are added. + let frag_slot = match f.assembler.get(&key, self.now + f.reassembly_timeout) { + Ok(frag) => frag, + Err(AssemblerFullError) => { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + }; + + if frag.is_first_fragment() { + // The first fragment contains the total size of the IPv6 packet. + // However, we received a packet that is compressed following the 6LoWPAN + // standard. This means we need to convert the IPv6 packet size to a 6LoWPAN + // packet size. The packet size can be different because of first the + // compression of the IP header and when UDP is used (because the UDP header + // can also be compressed). Other headers are not compressed by 6LoWPAN. + + // First segment tells us the total size. + let total_size = frag.datagram_size() as usize; + if frag_slot.set_total_size(total_size).is_err() { + net_debug!("No available packet assembler for fragmented packet"); + return None; + } + + // Decompress headers+payload into the assembler. + if let Err(e) = frag_slot.add_with(0, |buffer| { + self.decompress_sixlowpan(ieee802154_repr, frag.payload(), Some(total_size), buffer) + .map_err(|_| AssemblerError) + }) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + } else { + // Add the fragment to the packet assembler. + if let Err(e) = frag_slot.add(frag.payload(), offset) { + net_debug!("fragmentation error: {:?}", e); + return None; + } + } + + match frag_slot.assemble() { + Some(payload) => { + net_trace!("6LoWPAN: fragmented packet now complete"); + Some(payload) + } + None => None, + } + } + + fn decompress_sixlowpan( + &self, + ieee802154_repr: &Ieee802154Repr, + iphc_payload: &[u8], + total_size: Option, + buffer: &mut [u8], + ) -> core::result::Result { + let iphc = SixlowpanIphcPacket::new_checked(iphc_payload)?; + let iphc_repr = SixlowpanIphcRepr::parse( + &iphc, + ieee802154_repr.src_addr, + ieee802154_repr.dst_addr, + &self.sixlowpan_address_context, + )?; + + let mut decompressed_size = 40 + iphc.payload().len(); + + let next_header = match iphc_repr.next_header { + SixlowpanNextHeader::Compressed => { + match SixlowpanNhcPacket::dispatch(iphc.payload())? { + SixlowpanNhcPacket::ExtHeader => { + net_debug!("Extension headers are currently not supported for 6LoWPAN"); + IpProtocol::Unknown(0) + } + SixlowpanNhcPacket::UdpHeader => { + let udp_packet = SixlowpanUdpNhcPacket::new_checked(iphc.payload())?; + let udp_repr = SixlowpanUdpNhcRepr::parse( + &udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + &crate::phy::ChecksumCapabilities::ignored(), + )?; + + decompressed_size += 8; + decompressed_size -= udp_repr.header_len(); + IpProtocol::Udp + } + } + } + SixlowpanNextHeader::Uncompressed(proto) => proto, + }; + + if buffer.len() < decompressed_size { + net_debug!("sixlowpan decompress: buffer too short"); + return Err(crate::wire::Error); + } + let buffer = &mut buffer[..decompressed_size]; + + let total_size = if let Some(size) = total_size { + size + } else { + decompressed_size + }; + + let ipv6_repr = Ipv6Repr { + src_addr: iphc_repr.src_addr, + dst_addr: iphc_repr.dst_addr, + next_header, + payload_len: total_size - 40, + hop_limit: iphc_repr.hop_limit, + }; + + // Emit the decompressed IPHC header (decompressed to an IPv6 header). + let mut ipv6_packet = Ipv6Packet::new_unchecked(&mut buffer[..ipv6_repr.buffer_len()]); + ipv6_repr.emit(&mut ipv6_packet); + let buffer = &mut buffer[ipv6_repr.buffer_len()..]; + + match iphc_repr.next_header { + SixlowpanNextHeader::Compressed => { + match SixlowpanNhcPacket::dispatch(iphc.payload())? { + SixlowpanNhcPacket::ExtHeader => todo!(), + SixlowpanNhcPacket::UdpHeader => { + // We need to uncompress the UDP packet and emit it to the + // buffer. + let udp_packet = SixlowpanUdpNhcPacket::new_checked(iphc.payload())?; + let udp_repr = SixlowpanUdpNhcRepr::parse( + &udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + &ChecksumCapabilities::ignored(), + )?; + + let mut udp = UdpPacket::new_unchecked( + &mut buffer[..udp_repr.0.header_len() + iphc.payload().len() + - udp_repr.header_len()], + ); + udp_repr.0.emit_header(&mut udp, ipv6_repr.payload_len - 8); + + buffer[8..].copy_from_slice(&iphc.payload()[udp_repr.header_len()..]); + } + } + } + SixlowpanNextHeader::Uncompressed(_) => { + // For uncompressed headers we just copy the slice. + let len = iphc.payload().len(); + buffer[..len].copy_from_slice(iphc.payload()); + } + }; + + Ok(decompressed_size) + } + + pub(super) fn dispatch_sixlowpan( + &mut self, + mut tx_token: Tx, + meta: PacketMeta, + packet: IpPacket, + ieee_repr: Ieee802154Repr, + frag: &mut Fragmenter, + ) { + let ip_repr = packet.ip_repr(); + + let (src_addr, dst_addr) = match (ip_repr.src_addr(), ip_repr.dst_addr()) { + (IpAddress::Ipv6(src_addr), IpAddress::Ipv6(dst_addr)) => (src_addr, dst_addr), + #[allow(unreachable_patterns)] + _ => { + unreachable!() + } + }; + + // Create the 6LoWPAN IPHC header. + let iphc_repr = SixlowpanIphcRepr { + src_addr, + ll_src_addr: ieee_repr.src_addr, + dst_addr, + ll_dst_addr: ieee_repr.dst_addr, + next_header: match &packet { + IpPacket::Icmpv6(_) => SixlowpanNextHeader::Uncompressed(IpProtocol::Icmpv6), + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp(_) => SixlowpanNextHeader::Uncompressed(IpProtocol::Tcp), + #[cfg(feature = "socket-udp")] + IpPacket::Udp(_) => SixlowpanNextHeader::Compressed, + #[allow(unreachable_patterns)] + _ => { + net_debug!("dispatch_ieee802154: dropping, unhandled protocol."); + return; + } + }, + hop_limit: ip_repr.hop_limit(), + ecn: None, + dscp: None, + flow_label: None, + }; + + // Now we calculate the total size of the packet. + // We need to know this, such that we know when to do the fragmentation. + let mut total_size = 0; + total_size += iphc_repr.buffer_len(); + let mut _compressed_headers_len = iphc_repr.buffer_len(); + let mut _uncompressed_headers_len = ip_repr.header_len(); + + match packet { + #[cfg(feature = "socket-udp")] + IpPacket::Udp((_, udpv6_repr, payload)) => { + let udp_repr = SixlowpanUdpNhcRepr(udpv6_repr); + _compressed_headers_len += udp_repr.header_len(); + _uncompressed_headers_len += udpv6_repr.header_len(); + total_size += udp_repr.header_len() + payload.len(); + } + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp((_, tcp_repr)) => { + total_size += tcp_repr.buffer_len(); + } + #[cfg(feature = "proto-ipv6")] + IpPacket::Icmpv6((_, icmp_repr)) => { + total_size += icmp_repr.buffer_len(); + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + + let ieee_len = ieee_repr.buffer_len(); + + if total_size + ieee_len > 125 { + #[cfg(feature = "proto-sixlowpan-fragmentation")] + { + // The packet does not fit in one Ieee802154 frame, so we need fragmentation. + // We do this by emitting everything in the `frag.buffer` from the interface. + // After emitting everything into that buffer, we send the first fragment heere. + // When `poll` is called again, we check if frag was fully sent, otherwise we + // call `dispatch_ieee802154_frag`, which will transmit the other fragments. + + // `dispatch_ieee802154_frag` requires some information about the total packet size, + // the link local source and destination address... + let pkt = frag; + + if pkt.buffer.len() < total_size { + net_debug!( + "dispatch_ieee802154: dropping, \ + fragmentation buffer is too small, at least {} needed", + total_size + ); + return; + } + + pkt.sixlowpan.ll_dst_addr = ieee_repr.dst_addr.unwrap(); + pkt.sixlowpan.ll_src_addr = ieee_repr.src_addr.unwrap(); + + let mut iphc_packet = + SixlowpanIphcPacket::new_unchecked(&mut pkt.buffer[..iphc_repr.buffer_len()]); + iphc_repr.emit(&mut iphc_packet); + + let b = &mut pkt.buffer[iphc_repr.buffer_len()..]; + + match packet { + #[cfg(feature = "socket-udp")] + IpPacket::Udp((_, udpv6_repr, payload)) => { + let udp_repr = SixlowpanUdpNhcRepr(udpv6_repr); + let mut udp_packet = SixlowpanUdpNhcPacket::new_unchecked( + &mut b[..udp_repr.header_len() + payload.len()], + ); + udp_repr.emit( + &mut udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + payload.len(), + |buf| buf.copy_from_slice(payload), + ); + } + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp((_, tcp_repr)) => { + let mut tcp_packet = + TcpPacket::new_unchecked(&mut b[..tcp_repr.buffer_len()]); + tcp_repr.emit( + &mut tcp_packet, + &iphc_repr.src_addr.into(), + &iphc_repr.dst_addr.into(), + &self.caps.checksum, + ); + } + #[cfg(feature = "proto-ipv6")] + IpPacket::Icmpv6((_, icmp_repr)) => { + let mut icmp_packet = + Icmpv6Packet::new_unchecked(&mut b[..icmp_repr.buffer_len()]); + icmp_repr.emit( + &iphc_repr.src_addr.into(), + &iphc_repr.dst_addr.into(), + &mut icmp_packet, + &self.caps.checksum, + ); + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + + pkt.packet_len = total_size; + + // The datagram size that we need to set in the first fragment header is equal to the + // IPv6 payload length + 40. + pkt.sixlowpan.datagram_size = (packet.ip_repr().payload_len() + 40) as u16; + + // We generate a random tag. + let tag = self.get_sixlowpan_fragment_tag(); + // We save the tag for the other fragments that will be created when calling `poll` + // multiple times. + pkt.sixlowpan.datagram_tag = tag; + + let frag1 = SixlowpanFragRepr::FirstFragment { + size: pkt.sixlowpan.datagram_size, + tag, + }; + let fragn = SixlowpanFragRepr::Fragment { + size: pkt.sixlowpan.datagram_size, + tag, + offset: 0, + }; + + // We calculate how much data we can send in the first fragment and the other + // fragments. The eventual IPv6 sizes of these fragments need to be a multiple of eight + // (except for the last fragment) since the offset field in the fragment is an offset + // in multiples of 8 octets. This is explained in [RFC 4944 § 5.3]. + // + // [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 + + let header_diff = _uncompressed_headers_len - _compressed_headers_len; + let frag1_size = + (125 - ieee_len - frag1.buffer_len() + header_diff) / 8 * 8 - (header_diff); + + pkt.sixlowpan.fragn_size = (125 - ieee_len - fragn.buffer_len()) / 8 * 8; + + pkt.sent_bytes = frag1_size; + pkt.sixlowpan.datagram_offset = frag1_size + header_diff; + + tx_token.consume(ieee_len + frag1.buffer_len() + frag1_size, |mut tx_buf| { + // Add the IEEE header. + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + // Add the first fragment header + let mut frag1_packet = SixlowpanFragPacket::new_unchecked(&mut tx_buf); + frag1.emit(&mut frag1_packet); + tx_buf = &mut tx_buf[frag1.buffer_len()..]; + + // Add the buffer part. + tx_buf[..frag1_size].copy_from_slice(&pkt.buffer[..frag1_size]); + }); + } + + #[cfg(not(feature = "proto-sixlowpan-fragmentation"))] + { + net_debug!( + "Enable the `proto-sixlowpan-fragmentation` feature for fragmentation support." + ); + return; + } + } else { + tx_token.set_meta(meta); + + // We don't need fragmentation, so we emit everything to the TX token. + tx_token.consume(total_size + ieee_len, |mut tx_buf| { + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + let mut iphc_packet = + SixlowpanIphcPacket::new_unchecked(&mut tx_buf[..iphc_repr.buffer_len()]); + iphc_repr.emit(&mut iphc_packet); + tx_buf = &mut tx_buf[iphc_repr.buffer_len()..]; + + match packet { + #[cfg(feature = "socket-udp")] + IpPacket::Udp((_, udpv6_repr, payload)) => { + let udp_repr = SixlowpanUdpNhcRepr(udpv6_repr); + let mut udp_packet = SixlowpanUdpNhcPacket::new_unchecked( + &mut tx_buf[..udp_repr.header_len() + payload.len()], + ); + udp_repr.emit( + &mut udp_packet, + &iphc_repr.src_addr, + &iphc_repr.dst_addr, + payload.len(), + |buf| buf.copy_from_slice(payload), + ); + } + #[cfg(feature = "socket-tcp")] + IpPacket::Tcp((_, tcp_repr)) => { + let mut tcp_packet = + TcpPacket::new_unchecked(&mut tx_buf[..tcp_repr.buffer_len()]); + tcp_repr.emit( + &mut tcp_packet, + &iphc_repr.src_addr.into(), + &iphc_repr.dst_addr.into(), + &self.caps.checksum, + ); + } + #[cfg(feature = "proto-ipv6")] + IpPacket::Icmpv6((_, icmp_repr)) => { + let mut icmp_packet = + Icmpv6Packet::new_unchecked(&mut tx_buf[..icmp_repr.buffer_len()]); + icmp_repr.emit( + &iphc_repr.src_addr.into(), + &iphc_repr.dst_addr.into(), + &mut icmp_packet, + &self.caps.checksum, + ); + } + #[allow(unreachable_patterns)] + _ => unreachable!(), + } + }); + } + } + + #[cfg(feature = "proto-sixlowpan-fragmentation")] + pub(super) fn dispatch_sixlowpan_frag( + &mut self, + tx_token: Tx, + ieee_repr: Ieee802154Repr, + frag: &mut Fragmenter, + ) { + // Create the FRAG_N header. + let fragn = SixlowpanFragRepr::Fragment { + size: frag.sixlowpan.datagram_size, + tag: frag.sixlowpan.datagram_tag, + offset: (frag.sixlowpan.datagram_offset / 8) as u8, + }; + + let ieee_len = ieee_repr.buffer_len(); + let frag_size = (frag.packet_len - frag.sent_bytes).min(frag.sixlowpan.fragn_size); + + tx_token.consume( + ieee_repr.buffer_len() + fragn.buffer_len() + frag_size, + |mut tx_buf| { + let mut ieee_packet = Ieee802154Frame::new_unchecked(&mut tx_buf[..ieee_len]); + ieee_repr.emit(&mut ieee_packet); + tx_buf = &mut tx_buf[ieee_len..]; + + let mut frag_packet = + SixlowpanFragPacket::new_unchecked(&mut tx_buf[..fragn.buffer_len()]); + fragn.emit(&mut frag_packet); + tx_buf = &mut tx_buf[fragn.buffer_len()..]; + + // Add the buffer part + tx_buf[..frag_size].copy_from_slice(&frag.buffer[frag.sent_bytes..][..frag_size]); + + frag.sent_bytes += frag_size; + frag.sixlowpan.datagram_offset += frag_size; + }, + ); + } +} diff --git a/src/iface/interface/tests/ipv4.rs b/src/iface/interface/tests/ipv4.rs new file mode 100644 index 000000000..d109dc4f3 --- /dev/null +++ b/src/iface/interface/tests/ipv4.rs @@ -0,0 +1,961 @@ +use super::*; + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_no_icmp_no_unicast(#[case] medium: Medium) { + let (mut iface, mut sockets, _) = setup(medium); + + // Unknown Ipv4 Protocol + // + // Because the destination is the broadcast address + // this should not trigger and Destination Unreachable + // response. See RFC 1122 § 3.2.2. + let repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + hop_limit: 0x40, + }); + + let mut bytes = vec![0u8; 54]; + repr.emit(&mut bytes, &ChecksumCapabilities::default()); + let frame = Ipv4Packet::new_unchecked(&bytes); + + // Ensure that the unknown protocol frame does not trigger an + // ICMP error response when the destination address is a + // broadcast address + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_icmp_error_no_payload(#[case] medium: Medium) { + static NO_BYTES: [u8; 0] = []; + let (mut iface, mut sockets, _device) = setup(medium); + + // Unknown Ipv4 Protocol with no payload + let repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + hop_limit: 0x40, + }); + + let mut bytes = vec![0u8; 34]; + repr.emit(&mut bytes, &ChecksumCapabilities::default()); + let frame = Ipv4Packet::new_unchecked(&bytes); + + // The expected Destination Unreachable response due to the + // unknown protocol + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::ProtoUnreachable, + header: Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Unknown(12), + payload_len: 0, + hop_limit: 64, + }, + data: &NO_BYTES, + }; + + let expected_repr = IpPacket::Icmpv4(( + Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }, + icmp_repr, + )); + + // Ensure that the unknown protocol triggers an error response. + // And we correctly handle no payload. + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + Some(expected_repr) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_local_subnet_broadcasts(#[case] medium: Medium) { + let (mut iface, _, _device) = setup(medium); + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 23]), 24)); + }); + }); + + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(iface.inner.is_broadcast_v4(Ipv4Address([192, 168, 1, 255]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 168, 1, 254]))); + + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 23, 24]), 16)); + }); + }); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 23, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 23, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 255, 254]))); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([192, 168, 255, 255]))); + + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address([192, 168, 23, 24]), 8)); + }); + }); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 255]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([255, 255, 255, 254]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 23, 1, 255]))); + assert!(!iface.inner.is_broadcast_v4(Ipv4Address([192, 23, 1, 254]))); + assert!(!iface + .inner + .is_broadcast_v4(Ipv4Address([192, 255, 255, 254]))); + assert!(iface + .inner + .is_broadcast_v4(Ipv4Address([192, 255, 255, 255]))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "medium-ip", feature = "socket-udp"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "medium-ethernet", feature = "socket-udp"))] +fn test_icmp_error_port_unreachable(#[case] medium: Medium) { + static UDP_PAYLOAD: [u8; 12] = [ + 0x48, 0x65, 0x6c, 0x6c, 0x6f, 0x2c, 0x20, 0x57, 0x6f, 0x6c, 0x64, 0x21, + ]; + let (mut iface, mut sockets, _device) = setup(medium); + + let mut udp_bytes_unicast = vec![0u8; 20]; + let mut udp_bytes_broadcast = vec![0u8; 20]; + let mut packet_unicast = UdpPacket::new_unchecked(&mut udp_bytes_unicast); + let mut packet_broadcast = UdpPacket::new_unchecked(&mut udp_bytes_broadcast); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }); + + // Emit the representations to a packet + udp_repr.emit( + &mut packet_unicast, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + let data = packet_unicast.into_inner(); + + // The expected Destination Unreachable ICMPv4 error response due + // to no sockets listening on the destination port. + let icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }, + data, + }; + let expected_repr = IpPacket::Icmpv4(( + Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + dst_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + next_header: IpProtocol::Icmp, + payload_len: icmp_repr.buffer_len(), + hop_limit: 64, + }, + icmp_repr, + )); + + // Ensure that the unknown protocol triggers an error response. + // And we correctly handle no payload. + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + data + ), + Some(expected_repr) + ); + + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x02]), + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 64, + }); + + // Emit the representations to a packet + udp_repr.emit( + &mut packet_broadcast, + &ip_repr.src_addr(), + &IpAddress::Ipv4(Ipv4Address::BROADCAST), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + // Ensure that the port unreachable error does not trigger an + // ICMP error response when the destination address is a + // broadcast address and no socket is bound to the port. + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + packet_broadcast.into_inner(), + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_ipv4_broadcast(#[case] medium: Medium) { + use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Packet}; + + let (mut iface, mut sockets, _device) = setup(medium); + + let our_ipv4_addr = iface.ipv4_addr().unwrap(); + let src_ipv4_addr = Ipv4Address([127, 0, 0, 2]); + + // ICMPv4 echo request + let icmpv4_data: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + let icmpv4_repr = Icmpv4Repr::EchoRequest { + ident: 0x1234, + seq_no: 0xabcd, + data: &icmpv4_data, + }; + + // Send to IPv4 broadcast address + let ipv4_repr = Ipv4Repr { + src_addr: src_ipv4_addr, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: icmpv4_repr.buffer_len(), + }; + + // Emit to ip frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + icmpv4_repr.buffer_len()]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes), + &ChecksumCapabilities::default(), + ); + icmpv4_repr.emit( + &mut Icmpv4Packet::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes) + }; + + // Expected ICMPv4 echo reply + let expected_icmpv4_repr = Icmpv4Repr::EchoReply { + ident: 0x1234, + seq_no: 0xabcd, + data: &icmpv4_data, + }; + let expected_ipv4_repr = Ipv4Repr { + src_addr: our_ipv4_addr, + dst_addr: src_ipv4_addr, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: expected_icmpv4_repr.buffer_len(), + }; + let expected_packet = IpPacket::Icmpv4((expected_ipv4_repr, expected_icmpv4_repr)); + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + Some(expected_packet) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_valid_arp_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let local_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let local_hw_addr = EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: local_ip_addr, + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + + // Ensure an ARP Request for us triggers an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: local_hw_addr, + source_protocol_addr: local_ip_addr, + target_hardware_addr: remote_hw_addr, + target_protocol_addr: remote_ip_addr + })) + ); + + // Ensure the address of the requestor was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(local_ip_addr), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_other_arp_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x03]), + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + + // Ensure an ARP Request for someone else does not trigger an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + None + ); + + // Ensure the address of the requestor was NOT entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(Ipv4Address([0x7f, 0x00, 0x00, 0x01])), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Err(DispatchError::NeighborPending) + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_arp_flush_after_update_ip(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 42]; + + let local_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + let remote_ip_addr = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); + let local_hw_addr = EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let repr = ArpRepr::EthernetIpv4 { + operation: ArpOperation::Request, + source_hardware_addr: remote_hw_addr, + source_protocol_addr: remote_ip_addr, + target_hardware_addr: EthernetAddress::default(), + target_protocol_addr: Ipv4Address([0x7f, 0x00, 0x00, 0x01]), + }; + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress::BROADCAST); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Arp); + { + let mut packet = ArpPacket::new_unchecked(frame.payload_mut()); + repr.emit(&mut packet); + } + + // Ensure an ARP Request for us triggers an ARP Reply + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Arp(ArpRepr::EthernetIpv4 { + operation: ArpOperation::Reply, + source_hardware_addr: local_hw_addr, + source_protocol_addr: local_ip_addr, + target_hardware_addr: remote_hw_addr, + target_protocol_addr: remote_ip_addr + })) + ); + + // Ensure the address of the requestor was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv4(local_ip_addr), + &IpAddress::Ipv4(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); + + // Update IP addrs to trigger ARP cache flush + let local_ip_addr_new = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); + iface.update_ip_addrs(|addrs| { + addrs.iter_mut().next().map(|addr| { + *addr = IpCidr::Ipv4(Ipv4Cidr::new(local_ip_addr_new, 24)); + }); + }); + + // ARP cache flush after address change + assert!(!iface.inner.has_neighbor(&IpAddress::Ipv4(remote_ip_addr))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-icmp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-icmp", feature = "medium-ethernet"))] +fn test_icmpv4_socket(#[case] medium: Medium) { + use crate::wire::Icmpv4Packet; + + let (mut iface, mut sockets, _device) = setup(medium); + + let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 24]); + let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 24]); + + let icmpv4_socket = icmp::Socket::new(rx_buffer, tx_buffer); + + let socket_handle = sockets.add(icmpv4_socket); + + let ident = 0x1234; + let seq_no = 0x5432; + let echo_data = &[0xff; 16]; + + let socket = sockets.get_mut::(socket_handle); + // Bind to the ID 0x1234 + assert_eq!(socket.bind(icmp::Endpoint::Ident(ident)), Ok(())); + + // Ensure the ident we bound to and the ident of the packet are the same. + let mut bytes = [0xff; 24]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); + let echo_repr = Icmpv4Repr::EchoRequest { + ident, + seq_no, + data: echo_data, + }; + echo_repr.emit(&mut packet, &ChecksumCapabilities::default()); + let icmp_data = &*packet.into_inner(); + + let ipv4_repr = Ipv4Repr { + src_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x02), + dst_addr: Ipv4Address::new(0x7f, 0x00, 0x00, 0x01), + next_header: IpProtocol::Icmp, + payload_len: 24, + hop_limit: 64, + }; + let ip_repr = IpRepr::Ipv4(ipv4_repr); + + // Open a socket and ensure the packet is handled due to the listening + // socket. + assert!(!sockets.get_mut::(socket_handle).can_recv()); + + // Confirm we still get EchoReply from `smoltcp` even with the ICMP socket listening + let echo_reply = Icmpv4Repr::EchoReply { + ident, + seq_no, + data: echo_data, + }; + let ipv4_reply = Ipv4Repr { + src_addr: ipv4_repr.dst_addr, + dst_addr: ipv4_repr.src_addr, + ..ipv4_repr + }; + assert_eq!( + iface.inner.process_icmpv4(&mut sockets, ip_repr, icmp_data), + Some(IpPacket::Icmpv4((ipv4_reply, echo_reply))) + ); + + let socket = sockets.get_mut::(socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok(( + icmp_data, + IpAddress::Ipv4(Ipv4Address::new(0x7f, 0x00, 0x00, 0x02)) + )) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "proto-igmp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "proto-igmp", feature = "medium-ethernet"))] +fn test_handle_igmp(#[case] medium: Medium) { + fn recv_igmp(device: &mut Loopback, timestamp: Instant) -> Vec<(Ipv4Repr, IgmpRepr)> { + let caps = device.capabilities(); + let checksum_caps = &caps.checksum; + recv_all(device, timestamp) + .iter() + .filter_map(|frame| { + let ipv4_packet = match caps.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + let eth_frame = EthernetFrame::new_checked(frame).ok()?; + Ipv4Packet::new_checked(eth_frame.payload()).ok()? + } + #[cfg(feature = "medium-ip")] + Medium::Ip => Ipv4Packet::new_checked(&frame[..]).ok()?, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + let ipv4_repr = Ipv4Repr::parse(&ipv4_packet, checksum_caps).ok()?; + let ip_payload = ipv4_packet.payload(); + let igmp_packet = IgmpPacket::new_checked(ip_payload).ok()?; + let igmp_repr = IgmpRepr::parse(&igmp_packet).ok()?; + Some((ipv4_repr, igmp_repr)) + }) + .collect::>() + } + + let groups = [ + Ipv4Address::new(224, 0, 0, 22), + Ipv4Address::new(224, 0, 0, 56), + ]; + + let (mut iface, mut sockets, mut device) = setup(medium); + + // Join multicast groups + let timestamp = Instant::now(); + for group in &groups { + iface + .join_multicast_group(&mut device, *group, timestamp) + .unwrap(); + } + + let reports = recv_igmp(&mut device, timestamp); + assert_eq!(reports.len(), 2); + for (i, group_addr) in groups.iter().enumerate() { + assert_eq!(reports[i].0.next_header, IpProtocol::Igmp); + assert_eq!(reports[i].0.dst_addr, *group_addr); + assert_eq!( + reports[i].1, + IgmpRepr::MembershipReport { + group_addr: *group_addr, + version: IgmpVersion::Version2, + } + ); + } + + // General query + let timestamp = Instant::now(); + const GENERAL_QUERY_BYTES: &[u8] = &[ + 0x46, 0xc0, 0x00, 0x24, 0xed, 0xb4, 0x00, 0x00, 0x01, 0x02, 0x47, 0x43, 0xac, 0x16, 0x63, + 0x04, 0xe0, 0x00, 0x00, 0x01, 0x94, 0x04, 0x00, 0x00, 0x11, 0x64, 0xec, 0x8f, 0x00, 0x00, + 0x00, 0x00, 0x02, 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, + ]; + { + // Transmit GENERAL_QUERY_BYTES into loopback + let tx_token = device.transmit(timestamp).unwrap(); + tx_token.consume(GENERAL_QUERY_BYTES.len(), |buffer| { + buffer.copy_from_slice(GENERAL_QUERY_BYTES); + }); + } + // Trigger processing until all packets received through the + // loopback have been processed, including responses to + // GENERAL_QUERY_BYTES. Therefore `recv_all()` would return 0 + // pkts that could be checked. + iface.socket_ingress(&mut device, &mut sockets); + + // Leave multicast groups + let timestamp = Instant::now(); + for group in &groups { + iface + .leave_multicast_group(&mut device, *group, timestamp) + .unwrap(); + } + + let leaves = recv_igmp(&mut device, timestamp); + assert_eq!(leaves.len(), 2); + for (i, group_addr) in groups.iter().cloned().enumerate() { + assert_eq!(leaves[i].0.next_header, IpProtocol::Igmp); + assert_eq!(leaves[i].0.dst_addr, Ipv4Address::MULTICAST_ALL_ROUTERS); + assert_eq!(leaves[i].1, IgmpRepr::LeaveGroup { group_addr }); + } +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-raw", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-raw", feature = "medium-ethernet"))] +fn test_raw_socket_no_reply(#[case] medium: Medium) { + use crate::wire::{IpVersion, Ipv4Packet, UdpPacket, UdpRepr}; + + let (mut iface, mut sockets, _) = setup(medium); + + let packets = 1; + let rx_buffer = + raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; packets], vec![0; 48 * 1]); + let tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; 48 * packets], + ); + let raw_socket = raw::Socket::new(IpVersion::Ipv4, IpProtocol::Udp, rx_buffer, tx_buffer); + sockets.add(raw_socket); + + let src_addr = Ipv4Address([127, 0, 0, 2]); + let dst_addr = Ipv4Address([127, 0, 0, 1]); + + const PAYLOAD_LEN: usize = 10; + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + let ipv4_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + PAYLOAD_LEN, + }; + + // Emit to frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + udp_repr.header_len() + PAYLOAD_LEN]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes), + &ChecksumCapabilities::default(), + ); + udp_repr.emit( + &mut UdpPacket::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &src_addr.into(), + &dst_addr.into(), + PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes) + }; + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-raw", feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all( + feature = "socket-raw", + feature = "socket-udp", + feature = "medium-ethernet" +))] +fn test_raw_socket_with_udp_socket(#[case] medium: Medium) { + use crate::wire::{IpEndpoint, IpVersion, Ipv4Packet, UdpPacket, UdpRepr}; + + static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; + + let (mut iface, mut sockets, _) = setup(medium); + + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + let udp_socket_handle = sockets.add(udp_socket); + + // Bind the socket to port 68 + let socket = sockets.get_mut::(udp_socket_handle); + assert_eq!(socket.bind(68), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + + let packets = 1; + let raw_rx_buffer = + raw::PacketBuffer::new(vec![raw::PacketMetadata::EMPTY; packets], vec![0; 48 * 1]); + let raw_tx_buffer = raw::PacketBuffer::new( + vec![raw::PacketMetadata::EMPTY; packets], + vec![0; 48 * packets], + ); + let raw_socket = raw::Socket::new( + IpVersion::Ipv4, + IpProtocol::Udp, + raw_rx_buffer, + raw_tx_buffer, + ); + sockets.add(raw_socket); + + let src_addr = Ipv4Address([127, 0, 0, 2]); + let dst_addr = Ipv4Address([127, 0, 0, 1]); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + UDP_PAYLOAD.len()]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + let ipv4_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + }; + + // Emit to frame + let mut bytes = vec![0u8; ipv4_repr.buffer_len() + udp_repr.header_len() + UDP_PAYLOAD.len()]; + let frame = { + ipv4_repr.emit( + &mut Ipv4Packet::new_unchecked(&mut bytes), + &ChecksumCapabilities::default(), + ); + udp_repr.emit( + &mut UdpPacket::new_unchecked(&mut bytes[ipv4_repr.buffer_len()..]), + &src_addr.into(), + &dst_addr.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + Ipv4Packet::new_unchecked(&bytes) + }; + + assert_eq!( + iface.inner.process_ipv4( + &mut sockets, + PacketMeta::default(), + &frame, + &mut iface.fragments + ), + None + ); + + // Make sure the UDP socket can still receive in presence of a Raw socket that handles UDP + let socket = sockets.get_mut::(udp_socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok(( + &UDP_PAYLOAD[..], + IpEndpoint::new(src_addr.into(), 67).into() + )) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-udp", feature = "medium-ethernet"))] +fn test_icmp_reply_size(#[case] medium: Medium) { + use crate::wire::IPV4_MIN_MTU as MIN_MTU; + const MAX_PAYLOAD_LEN: usize = 528; + + let (mut iface, mut sockets, _device) = setup(medium); + + let src_addr = Ipv4Address([192, 168, 1, 1]); + let dst_addr = Ipv4Address([192, 168, 1, 2]); + + // UDP packet that if not tructated will cause a icmp port unreachable reply + // to exeed the minimum mtu bytes in length. + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + MAX_PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + MAX_PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + + let ip_repr = Ipv4Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + MAX_PAYLOAD_LEN, + }; + let payload = packet.into_inner(); + + let expected_icmp_repr = Icmpv4Repr::DstUnreachable { + reason: Icmpv4DstUnreachable::PortUnreachable, + header: ip_repr, + data: &payload[..MAX_PAYLOAD_LEN], + }; + + let expected_ip_repr = Ipv4Repr { + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmp, + hop_limit: 64, + payload_len: expected_icmp_repr.buffer_len(), + }; + + assert_eq!( + expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), + MIN_MTU + ); + + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr.into(), + udp_repr, + false, + &vec![0x2a; MAX_PAYLOAD_LEN], + payload, + ), + Some(IpPacket::Icmpv4((expected_ip_repr, expected_icmp_repr))) + ); +} diff --git a/src/iface/interface/tests/ipv6.rs b/src/iface/interface/tests/ipv6.rs new file mode 100644 index 000000000..920d6da3b --- /dev/null +++ b/src/iface/interface/tests/ipv6.rs @@ -0,0 +1,732 @@ +use super::*; + +fn parse_ipv6(data: &[u8]) -> crate::wire::Result> { + let ipv6_header = Ipv6Packet::new_checked(data)?; + let ipv6 = Ipv6Repr::parse(&ipv6_header)?; + + match ipv6.next_header { + IpProtocol::HopByHop => todo!(), + IpProtocol::Icmp => todo!(), + IpProtocol::Igmp => todo!(), + IpProtocol::Tcp => todo!(), + IpProtocol::Udp => todo!(), + IpProtocol::Ipv6Route => todo!(), + IpProtocol::Ipv6Frag => todo!(), + IpProtocol::Icmpv6 => { + let icmp = Icmpv6Repr::parse( + &ipv6.src_addr.into(), + &ipv6.dst_addr.into(), + &Icmpv6Packet::new_checked(ipv6_header.payload())?, + &Default::default(), + )?; + Ok(IpPacket::Icmpv6((ipv6, icmp))) + } + IpProtocol::Ipv6NoNxt => todo!(), + IpProtocol::Ipv6Opts => todo!(), + IpProtocol::Unknown(_) => todo!(), + } +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn multicast_source_address(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xff, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x1, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn hop_by_hop_skip_with_icmp(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (skipped) + // - ICMP echo request + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0x1, 0x0, 0xf, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + Icmpv6Repr::EchoReply { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + }, + ))); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn hop_by_hop_discard_with_icmp(#[case] medium: Medium) { + // The following contains: + // - IPv6 header + // - Hop-by-hop, with options: + // - PADN (skipped) + // - Unknown option (discard) + // - ICMP echo request + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x1b, 0x0, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x3a, 0x0, 0x1, 0x0, 0x40, 0x0, 0x1, 0x0, 0x80, 0x0, 0x2c, 0x88, + 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn imcp_empty_echo_request(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x8, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x84, 0x3c, 0x0, 0x0, 0x0, 0x0, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 8, + }, + Icmpv6Repr::EchoRequest { + ident: 0, + seq_no: 0, + data: b"", + } + ))) + ); + + let response = Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 8, + }, + Icmpv6Repr::EchoReply { + ident: 0, + seq_no: 0, + data: b"", + }, + ))); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_request(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x13, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x80, 0x0, 0x2c, 0x88, 0x0, 0x2a, 0x1, 0xa4, 0x4c, 0x6f, 0x72, + 0x65, 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + Icmpv6Repr::EchoRequest { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + } + ))) + ); + + let response = Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + Icmpv6Repr::EchoReply { + ident: 42, + seq_no: 420, + data: b"Lorem Ipsum", + }, + ))); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_reply_as_input(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x13, 0x3a, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x81, 0x0, 0x2d, 0x56, 0x0, 0x0, 0x0, 0x0, 0x4c, 0x6f, 0x72, 0x65, + 0x6d, 0x20, 0x49, 0x70, 0x73, 0x75, 0x6d, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 19, + }, + Icmpv6Repr::EchoReply { + ident: 0, + seq_no: 0, + data: b"Lorem Ipsum", + } + ))) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn unknown_proto_with_multicast_dst_address(#[case] medium: Medium) { + // Since the destination address is multicast, we should not answer with an ICMPv6 message. + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xff, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ip(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn unknown_proto(#[case] medium: Medium) { + // Since the destination address is multicast, we should not answer with an ICMPv6 message. + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x0, 0xc, 0x40, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, + ]; + + let response = Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 48, + }, + Icmpv6Repr::ParamProblem { + reason: Icmpv6ParamProblem::UnrecognizedNxtHdr, + pointer: 40, + header: Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 64, + next_header: IpProtocol::Unknown(0x0c), + payload_len: 0, + }, + data: &[], + }, + ))); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); +} + +#[rstest] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn ndsic_neighbor_advertisement_ethernet(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x20, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0x9f, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x1, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 32, + }, + Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[0, 0, 0, 0, 0, 1])), + }) + ))) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::Found(HardwareAddress::Ethernet(EthernetAddress::from_bytes(&[ + 0, 0, 0, 0, 0, 1 + ]))), + ); +} + +#[rstest] +#[case::ethernet(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn ndsic_neighbor_advertisement_ethernet_multicast_addr(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x20, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0xa0, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x1, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 32, + }, + Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[ + 0xff, 0xff, 0xff, 0xff, 0xff, 0xff + ])), + }) + ))) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::NotFound, + ); +} + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn ndsic_neighbor_advertisement_ieee802154(#[case] medium: Medium) { + let data = [ + 0x60, 0x0, 0x0, 0x0, 0x0, 0x28, 0x3a, 0xff, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, 0xbe, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x1, 0x88, 0x0, 0x3b, 0x96, 0x40, 0x0, 0x0, 0x0, 0xfe, 0x80, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0x2, 0x2, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + assert_eq!( + parse_ipv6(&data), + Ok(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002]), + dst_addr: Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0001]), + hop_limit: 255, + next_header: IpProtocol::Icmpv6, + payload_len: 40, + }, + Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 0x0002]), + lladdr: Some(RawHardwareAddress::from_bytes(&[0, 0, 0, 0, 0, 0, 0, 1])), + }) + ))) + ); + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ipv6( + &mut sockets, + PacketMeta::default(), + &Ipv6Packet::new_checked(&data).unwrap() + ), + response + ); + + assert_eq!( + iface.inner.neighbor_cache.lookup( + &IpAddress::Ipv6(Ipv6Address::from_parts(&[0xfdbe, 0, 0, 0, 0, 0, 0, 0x0002])), + iface.inner.now, + ), + NeighborAnswer::Found(HardwareAddress::Ieee802154(Ieee802154Address::from_bytes( + &[0, 0, 0, 0, 0, 0, 0, 1] + ))), + ); +} + +#[rstest] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +fn test_handle_valid_ndisc_request(#[case] medium: Medium) { + let (mut iface, mut sockets, _device) = setup(medium); + + let mut eth_bytes = vec![0u8; 86]; + + let local_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 1); + let remote_ip_addr = Ipv6Address::new(0xfdbe, 0, 0, 0, 0, 0, 0, 2); + let local_hw_addr = EthernetAddress([0x00, 0x00, 0x00, 0x00, 0x00, 0x00]); + let remote_hw_addr = EthernetAddress([0x52, 0x54, 0x00, 0x00, 0x00, 0x00]); + + let solicit = Icmpv6Repr::Ndisc(NdiscRepr::NeighborSolicit { + target_addr: local_ip_addr, + lladdr: Some(remote_hw_addr.into()), + }); + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: remote_ip_addr, + dst_addr: local_ip_addr.solicited_node(), + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: solicit.buffer_len(), + }); + + let mut frame = EthernetFrame::new_unchecked(&mut eth_bytes); + frame.set_dst_addr(EthernetAddress([0x33, 0x33, 0x00, 0x00, 0x00, 0x00])); + frame.set_src_addr(remote_hw_addr); + frame.set_ethertype(EthernetProtocol::Ipv6); + ip_repr.emit(frame.payload_mut(), &ChecksumCapabilities::default()); + solicit.emit( + &remote_ip_addr.into(), + &local_ip_addr.solicited_node().into(), + &mut Icmpv6Packet::new_unchecked(&mut frame.payload_mut()[ip_repr.header_len()..]), + &ChecksumCapabilities::default(), + ); + + let icmpv6_expected = Icmpv6Repr::Ndisc(NdiscRepr::NeighborAdvert { + flags: NdiscNeighborFlags::SOLICITED, + target_addr: local_ip_addr, + lladdr: Some(local_hw_addr.into()), + }); + + let ipv6_expected = Ipv6Repr { + src_addr: local_ip_addr, + dst_addr: remote_ip_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 0xff, + payload_len: icmpv6_expected.buffer_len(), + }; + + // Ensure an Neighbor Solicitation triggers a Neighbor Advertisement + assert_eq!( + iface.inner.process_ethernet( + &mut sockets, + PacketMeta::default(), + frame.into_inner(), + &mut iface.fragments + ), + Some(EthernetPacket::Ip(IpPacket::Icmpv6(( + ipv6_expected, + icmpv6_expected + )))) + ); + + // Ensure the address of the requestor was entered in the cache + assert_eq!( + iface.inner.lookup_hardware_addr( + MockTxToken, + &IpAddress::Ipv6(local_ip_addr), + &IpAddress::Ipv6(remote_ip_addr), + &mut iface.fragmenter, + ), + Ok((HardwareAddress::Ethernet(remote_hw_addr), MockTxToken)) + ); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(feature = "medium-ip")] +#[case(Medium::Ethernet)] +#[cfg(feature = "medium-ethernet")] +#[case(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn test_solicited_node_addrs(#[case] medium: Medium) { + let (mut iface, _, _) = setup(medium); + let mut new_addrs = heapless::Vec::::new(); + new_addrs + .push(IpCidr::new(IpAddress::v6(0xfe80, 0, 0, 0, 1, 2, 0, 2), 64)) + .unwrap(); + new_addrs + .push(IpCidr::new( + IpAddress::v6(0xfe80, 0, 0, 0, 3, 4, 0, 0xffff), + 64, + )) + .unwrap(); + iface.update_ip_addrs(|addrs| { + new_addrs.extend(addrs.to_vec()); + *addrs = new_addrs; + }); + assert!(iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0002))); + assert!(iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0xffff))); + assert!(!iface + .inner + .has_solicited_node(Ipv6Address::new(0xff02, 0, 0, 0, 0, 1, 0xff00, 0x0003))); +} + +#[rstest] +#[case(Medium::Ip)] +#[cfg(all(feature = "socket-udp", feature = "medium-ip"))] +#[case(Medium::Ethernet)] +#[cfg(all(feature = "socket-udp", feature = "medium-ethernet"))] +#[case(Medium::Ieee802154)] +#[cfg(all(feature = "socket-udp", feature = "medium-ieee802154"))] +fn test_icmp_reply_size(#[case] medium: Medium) { + use crate::wire::Icmpv6DstUnreachable; + use crate::wire::IPV6_MIN_MTU as MIN_MTU; + const MAX_PAYLOAD_LEN: usize = 1192; + + let (mut iface, mut sockets, _device) = setup(medium); + + let src_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + let dst_addr = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 2); + + // UDP packet that if not tructated will cause a icmp port unreachable reply + // to exeed the minimum mtu bytes in length. + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + let mut bytes = vec![0xff; udp_repr.header_len() + MAX_PAYLOAD_LEN]; + let mut packet = UdpPacket::new_unchecked(&mut bytes[..]); + udp_repr.emit( + &mut packet, + &src_addr.into(), + &dst_addr.into(), + MAX_PAYLOAD_LEN, + |buf| fill_slice(buf, 0x2a), + &ChecksumCapabilities::default(), + ); + + let ip_repr = Ipv6Repr { + src_addr, + dst_addr, + next_header: IpProtocol::Udp, + hop_limit: 64, + payload_len: udp_repr.header_len() + MAX_PAYLOAD_LEN, + }; + let payload = packet.into_inner(); + + let expected_icmp_repr = Icmpv6Repr::DstUnreachable { + reason: Icmpv6DstUnreachable::PortUnreachable, + header: ip_repr, + data: &payload[..MAX_PAYLOAD_LEN], + }; + + let expected_ip_repr = Ipv6Repr { + src_addr: dst_addr, + dst_addr: src_addr, + next_header: IpProtocol::Icmpv6, + hop_limit: 64, + payload_len: expected_icmp_repr.buffer_len(), + }; + + assert_eq!( + expected_ip_repr.buffer_len() + expected_icmp_repr.buffer_len(), + MIN_MTU + ); + + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr.into(), + udp_repr, + false, + &vec![0x2a; MAX_PAYLOAD_LEN], + payload, + ), + Some(IpPacket::Icmpv6((expected_ip_repr, expected_icmp_repr))) + ); +} diff --git a/src/iface/interface/tests/mod.rs b/src/iface/interface/tests/mod.rs new file mode 100644 index 000000000..c781f0c76 --- /dev/null +++ b/src/iface/interface/tests/mod.rs @@ -0,0 +1,183 @@ +#[cfg(feature = "proto-ipv4")] +mod ipv4; +#[cfg(feature = "proto-ipv6")] +mod ipv6; +#[cfg(feature = "proto-sixlowpan")] +mod sixlowpan; + +#[cfg(feature = "proto-igmp")] +use std::vec::Vec; + +use rstest::*; + +use super::*; + +use crate::iface::Interface; +use crate::phy::{ChecksumCapabilities, Loopback}; +use crate::time::Instant; + +#[allow(unused)] +fn fill_slice(s: &mut [u8], val: u8) { + for x in s.iter_mut() { + *x = val + } +} + +fn setup<'a>(medium: Medium) -> (Interface, SocketSet<'a>, Loopback) { + let mut device = Loopback::new(medium); + + let config = Config::new(match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => HardwareAddress::Ethernet(Default::default()), + #[cfg(feature = "medium-ip")] + Medium::Ip => HardwareAddress::Ip, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => HardwareAddress::Ieee802154(Default::default()), + }); + + let mut iface = Interface::new(config, &mut device, Instant::ZERO); + + #[cfg(feature = "proto-ipv4")] + { + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v4(127, 0, 0, 1), 8)) + .unwrap(); + }); + } + + #[cfg(feature = "proto-ipv6")] + { + iface.update_ip_addrs(|ip_addrs| { + ip_addrs + .push(IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 1), 128)) + .unwrap(); + ip_addrs + .push(IpCidr::new(IpAddress::v6(0xfdbe, 0, 0, 0, 0, 0, 0, 1), 64)) + .unwrap(); + }); + } + + (iface, SocketSet::new(vec![]), device) +} + +#[cfg(feature = "proto-igmp")] +fn recv_all(device: &mut Loopback, timestamp: Instant) -> Vec> { + let mut pkts = Vec::new(); + while let Some((rx, _tx)) = device.receive(timestamp) { + rx.consume(|pkt| { + pkts.push(pkt.to_vec()); + }); + } + pkts +} + +#[derive(Debug, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct MockTxToken; + +impl TxToken for MockTxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut junk = [0; 1536]; + f(&mut junk[..len]) + } +} + +#[test] +#[should_panic(expected = "The hardware address does not match the medium of the interface.")] +#[cfg(all(feature = "medium-ip", feature = "medium-ethernet"))] +fn test_new_panic() { + let mut device = Loopback::new(Medium::Ethernet); + let config = Config::new(HardwareAddress::Ip); + Interface::new(config, &mut device, Instant::ZERO); +} + +#[rstest] +#[cfg(feature = "default")] +fn test_handle_udp_broadcast( + #[values(Medium::Ip, Medium::Ethernet, Medium::Ieee802154)] medium: Medium, +) { + use crate::wire::IpEndpoint; + + static UDP_PAYLOAD: [u8; 5] = [0x48, 0x65, 0x6c, 0x6c, 0x6f]; + + let (mut iface, mut sockets, _device) = setup(medium); + + let rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + let tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 15]); + + let udp_socket = udp::Socket::new(rx_buffer, tx_buffer); + + let mut udp_bytes = vec![0u8; 13]; + let mut packet = UdpPacket::new_unchecked(&mut udp_bytes); + + let socket_handle = sockets.add(udp_socket); + + #[cfg(feature = "proto-ipv6")] + let src_ip = Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1); + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + let src_ip = Ipv4Address::new(0x7f, 0x00, 0x00, 0x02); + + let udp_repr = UdpRepr { + src_port: 67, + dst_port: 68, + }; + + #[cfg(feature = "proto-ipv6")] + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: src_ip, + dst_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 0x40, + }); + #[cfg(all(not(feature = "proto-ipv6"), feature = "proto-ipv4"))] + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: src_ip, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: udp_repr.header_len() + UDP_PAYLOAD.len(), + hop_limit: 0x40, + }); + + // Bind the socket to port 68 + let socket = sockets.get_mut::(socket_handle); + assert_eq!(socket.bind(68), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + + udp_repr.emit( + &mut packet, + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(&UDP_PAYLOAD), + &ChecksumCapabilities::default(), + ); + + // Packet should be handled by bound UDP socket + assert_eq!( + iface.inner.process_udp( + &mut sockets, + PacketMeta::default(), + ip_repr, + udp_repr, + false, + &UDP_PAYLOAD, + packet.into_inner(), + ), + None + ); + + // Make sure the payload to the UDP packet processed by process_udp is + // appended to the bound sockets rx_buffer + let socket = sockets.get_mut::(socket_handle); + assert!(socket.can_recv()); + assert_eq!( + socket.recv(), + Ok((&UDP_PAYLOAD[..], IpEndpoint::new(src_ip.into(), 67).into())) + ); +} diff --git a/src/iface/interface/tests/sixlowpan.rs b/src/iface/interface/tests/sixlowpan.rs new file mode 100644 index 000000000..011f45a8b --- /dev/null +++ b/src/iface/interface/tests/sixlowpan.rs @@ -0,0 +1,411 @@ +use super::*; + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn ieee802154_wrong_pan_id(#[case] medium: Medium) { + let data = [ + 0x41, 0xcc, 0x3b, 0xff, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0x62, 0x3a, + 0xa6, 0x34, 0x57, 0x29, 0x1c, 0x26, + ]; + + let response = None; + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ieee802154( + &mut sockets, + PacketMeta::default(), + &data[..], + &mut iface.fragments + ), + response, + ); +} + +#[rstest] +#[case::ieee802154(Medium::Ieee802154)] +#[cfg(feature = "medium-ieee802154")] +fn icmp_echo_request(#[case] medium: Medium) { + let data = [ + 0x41, 0xcc, 0x3b, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0x62, 0x3a, + 0xa6, 0x34, 0x57, 0x29, 0x1c, 0x26, 0x6a, 0x33, 0x0a, 0x62, 0x17, 0x3a, 0x80, 0x00, 0xb0, + 0xe3, 0x00, 0x04, 0x00, 0x01, 0x82, 0xf2, 0x82, 0x64, 0x00, 0x00, 0x00, 0x00, 0x66, 0x23, + 0x0c, 0x00, 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, + 0x19, 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + 0x37, + ]; + + let response = Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0x180b, 0x4242, 0x4242, 0x4242]), + dst_addr: Ipv6Address::from_parts(&[0xfe80, 0, 0, 0, 0x241c, 0x2957, 0x34a6, 0x3a62]), + hop_limit: 64, + next_header: IpProtocol::Icmpv6, + payload_len: 64, + }, + Icmpv6Repr::EchoReply { + ident: 4, + seq_no: 1, + data: &[ + 0x82, 0xf2, 0x82, 0x64, 0x00, 0x00, 0x00, 0x00, 0x66, 0x23, 0x0c, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, + 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, + ], + }, + ))); + + let (mut iface, mut sockets, _device) = setup(medium); + + assert_eq!( + iface.inner.process_ieee802154( + &mut sockets, + PacketMeta::default(), + &data[..], + &mut iface.fragments + ), + response, + ); +} + +#[test] +#[cfg(feature = "proto-sixlowpan-fragmentation")] +fn test_echo_request_sixlowpan_128_bytes() { + use crate::phy::Checksum; + + let (mut iface, mut sockets, mut device) = setup(Medium::Ieee802154); + // TODO: modify the example, such that we can also test if the checksum is correctly + // computed. + iface.inner.caps.checksum.icmpv6 = Checksum::None; + + assert_eq!(iface.inner.caps.medium, Medium::Ieee802154); + let now = iface.inner.now(); + + iface.inner.neighbor_cache.fill( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0x2, 0, 0, 0, 0, 0, 0, 0]).into(), + HardwareAddress::Ieee802154(Ieee802154Address::default()), + now, + ); + + let mut ieee802154_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(5), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: Some(Ieee802154Pan(0xbeef)), + dst_addr: Some(Ieee802154Address::Extended([ + 0x90, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, 0x76, + ])), + src_pan_id: Some(Ieee802154Pan(0xbeef)), + src_addr: Some(Ieee802154Address::Extended([ + 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, + ])), + }; + + // NOTE: this data is retrieved from tests with Contiki-NG + + let request_first_part_packet = SixlowpanFragPacket::new_checked(&[ + 0xc0, 0xb0, 0x00, 0x8e, 0x6a, 0x33, 0x05, 0x25, 0x2c, 0x3a, 0x80, 0x00, 0xe0, 0x71, 0x00, + 0x27, 0x00, 0x02, 0xa2, 0xc2, 0x2d, 0x63, 0x00, 0x00, 0x00, 0x00, 0xd9, 0x5e, 0x0c, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, + 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, + 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, + 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, + 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, + ]) + .unwrap(); + + let request_first_part_iphc_packet = + SixlowpanIphcPacket::new_checked(request_first_part_packet.payload()).unwrap(); + + let request_first_part_iphc_repr = SixlowpanIphcRepr::parse( + &request_first_part_iphc_packet, + ieee802154_repr.src_addr, + ieee802154_repr.dst_addr, + &iface.inner.sixlowpan_address_context, + ) + .unwrap(); + + assert_eq!( + request_first_part_iphc_repr.src_addr, + Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, + 0x1a, + ]), + ); + assert_eq!( + request_first_part_iphc_repr.dst_addr, + Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x92, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, + 0x76, + ]), + ); + + let request_second_part = [ + 0xe0, 0xb0, 0x00, 0x8e, 0x10, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + &request_first_part_packet.into_inner(), + &mut iface.fragments + ), + None + ); + + ieee802154_repr.sequence_number = Some(6); + + // data that was generated when using `ping -s 128` + let data = &[ + 0xa2, 0xc2, 0x2d, 0x63, 0x00, 0x00, 0x00, 0x00, 0xd9, 0x5e, 0x0c, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, 0x1a, 0x1b, 0x1c, 0x1d, + 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, 0x28, 0x29, 0x2a, 0x2b, 0x2c, + 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, + 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, 0x44, 0x45, 0x46, 0x47, 0x48, 0x49, 0x4a, + 0x4b, 0x4c, 0x4d, 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, + 0x5a, 0x5b, 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, + 0x69, 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ]; + + let result = iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + &request_second_part, + &mut iface.fragments, + ); + + assert_eq!( + result, + Some(IpPacket::Icmpv6(( + Ipv6Repr { + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x92, 0xfc, 0x48, 0xc2, 0xa4, 0x41, + 0xfc, 0x76, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, + 0xb, 0x1a, + ]), + next_header: IpProtocol::Icmpv6, + payload_len: 136, + hop_limit: 64, + }, + Icmpv6Repr::EchoReply { + ident: 39, + seq_no: 2, + data, + } + ))) + ); + + iface.inner.neighbor_cache.fill( + IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, 0x1a, + ])), + HardwareAddress::Ieee802154(Ieee802154Address::default()), + Instant::now(), + ); + + let tx_token = device.transmit(Instant::now()).unwrap(); + iface.inner.dispatch_ieee802154( + Ieee802154Address::default(), + tx_token, + PacketMeta::default(), + result.unwrap(), + &mut iface.fragmenter, + ); + + assert_eq!( + device.queue[0], + &[ + 0x41, 0xcc, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xc0, 0xb0, 0x5, 0x4e, 0x7a, 0x11, 0x3a, 0x92, 0xfc, 0x48, 0xc2, + 0xa4, 0x41, 0xfc, 0x76, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, 0xb, 0x1a, 0x81, 0x0, 0x0, + 0x0, 0x0, 0x27, 0x0, 0x2, 0xa2, 0xc2, 0x2d, 0x63, 0x0, 0x0, 0x0, 0x0, 0xd9, 0x5e, 0xc, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x10, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x17, 0x18, 0x19, + 0x1a, 0x1b, 0x1c, 0x1d, 0x1e, 0x1f, 0x20, 0x21, 0x22, 0x23, 0x24, 0x25, 0x26, 0x27, + 0x28, 0x29, 0x2a, 0x2b, 0x2c, 0x2d, 0x2e, 0x2f, 0x30, 0x31, 0x32, 0x33, 0x34, 0x35, + 0x36, 0x37, 0x38, 0x39, 0x3a, 0x3b, 0x3c, 0x3d, 0x3e, 0x3f, 0x40, 0x41, 0x42, 0x43, + 0x44, 0x45, 0x46, 0x47, + ] + ); + + iface.poll(Instant::now(), &mut device, &mut sockets); + + assert_eq!( + device.queue[1], + &[ + 0x41, 0xcc, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xe0, 0xb0, 0x5, 0x4e, 0xf, 0x48, 0x49, 0x4a, 0x4b, 0x4c, 0x4d, + 0x4e, 0x4f, 0x50, 0x51, 0x52, 0x53, 0x54, 0x55, 0x56, 0x57, 0x58, 0x59, 0x5a, 0x5b, + 0x5c, 0x5d, 0x5e, 0x5f, 0x60, 0x61, 0x62, 0x63, 0x64, 0x65, 0x66, 0x67, 0x68, 0x69, + 0x6a, 0x6b, 0x6c, 0x6d, 0x6e, 0x6f, 0x70, 0x71, 0x72, 0x73, 0x74, 0x75, 0x76, 0x77, + 0x78, 0x79, 0x7a, 0x7b, 0x7c, 0x7d, 0x7e, 0x7f, + ] + ); +} + +#[test] +#[cfg(feature = "proto-sixlowpan-fragmentation")] +fn test_sixlowpan_udp_with_fragmentation() { + use crate::phy::Checksum; + + let mut ieee802154_repr = Ieee802154Repr { + frame_type: Ieee802154FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: false, + sequence_number: Some(5), + pan_id_compression: true, + frame_version: Ieee802154FrameVersion::Ieee802154_2003, + dst_pan_id: Some(Ieee802154Pan(0xbeef)), + dst_addr: Some(Ieee802154Address::Extended([ + 0x90, 0xfc, 0x48, 0xc2, 0xa4, 0x41, 0xfc, 0x76, + ])), + src_pan_id: Some(Ieee802154Pan(0xbeef)), + src_addr: Some(Ieee802154Address::Extended([ + 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, + ])), + }; + + let (mut iface, mut sockets, mut device) = setup(Medium::Ieee802154); + iface.inner.caps.checksum.udp = Checksum::None; + + let udp_rx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1024 * 4]); + let udp_tx_buffer = udp::PacketBuffer::new(vec![udp::PacketMetadata::EMPTY], vec![0; 1024 * 4]); + let udp_socket = udp::Socket::new(udp_rx_buffer, udp_tx_buffer); + let udp_socket_handle = sockets.add(udp_socket); + + { + let socket = sockets.get_mut::(udp_socket_handle); + assert_eq!(socket.bind(6969), Ok(())); + assert!(!socket.can_recv()); + assert!(socket.can_send()); + } + + let udp_first_part = &[ + 0xc0, 0xbc, 0x00, 0x92, 0x6e, 0x33, 0x07, 0xe7, 0xdc, 0xf0, 0xd3, 0xc9, 0x1b, 0x39, 0xbf, + 0xa0, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, + 0x6c, 0x6f, 0x72, 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, + 0x6f, 0x6e, 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, + 0x69, 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, 0x2e, 0x20, 0x49, 0x6e, + 0x20, 0x61, 0x74, 0x20, 0x72, 0x68, 0x6f, 0x6e, 0x63, 0x75, 0x73, 0x20, 0x74, 0x6f, 0x72, + 0x74, 0x6f, 0x72, 0x2e, 0x20, 0x43, 0x72, 0x61, 0x73, 0x20, 0x62, 0x6c, 0x61, 0x6e, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + udp_first_part, + &mut iface.fragments + ), + None + ); + + ieee802154_repr.sequence_number = Some(6); + + let udp_second_part = &[ + 0xe0, 0xbc, 0x00, 0x92, 0x11, 0x64, 0x69, 0x74, 0x20, 0x74, 0x65, 0x6c, 0x6c, 0x75, 0x73, + 0x20, 0x64, 0x69, 0x61, 0x6d, 0x2c, 0x20, 0x76, 0x61, 0x72, 0x69, 0x75, 0x73, 0x20, 0x76, + 0x65, 0x73, 0x74, 0x69, 0x62, 0x75, 0x6c, 0x75, 0x6d, 0x20, 0x6e, 0x69, 0x62, 0x68, 0x20, + 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x64, 0x6f, 0x20, 0x6e, 0x65, 0x63, 0x2e, + ]; + + assert_eq!( + iface.inner.process_sixlowpan( + &mut sockets, + PacketMeta::default(), + &ieee802154_repr, + udp_second_part, + &mut iface.fragments + ), + None + ); + + let socket = sockets.get_mut::(udp_socket_handle); + + let udp_data = b"Lorem ipsum dolor sit amet, consectetur adipiscing elit. \ +In at rhoncus tortor. Cras blandit tellus diam, varius vestibulum nibh commodo nec."; + assert_eq!( + socket.recv(), + Ok(( + &udp_data[..], + IpEndpoint { + addr: IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x40, 0x42, 0x42, 0x42, 0x42, 0x42, + 0xb, 0x1a, + ])), + port: 54217, + } + .into() + )) + ); + + let tx_token = device.transmit(Instant::now()).unwrap(); + iface.inner.dispatch_ieee802154( + Ieee802154Address::default(), + tx_token, + PacketMeta::default(), + IpPacket::Udp(( + IpRepr::Ipv6(Ipv6Repr { + src_addr: Ipv6Address::default(), + dst_addr: Ipv6Address::default(), + next_header: IpProtocol::Udp, + payload_len: udp_data.len(), + hop_limit: 64, + }), + UdpRepr { + src_port: 1234, + dst_port: 1234, + }, + udp_data, + )), + &mut iface.fragmenter, + ); + + iface.poll(Instant::now(), &mut device, &mut sockets); + + assert_eq!( + device.queue[0], + &[ + 0x41, 0xcc, 0x3, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xc0, 0xb4, 0x5, 0x4e, 0x7e, 0x40, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xf0, 0x4, 0xd2, 0x4, 0xd2, 0xf6, + 0x4d, 0x4c, 0x6f, 0x72, 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, + 0x6f, 0x6c, 0x6f, 0x72, 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, + 0x20, 0x63, 0x6f, 0x6e, 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, + 0x64, 0x69, 0x70, 0x69, 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, + 0x2e, 0x20, 0x49, 0x6e, 0x20, 0x61, 0x74, 0x20, 0x72, 0x68, 0x6f, 0x6e, 0x63, 0x75, + 0x73, 0x20, 0x74, + ] + ); + + assert_eq!( + device.queue[1], + &[ + 0x41, 0xcc, 0x4, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x0, 0xe0, 0xb4, 0x5, 0x4e, 0xf, 0x6f, 0x72, 0x74, 0x6f, 0x72, 0x2e, + 0x20, 0x43, 0x72, 0x61, 0x73, 0x20, 0x62, 0x6c, 0x61, 0x6e, 0x64, 0x69, 0x74, 0x20, + 0x74, 0x65, 0x6c, 0x6c, 0x75, 0x73, 0x20, 0x64, 0x69, 0x61, 0x6d, 0x2c, 0x20, 0x76, + 0x61, 0x72, 0x69, 0x75, 0x73, 0x20, 0x76, 0x65, 0x73, 0x74, 0x69, 0x62, 0x75, 0x6c, + 0x75, 0x6d, 0x20, 0x6e, 0x69, 0x62, 0x68, 0x20, 0x63, 0x6f, 0x6d, 0x6d, 0x6f, 0x64, + 0x6f, 0x20, 0x6e, 0x65, 0x63, 0x2e, + ] + ); +} diff --git a/src/iface/mod.rs b/src/iface/mod.rs index d38593f70..710587ae0 100644 --- a/src/iface/mod.rs +++ b/src/iface/mod.rs @@ -4,19 +4,20 @@ The `iface` module deals with the *network interfaces*. It filters incoming fram provides lookup and caching of hardware addresses, and handles management packets. */ -#[cfg(feature = "ethernet")] +#[cfg(any(feature = "proto-ipv4", feature = "proto-sixlowpan"))] +mod fragmentation; +mod interface; +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] mod neighbor; mod route; -#[cfg(feature = "ethernet")] -mod ethernet; +#[cfg(feature = "proto-rpl")] +mod rpl; +mod socket_meta; +mod socket_set; -#[cfg(feature = "ethernet")] -pub use self::neighbor::Neighbor as Neighbor; -#[cfg(feature = "ethernet")] -pub(crate) use self::neighbor::Answer as NeighborAnswer; -#[cfg(feature = "ethernet")] -pub use self::neighbor::Cache as NeighborCache; -pub use self::route::{Route, Routes}; -#[cfg(feature = "ethernet")] -pub use self::ethernet::{Interface as EthernetInterface, - InterfaceBuilder as EthernetInterfaceBuilder}; +#[cfg(feature = "proto-igmp")] +pub use self::interface::MulticastError; +pub use self::interface::{Config, Interface, InterfaceInner as Context}; + +pub use self::route::{Route, RouteTableFull, Routes}; +pub use self::socket_set::{SocketHandle, SocketSet, SocketStorage}; diff --git a/src/iface/neighbor.rs b/src/iface/neighbor.rs index dd042b991..710454edf 100644 --- a/src/iface/neighbor.rs +++ b/src/iface/neighbor.rs @@ -1,279 +1,293 @@ // Heads up! Before working on this file you should read, at least, // the parts of RFC 1122 that discuss ARP. -use managed::ManagedMap; +use heapless::LinearMap; -use wire::{EthernetAddress, IpAddress}; -use time::{Duration, Instant}; +use crate::config::IFACE_NEIGHBOR_CACHE_COUNT; +use crate::time::{Duration, Instant}; +use crate::wire::{HardwareAddress, IpAddress}; /// A cached neighbor. /// /// A neighbor mapping translates from a protocol address to a hardware address, /// and contains the timestamp past which the mapping should be discarded. #[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Neighbor { - hardware_addr: EthernetAddress, - expires_at: Instant, + hardware_addr: HardwareAddress, + expires_at: Instant, } /// An answer to a neighbor cache lookup. #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub(crate) enum Answer { /// The neighbor address is in the cache and not expired. - Found(EthernetAddress), + Found(HardwareAddress), /// The neighbor address is not in the cache, or has expired. NotFound, /// The neighbor address is not in the cache, or has expired, /// and a lookup has been made recently. - RateLimited + RateLimited, +} + +impl Answer { + /// Returns whether a valid address was found. + pub(crate) fn found(&self) -> bool { + match self { + Answer::Found(_) => true, + _ => false, + } + } } /// A neighbor cache backed by a map. -/// -/// # Examples -/// -/// On systems with heap, this cache can be created with: -/// -/// ```rust -/// use std::collections::BTreeMap; -/// use smoltcp::iface::NeighborCache; -/// let mut neighbor_cache = NeighborCache::new(BTreeMap::new()); -/// ``` -/// -/// On systems without heap, use: -/// -/// ```rust -/// use smoltcp::iface::NeighborCache; -/// let mut neighbor_cache_storage = [None; 8]; -/// let mut neighbor_cache = NeighborCache::new(&mut neighbor_cache_storage[..]); -/// ``` #[derive(Debug)] -pub struct Cache<'a> { - storage: ManagedMap<'a, IpAddress, Neighbor>, +pub struct Cache { + storage: LinearMap, silent_until: Instant, - gc_threshold: usize - } -impl<'a> Cache<'a> { +impl Cache { /// Minimum delay between discovery requests, in milliseconds. - pub(crate) const SILENT_TIME: Duration = Duration { millis: 1_000 }; + pub(crate) const SILENT_TIME: Duration = Duration::from_millis(1_000); /// Neighbor entry lifetime, in milliseconds. - pub(crate) const ENTRY_LIFETIME: Duration = Duration { millis: 60_000 }; - - /// Default number of entries in the cache before GC kicks in - pub(crate) const GC_THRESHOLD: usize = 1024; + pub(crate) const ENTRY_LIFETIME: Duration = Duration::from_millis(60_000); - /// Create a cache. The backing storage is cleared upon creation. - /// - /// # Panics - /// This function panics if `storage.len() == 0`. - pub fn new(storage: T) -> Cache<'a> - where T: Into> { - - Cache::new_with_limit(storage, Cache::GC_THRESHOLD) - } - - pub fn new_with_limit(storage: T, gc_threshold: usize) -> Cache<'a> - where T: Into> { - let mut storage = storage.into(); - storage.clear(); - - Cache { storage, gc_threshold, silent_until: Instant::from_millis(0) } + /// Create a cache. + pub fn new() -> Self { + Self { + storage: LinearMap::new(), + silent_until: Instant::from_millis(0), + } } - pub fn fill(&mut self, protocol_addr: IpAddress, hardware_addr: EthernetAddress, - timestamp: Instant) { + pub fn fill( + &mut self, + protocol_addr: IpAddress, + hardware_addr: HardwareAddress, + timestamp: Instant, + ) { debug_assert!(protocol_addr.is_unicast()); debug_assert!(hardware_addr.is_unicast()); - #[cfg(any(feature = "std", feature = "alloc"))] - let current_storage_size = self.storage.len(); - - match self.storage { - ManagedMap::Borrowed(_) => (), - #[cfg(any(feature = "std", feature = "alloc"))] - ManagedMap::Owned(ref mut map) => { - if current_storage_size >= self.gc_threshold { - let new_btree_map = map.into_iter() - .map(|(key, value)| (*key, *value)) - .filter(|(_, v)| timestamp < v.expires_at) - .collect(); - - *map = new_btree_map; - } - } - }; let neighbor = Neighbor { - expires_at: timestamp + Self::ENTRY_LIFETIME, hardware_addr + expires_at: timestamp + Self::ENTRY_LIFETIME, + hardware_addr, }; match self.storage.insert(protocol_addr, neighbor) { Ok(Some(old_neighbor)) => { if old_neighbor.hardware_addr != hardware_addr { - net_trace!("replaced {} => {} (was {})", - protocol_addr, hardware_addr, old_neighbor.hardware_addr); + net_trace!( + "replaced {} => {} (was {})", + protocol_addr, + hardware_addr, + old_neighbor.hardware_addr + ); } } Ok(None) => { net_trace!("filled {} => {} (was empty)", protocol_addr, hardware_addr); } Err((protocol_addr, neighbor)) => { - // If we're going down this branch, it means that a fixed-size cache storage - // is full, and we need to evict an entry. - let old_protocol_addr = match self.storage { - ManagedMap::Borrowed(ref mut pairs) => { - pairs - .iter() - .min_by_key(|pair_opt| { - let (_protocol_addr, neighbor) = pair_opt.unwrap(); - neighbor.expires_at - }) - .expect("empty neighbor cache storage") // unwraps min_by_key - .unwrap() // unwraps pair - .0 - } - // Owned maps can extend themselves. - #[cfg(any(feature = "std", feature = "alloc"))] - ManagedMap::Owned(_) => unreachable!() - }; - - let _old_neighbor = - self.storage.remove(&old_protocol_addr).unwrap(); + // If we're going down this branch, it means the cache is full, and we need to evict an entry. + let old_protocol_addr = *self + .storage + .iter() + .min_by_key(|(_, neighbor)| neighbor.expires_at) + .expect("empty neighbor cache storage") + .0; + + let _old_neighbor = self.storage.remove(&old_protocol_addr).unwrap(); match self.storage.insert(protocol_addr, neighbor) { Ok(None) => { - net_trace!("filled {} => {} (evicted {} => {})", - protocol_addr, hardware_addr, - old_protocol_addr, _old_neighbor.hardware_addr); + net_trace!( + "filled {} => {} (evicted {} => {})", + protocol_addr, + hardware_addr, + old_protocol_addr, + _old_neighbor.hardware_addr + ); } // We've covered everything else above. - _ => unreachable!() + _ => unreachable!(), } - } } } - pub(crate) fn lookup_pure(&self, protocol_addr: &IpAddress, timestamp: Instant) -> - Option { - if protocol_addr.is_broadcast() { - return Some(EthernetAddress::BROADCAST) - } + pub(crate) fn lookup(&self, protocol_addr: &IpAddress, timestamp: Instant) -> Answer { + assert!(protocol_addr.is_unicast()); - match self.storage.get(protocol_addr) { - Some(&Neighbor { expires_at, hardware_addr }) => { - if timestamp < expires_at { - return Some(hardware_addr) - } + if let Some(&Neighbor { + expires_at, + hardware_addr, + }) = self.storage.get(protocol_addr) + { + if timestamp < expires_at { + return Answer::Found(hardware_addr); } - None => () } - None + if timestamp < self.silent_until { + Answer::RateLimited + } else { + Answer::NotFound + } } - pub(crate) fn lookup(&mut self, protocol_addr: &IpAddress, timestamp: Instant) -> Answer { - match self.lookup_pure(protocol_addr, timestamp) { - Some(hardware_addr) => - Answer::Found(hardware_addr), - None if timestamp < self.silent_until => - Answer::RateLimited, - None => { - self.silent_until = timestamp + Self::SILENT_TIME; - Answer::NotFound - } - } + pub(crate) fn limit_rate(&mut self, timestamp: Instant) { + self.silent_until = timestamp + Self::SILENT_TIME; + } + + pub(crate) fn flush(&mut self) { + self.storage.clear() } } +#[cfg(any(feature = "medium-ethernet", feature = "medium-ip"))] #[cfg(test)] mod test { use super::*; - use std::collections::BTreeMap; - use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3, MOCK_IP_ADDR_4}; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3, MOCK_IP_ADDR_4}; + use crate::wire::EthernetAddress; - const HADDR_A: EthernetAddress = EthernetAddress([0, 0, 0, 0, 0, 1]); - const HADDR_B: EthernetAddress = EthernetAddress([0, 0, 0, 0, 0, 2]); - const HADDR_C: EthernetAddress = EthernetAddress([0, 0, 0, 0, 0, 3]); - const HADDR_D: EthernetAddress = EthernetAddress([0, 0, 0, 0, 0, 4]); + const HADDR_A: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 1])); + const HADDR_B: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 2])); + const HADDR_C: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 3])); + const HADDR_D: HardwareAddress = HardwareAddress::Ethernet(EthernetAddress([0, 0, 0, 0, 0, 4])); #[test] fn test_fill() { - let mut cache_storage = [Default::default(); 3]; - let mut cache = Cache::new(&mut cache_storage[..]); + let mut cache = Cache::new(); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0)), None); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_2, Instant::from_millis(0)), None); + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0)), Some(HADDR_A)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_2, Instant::from_millis(0)), None); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2), - None); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup( + &MOCK_IP_ADDR_1, + Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2 + ) + .found(),); cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_2, Instant::from_millis(0)), None); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); } #[test] fn test_expire() { - let mut cache_storage = [Default::default(); 3]; - let mut cache = Cache::new(&mut cache_storage[..]); + let mut cache = Cache::new(); cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0)), Some(HADDR_A)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2), - None); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup( + &MOCK_IP_ADDR_1, + Instant::from_millis(0) + Cache::ENTRY_LIFETIME * 2 + ) + .found(),); } #[test] fn test_replace() { - let mut cache_storage = [Default::default(); 3]; - let mut cache = Cache::new(&mut cache_storage[..]); + let mut cache = Cache::new(); cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0)), Some(HADDR_A)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); cache.fill(MOCK_IP_ADDR_1, HADDR_B, Instant::from_millis(0)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_1, Instant::from_millis(0)), Some(HADDR_B)); - } - - #[test] - fn test_cache_gc() { - let mut cache = Cache::new_with_limit(BTreeMap::new(), 2); - cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(100)); - cache.fill(MOCK_IP_ADDR_2, HADDR_B, Instant::from_millis(50)); - // Adding third item after the expiration of the previous - // two should garbage collect - cache.fill(MOCK_IP_ADDR_3, HADDR_C, Instant::from_millis(50) + Cache::ENTRY_LIFETIME * 2); - - assert_eq!(cache.storage.len(), 1); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_3, Instant::from_millis(50) + Cache::ENTRY_LIFETIME * 2), Some(HADDR_C)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_B) + ); } #[test] fn test_evict() { - let mut cache_storage = [Default::default(); 3]; - let mut cache = Cache::new(&mut cache_storage[..]); + let mut cache = Cache::new(); cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(100)); cache.fill(MOCK_IP_ADDR_2, HADDR_B, Instant::from_millis(50)); cache.fill(MOCK_IP_ADDR_3, HADDR_C, Instant::from_millis(200)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_2, Instant::from_millis(1000)), Some(HADDR_B)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_4, Instant::from_millis(1000)), None); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_2, Instant::from_millis(1000)), + Answer::Found(HADDR_B) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_4, Instant::from_millis(1000)) + .found()); cache.fill(MOCK_IP_ADDR_4, HADDR_D, Instant::from_millis(300)); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_2, Instant::from_millis(1000)), None); - assert_eq!(cache.lookup_pure(&MOCK_IP_ADDR_4, Instant::from_millis(1000)), Some(HADDR_D)); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(1000)) + .found()); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_4, Instant::from_millis(1000)), + Answer::Found(HADDR_D) + ); } #[test] fn test_hush() { - let mut cache_storage = [Default::default(); 3]; - let mut cache = Cache::new(&mut cache_storage[..]); + let mut cache = Cache::new(); + + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::NotFound + ); + + cache.limit_rate(Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(100)), + Answer::RateLimited + ); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(2000)), + Answer::NotFound + ); + } - assert_eq!(cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), Answer::NotFound); - assert_eq!(cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(100)), Answer::RateLimited); - assert_eq!(cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(2000)), Answer::NotFound); + #[test] + fn test_flush() { + let mut cache = Cache::new(); + + cache.fill(MOCK_IP_ADDR_1, HADDR_A, Instant::from_millis(0)); + assert_eq!( + cache.lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)), + Answer::Found(HADDR_A) + ); + assert!(!cache + .lookup(&MOCK_IP_ADDR_2, Instant::from_millis(0)) + .found()); + + cache.flush(); + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); + assert!(!cache + .lookup(&MOCK_IP_ADDR_1, Instant::from_millis(0)) + .found()); } } diff --git a/src/iface/route.rs b/src/iface/route.rs index ad3e6d4e0..123c6950c 100644 --- a/src/iface/route.rs +++ b/src/iface/route.rs @@ -1,17 +1,31 @@ -use managed::ManagedMap; -use time::Instant; -use core::ops::Bound; +use heapless::Vec; -use {Error, Result}; -use wire::{IpCidr, IpAddress}; +use crate::config::IFACE_MAX_ROUTE_COUNT; +use crate::time::Instant; +use crate::wire::{IpAddress, IpCidr}; #[cfg(feature = "proto-ipv4")] -use wire::{Ipv4Address, Ipv4Cidr}; +use crate::wire::{Ipv4Address, Ipv4Cidr}; #[cfg(feature = "proto-ipv6")] -use wire::{Ipv6Address, Ipv6Cidr}; +use crate::wire::{Ipv6Address, Ipv6Cidr}; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RouteTableFull; + +impl core::fmt::Display for RouteTableFull { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Route table full") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RouteTableFull {} /// A prefix of addresses that should be routed via a router #[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Route { + pub cidr: IpCidr, pub via_router: IpAddress, /// `None` means "forever". pub preferred_until: Option, @@ -19,11 +33,18 @@ pub struct Route { pub expires_at: Option, } +#[cfg(feature = "proto-ipv4")] +const IPV4_DEFAULT: IpCidr = IpCidr::Ipv4(Ipv4Cidr::new(Ipv4Address::new(0, 0, 0, 0), 0)); +#[cfg(feature = "proto-ipv6")] +const IPV6_DEFAULT: IpCidr = + IpCidr::Ipv6(Ipv6Cidr::new(Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0), 0)); + impl Route { /// Returns a route to 0.0.0.0/0 via the `gateway`, with no expiry. #[cfg(feature = "proto-ipv4")] pub fn new_ipv4_gateway(gateway: Ipv4Address) -> Route { Route { + cidr: IPV4_DEFAULT, via_router: gateway.into(), preferred_until: None, expires_at: None, @@ -34,6 +55,7 @@ impl Route { #[cfg(feature = "proto-ipv6")] pub fn new_ipv6_gateway(gateway: Ipv6Address) -> Route { Route { + cidr: IPV6_DEFAULT, via_router: gateway.into(), preferred_until: None, expires_at: None, @@ -42,40 +64,21 @@ impl Route { } /// A routing table. -/// -/// # Examples -/// -/// On systems with heap, this table can be created with: -/// -/// ```rust -/// use std::collections::BTreeMap; -/// use smoltcp::iface::Routes; -/// let mut routes = Routes::new(BTreeMap::new()); -/// ``` -/// -/// On systems without heap, use: -/// -/// ```rust -/// use smoltcp::iface::Routes; -/// let mut routes_storage = []; -/// let mut routes = Routes::new(&mut routes_storage[..]); -/// ``` #[derive(Debug)] -pub struct Routes<'a> { - storage: ManagedMap<'a, IpCidr, Route>, +pub struct Routes { + storage: Vec, } -impl<'a> Routes<'a> { - /// Creates a routing tables. The backing storage is **not** cleared - /// upon creation. - pub fn new(storage: T) -> Routes<'a> - where T: Into> { - let storage = storage.into(); - Routes { storage } +impl Routes { + /// Creates a new empty routing table. + pub fn new() -> Self { + Self { + storage: Vec::new(), + } } /// Update the routes of this node. - pub fn update)>(&mut self, f: F) { + pub fn update)>(&mut self, f: F) { f(&mut self.storage); } @@ -83,54 +86,83 @@ impl<'a> Routes<'a> { /// /// On success, returns the previous default route, if any. #[cfg(feature = "proto-ipv4")] - pub fn add_default_ipv4_route(&mut self, gateway: Ipv4Address) -> Result> { - let cidr = IpCidr::new(IpAddress::v4(0, 0, 0, 0), 0); - let route = Route::new_ipv4_gateway(gateway); - match self.storage.insert(cidr, route) { - Ok(route) => Ok(route), - Err((_cidr, _route)) => Err(Error::Exhausted) - } + pub fn add_default_ipv4_route( + &mut self, + gateway: Ipv4Address, + ) -> Result, RouteTableFull> { + let old = self.remove_default_ipv4_route(); + self.storage + .push(Route::new_ipv4_gateway(gateway)) + .map_err(|_| RouteTableFull)?; + Ok(old) } /// Add a default ipv6 gateway (ie. "ip -6 route add ::/0 via `gateway`"). /// /// On success, returns the previous default route, if any. #[cfg(feature = "proto-ipv6")] - pub fn add_default_ipv6_route(&mut self, gateway: Ipv6Address) -> Result> { - let cidr = IpCidr::new(IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 0), 0); - let route = Route::new_ipv6_gateway(gateway); - match self.storage.insert(cidr, route) { - Ok(route) => Ok(route), - Err((_cidr, _route)) => Err(Error::Exhausted) + pub fn add_default_ipv6_route( + &mut self, + gateway: Ipv6Address, + ) -> Result, RouteTableFull> { + let old = self.remove_default_ipv6_route(); + self.storage + .push(Route::new_ipv6_gateway(gateway)) + .map_err(|_| RouteTableFull)?; + Ok(old) + } + + /// Remove the default ipv4 gateway + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv4")] + pub fn remove_default_ipv4_route(&mut self) -> Option { + if let Some((i, _)) = self + .storage + .iter() + .enumerate() + .find(|(_, r)| r.cidr == IPV4_DEFAULT) + { + Some(self.storage.remove(i)) + } else { + None } } - pub(crate) fn lookup(&self, addr: &IpAddress, timestamp: Instant) -> - Option { - assert!(addr.is_unicast()); + /// Remove the default ipv6 gateway + /// + /// On success, returns the previous default route, if any. + #[cfg(feature = "proto-ipv6")] + pub fn remove_default_ipv6_route(&mut self) -> Option { + if let Some((i, _)) = self + .storage + .iter() + .enumerate() + .find(|(_, r)| r.cidr == IPV6_DEFAULT) + { + Some(self.storage.remove(i)) + } else { + None + } + } - let cidr = match addr { - #[cfg(feature = "proto-ipv4")] - IpAddress::Ipv4(addr) => IpCidr::Ipv4(Ipv4Cidr::new(*addr, 32)), - #[cfg(feature = "proto-ipv6")] - IpAddress::Ipv6(addr) => IpCidr::Ipv6(Ipv6Cidr::new(*addr, 128)), - _ => unimplemented!() - }; + pub(crate) fn lookup(&self, addr: &IpAddress, timestamp: Instant) -> Option { + assert!(addr.is_unicast()); - for (prefix, route) in self.storage.range((Bound::Unbounded::, Bound::Included(cidr))).rev() { - // TODO: do something with route.preferred_until - if let Some(expires_at) = route.expires_at { - if timestamp > expires_at { - continue; + self.storage + .iter() + // Keep only matching routes + .filter(|route| { + if let Some(expires_at) = route.expires_at { + if timestamp > expires_at { + return false; + } } - } - - if prefix.contains_addr(addr) { - return Some(route.via_router); - } - } - - None + route.cidr.contains_addr(addr) + }) + // pick the most specific one (highest prefix_len) + .max_by_key(|route| route.cidr.prefix_len()) + .map(|route| route.via_router) } } @@ -140,24 +172,28 @@ mod test { #[cfg(feature = "proto-ipv6")] mod mock { use super::super::*; - pub const ADDR_1A: Ipv6Address = Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1]); - pub const ADDR_1B: Ipv6Address = Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 13]); - pub const ADDR_1C: Ipv6Address = Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 42]); + pub const ADDR_1A: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 1]); + pub const ADDR_1B: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 13]); + pub const ADDR_1C: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 42]); pub fn cidr_1() -> Ipv6Cidr { - Ipv6Cidr::new(Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]), 64) + Ipv6Cidr::new( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 2, 0, 0, 0, 0, 0, 0, 0, 0]), + 64, + ) } - pub const ADDR_2A: Ipv6Address = Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 1]); - pub const ADDR_2B: Ipv6Address = Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 21]); + pub const ADDR_2A: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 1]); + pub const ADDR_2B: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 21]); pub fn cidr_2() -> Ipv6Cidr { - Ipv6Cidr::new(Ipv6Address( - [0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 0]), 64) + Ipv6Cidr::new( + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 51, 100, 0, 0, 0, 0, 0, 0, 0, 0]), + 64, + ) } } @@ -182,48 +218,110 @@ mod test { #[test] fn test_fill() { - let mut routes_storage = [None, None, None]; - let mut routes = Routes::new(&mut routes_storage[..]); + let mut routes = Routes::new(); - assert_eq!(routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), None); - assert_eq!(routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), None); - assert_eq!(routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), None); - assert_eq!(routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), None); - assert_eq!(routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), None); + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + None + ); let route = Route { + cidr: cidr_1().into(), via_router: ADDR_1A.into(), - preferred_until: None, expires_at: None, + preferred_until: None, + expires_at: None, }; routes.update(|storage| { - storage.insert(cidr_1().into(), route).unwrap(); + storage.push(route).unwrap(); }); - assert_eq!(routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), None); - assert_eq!(routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), None); + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + None + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + None + ); let route2 = Route { + cidr: cidr_2().into(), via_router: ADDR_2A.into(), preferred_until: Some(Instant::from_millis(10)), expires_at: Some(Instant::from_millis(10)), }; routes.update(|storage| { - storage.insert(cidr_2().into(), route2).unwrap(); + storage.push(route2).unwrap(); }); - assert_eq!(routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), Some(ADDR_2A.into())); - assert_eq!(routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), Some(ADDR_2A.into())); - - assert_eq!(routes.lookup(&ADDR_1A.into(), Instant::from_millis(10)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1B.into(), Instant::from_millis(10)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_1C.into(), Instant::from_millis(10)), Some(ADDR_1A.into())); - assert_eq!(routes.lookup(&ADDR_2A.into(), Instant::from_millis(10)), Some(ADDR_2A.into())); - assert_eq!(routes.lookup(&ADDR_2B.into(), Instant::from_millis(10)), Some(ADDR_2A.into())); + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(0)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(0)), + Some(ADDR_2A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(0)), + Some(ADDR_2A.into()) + ); + + assert_eq!( + routes.lookup(&ADDR_1A.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1B.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_1C.into(), Instant::from_millis(10)), + Some(ADDR_1A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2A.into(), Instant::from_millis(10)), + Some(ADDR_2A.into()) + ); + assert_eq!( + routes.lookup(&ADDR_2B.into(), Instant::from_millis(10)), + Some(ADDR_2A.into()) + ); } } diff --git a/src/iface/rpl/consts.rs b/src/iface/rpl/consts.rs new file mode 100644 index 000000000..70a66138f --- /dev/null +++ b/src/iface/rpl/consts.rs @@ -0,0 +1,8 @@ +pub const SEQUENCE_WINDOW: u8 = 16; + +pub const DEFAULT_MIN_HOP_RANK_INCREASE: u16 = 256; + +pub const DEFAULT_DIO_INTERVAL_MIN: u32 = 12; +pub const DEFAULT_DIO_REDUNDANCY_CONSTANT: usize = 10; +/// This is 20 in the standard, but in Contiki they use: +pub const DEFAULT_DIO_INTERVAL_DOUBLINGS: u32 = 8; diff --git a/src/iface/rpl/lollipop.rs b/src/iface/rpl/lollipop.rs new file mode 100644 index 000000000..4785c7725 --- /dev/null +++ b/src/iface/rpl/lollipop.rs @@ -0,0 +1,189 @@ +//! Implementation of sequence counters defined in [RFC 6550 § 7.2]. Values from 128 and greater +//! are used as a linear sequence to indicate a restart and bootstrap the counter. Values less than +//! or equal to 127 are used as a circular sequence number space of size 128. When operating in the +//! circular region, if sequence numbers are detected to be too far apart, then they are not +//! comparable. +//! +//! [RFC 6550 § 7.2]: https://datatracker.ietf.org/doc/html/rfc6550#section-7.2 + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct SequenceCounter(u8); + +impl Default for SequenceCounter { + fn default() -> Self { + // RFC6550 7.2 recommends 240 (256 - SEQUENCE_WINDOW) as the initialization value of the + // counter. + Self(240) + } +} + +impl SequenceCounter { + /// Create a new sequence counter. + /// + /// Use `Self::default()` when a new sequence counter needs to be created with a value that is + /// recommended in RFC6550 7.2, being 240. + pub fn new(value: u8) -> Self { + Self(value) + } + + /// Return the value of the sequence counter. + pub fn value(&self) -> u8 { + self.0 + } + + /// Increment the sequence counter. + /// + /// When the sequence counter is greater than or equal to 128, the maximum value is 255. + /// When the sequence counter is less than 128, the maximum value is 127. + /// + /// When an increment of the sequence counter would cause the counter to increment beyond its + /// maximum value, the counter MUST wrap back to zero. + pub fn increment(&mut self) { + let max = if self.0 >= 128 { 255 } else { 127 }; + + self.0 = match self.0.checked_add(1) { + Some(val) if val <= max => val, + _ => 0, + }; + } +} + +impl PartialEq for SequenceCounter { + fn eq(&self, other: &Self) -> bool { + let a = self.value() as usize; + let b = other.value() as usize; + + if ((128..=255).contains(&a) && (0..=127).contains(&b)) + || ((128..=255).contains(&b) && (0..=127).contains(&a)) + { + false + } else { + let result = if a > b { a - b } else { b - a }; + + if result <= super::consts::SEQUENCE_WINDOW as usize { + // RFC1982 + a == b + } else { + // This case is actually not comparable. + false + } + } + } +} + +impl PartialOrd for SequenceCounter { + fn partial_cmp(&self, other: &Self) -> Option { + use super::consts::SEQUENCE_WINDOW; + use core::cmp::Ordering; + + let a = self.value() as usize; + let b = other.value() as usize; + + if (128..256).contains(&a) && (0..128).contains(&b) { + if 256 + b - a <= SEQUENCE_WINDOW as usize { + Some(Ordering::Less) + } else { + Some(Ordering::Greater) + } + } else if (128..256).contains(&b) && (0..128).contains(&a) { + if 256 + a - b <= SEQUENCE_WINDOW as usize { + Some(Ordering::Greater) + } else { + Some(Ordering::Less) + } + } else if ((0..128).contains(&a) && (0..128).contains(&b)) + || ((128..256).contains(&a) && (128..256).contains(&b)) + { + let result = if a > b { a - b } else { b - a }; + + if result <= SEQUENCE_WINDOW as usize { + // RFC1982 + a.partial_cmp(&b) + } else { + // This case is not comparable. + None + } + } else { + unreachable!(); + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn sequence_counter_increment() { + let mut seq = SequenceCounter::new(253); + seq.increment(); + assert_eq!(seq.value(), 254); + seq.increment(); + assert_eq!(seq.value(), 255); + seq.increment(); + assert_eq!(seq.value(), 0); + + let mut seq = SequenceCounter::new(126); + seq.increment(); + assert_eq!(seq.value(), 127); + seq.increment(); + assert_eq!(seq.value(), 0); + } + + #[test] + fn sequence_counter_comparison() { + use core::cmp::Ordering; + + assert!(SequenceCounter::new(240) != SequenceCounter::new(1)); + assert!(SequenceCounter::new(1) != SequenceCounter::new(240)); + assert!(SequenceCounter::new(1) != SequenceCounter::new(240)); + assert!(SequenceCounter::new(240) == SequenceCounter::new(240)); + assert!(SequenceCounter::new(240 - 17) != SequenceCounter::new(240)); + + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(5)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(250).partial_cmp(&SequenceCounter::new(5)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(5).partial_cmp(&SequenceCounter::new(250)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(127).partial_cmp(&SequenceCounter::new(129)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(120).partial_cmp(&SequenceCounter::new(121)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(121).partial_cmp(&SequenceCounter::new(120)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(241)), + Some(Ordering::Less) + ); + assert_eq!( + SequenceCounter::new(241).partial_cmp(&SequenceCounter::new(240)), + Some(Ordering::Greater) + ); + assert_eq!( + SequenceCounter::new(120).partial_cmp(&SequenceCounter::new(120)), + Some(Ordering::Equal) + ); + assert_eq!( + SequenceCounter::new(240).partial_cmp(&SequenceCounter::new(240)), + Some(Ordering::Equal) + ); + assert_eq!( + SequenceCounter::new(130).partial_cmp(&SequenceCounter::new(241)), + None + ); + } +} diff --git a/src/iface/rpl/mod.rs b/src/iface/rpl/mod.rs new file mode 100644 index 000000000..69aa9ae77 --- /dev/null +++ b/src/iface/rpl/mod.rs @@ -0,0 +1,9 @@ +#![allow(unused)] + +mod consts; +mod lollipop; +mod of0; +mod parents; +mod rank; +mod relations; +mod trickle; diff --git a/src/iface/rpl/of0.rs b/src/iface/rpl/of0.rs new file mode 100644 index 000000000..99e4d1f36 --- /dev/null +++ b/src/iface/rpl/of0.rs @@ -0,0 +1,129 @@ +use super::parents::*; +use super::rank::Rank; + +pub struct ObjectiveFunction0; + +pub(crate) trait ObjectiveFunction { + const OCP: u16; + + /// Return the new calculated Rank, based on information from the parent. + fn rank(current_rank: Rank, parent_rank: Rank) -> Rank; + + /// Return the preferred parent from a given parent set. + fn preferred_parent(parent_set: &ParentSet) -> Option<&Parent>; +} + +impl ObjectiveFunction0 { + const OCP: u16 = 0; + + const RANK_STRETCH: u16 = 0; + const RANK_FACTOR: u16 = 1; + const RANK_STEP: u16 = 3; + + fn rank_increase(parent_rank: Rank) -> u16 { + (Self::RANK_FACTOR * Self::RANK_STEP + Self::RANK_STRETCH) + * parent_rank.min_hop_rank_increase + } +} + +impl ObjectiveFunction for ObjectiveFunction0 { + const OCP: u16 = 0; + + fn rank(_: Rank, parent_rank: Rank) -> Rank { + assert_ne!(parent_rank, Rank::INFINITE); + + Rank::new( + parent_rank.value + Self::rank_increase(parent_rank), + parent_rank.min_hop_rank_increase, + ) + } + + fn preferred_parent(parent_set: &ParentSet) -> Option<&Parent> { + let mut pref_parent: Option<&Parent> = None; + + for (_, parent) in parent_set.parents() { + if pref_parent.is_none() || parent.rank() < pref_parent.unwrap().rank() { + pref_parent = Some(parent); + } + } + + pref_parent + } +} + +#[cfg(test)] +mod tests { + use crate::iface::rpl::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + + use super::*; + + #[test] + fn rank_increase() { + // 256 (root) + 3 * 256 + assert_eq!( + ObjectiveFunction0::rank(Rank::INFINITE, Rank::ROOT), + Rank::new(256 + 3 * 256, DEFAULT_MIN_HOP_RANK_INCREASE) + ); + + // 1024 + 3 * 256 + assert_eq!( + ObjectiveFunction0::rank( + Rank::INFINITE, + Rank::new(1024, DEFAULT_MIN_HOP_RANK_INCREASE) + ), + Rank::new(1024 + 3 * 256, DEFAULT_MIN_HOP_RANK_INCREASE) + ); + } + + #[test] + #[should_panic] + fn rank_increase_infinite() { + assert_eq!( + ObjectiveFunction0::rank(Rank::INFINITE, Rank::INFINITE), + Rank::INFINITE + ); + } + + #[test] + fn empty_set() { + assert_eq!( + ObjectiveFunction0::preferred_parent(&ParentSet::default()), + None + ); + } + + #[test] + fn non_empty_set() { + use crate::wire::Ipv6Address; + + let mut parents = ParentSet::default(); + + parents.add( + Ipv6Address::default(), + Parent::new(0, Rank::ROOT, Default::default(), Ipv6Address::default()), + ); + + let mut address = Ipv6Address::default(); + address.0[15] = 1; + + parents.add( + address, + Parent::new( + 0, + Rank::new(1024, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + Ipv6Address::default(), + ), + ); + + assert_eq!( + ObjectiveFunction0::preferred_parent(&parents), + Some(&Parent::new( + 0, + Rank::ROOT, + Default::default(), + Ipv6Address::default(), + )) + ); + } +} diff --git a/src/iface/rpl/parents.rs b/src/iface/rpl/parents.rs new file mode 100644 index 000000000..70d5a5e88 --- /dev/null +++ b/src/iface/rpl/parents.rs @@ -0,0 +1,176 @@ +use crate::wire::Ipv6Address; + +use super::{lollipop::SequenceCounter, rank::Rank}; +use crate::config::RPL_PARENTS_BUFFER_COUNT; + +#[derive(Debug, Clone, Copy, PartialEq)] +pub(crate) struct Parent { + rank: Rank, + preference: u8, + version_number: SequenceCounter, + dodag_id: Ipv6Address, +} + +impl Parent { + /// Create a new parent. + pub(crate) fn new( + preference: u8, + rank: Rank, + version_number: SequenceCounter, + dodag_id: Ipv6Address, + ) -> Self { + Self { + rank, + preference, + version_number, + dodag_id, + } + } + + /// Return the Rank of the parent. + pub(crate) fn rank(&self) -> &Rank { + &self.rank + } +} + +#[derive(Debug, Default)] +pub(crate) struct ParentSet { + parents: heapless::LinearMap, +} + +impl ParentSet { + /// Add a new parent to the parent set. The Rank of the new parent should be lower than the + /// Rank of the node that holds this parent set. + pub(crate) fn add(&mut self, address: Ipv6Address, parent: Parent) { + if let Some(p) = self.parents.get_mut(&address) { + *p = parent; + } else if let Err(p) = self.parents.insert(address, parent) { + if let Some((w_a, w_p)) = self.worst_parent() { + if w_p.rank.dag_rank() > parent.rank.dag_rank() { + self.parents.remove(&w_a.clone()).unwrap(); + self.parents.insert(address, parent).unwrap(); + } else { + net_debug!("could not add {} to parent set, buffer is full", address); + } + } else { + unreachable!() + } + } + } + + /// Find a parent based on its address. + pub(crate) fn find(&self, address: &Ipv6Address) -> Option<&Parent> { + self.parents.get(address) + } + + /// Find a mutable parent based on its address. + pub(crate) fn find_mut(&mut self, address: &Ipv6Address) -> Option<&mut Parent> { + self.parents.get_mut(address) + } + + /// Return a slice to the parent set. + pub(crate) fn parents(&self) -> impl Iterator { + self.parents.iter() + } + + /// Find the worst parent that is currently in the parent set. + fn worst_parent(&self) -> Option<(&Ipv6Address, &Parent)> { + self.parents.iter().max_by_key(|(k, v)| v.rank.dag_rank()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn add_parent() { + let mut set = ParentSet::default(); + set.add( + Default::default(), + Parent::new(0, Rank::ROOT, Default::default(), Default::default()), + ); + + assert_eq!( + set.find(&Default::default()), + Some(&Parent::new( + 0, + Rank::ROOT, + Default::default(), + Default::default() + )) + ); + } + + #[test] + fn add_more_parents() { + use super::super::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + let mut set = ParentSet::default(); + + let mut last_address = Default::default(); + for i in 0..RPL_PARENTS_BUFFER_COUNT { + let i = i as u16; + let mut address = Ipv6Address::default(); + address.0[15] = i as u8; + last_address = address; + + set.add( + address, + Parent::new( + 0, + Rank::new(256 * i, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + + assert_eq!( + set.find(&address), + Some(&Parent::new( + 0, + Rank::new(256 * i, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + )) + ); + } + + // This one is not added to the set, because its Rank is worse than any other parent in the + // set. + let mut address = Ipv6Address::default(); + address.0[15] = 8; + set.add( + address, + Parent::new( + 0, + Rank::new(256 * 8, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + assert_eq!(set.find(&address), None); + + /// This Parent has a better rank than the last one in the set. + let mut address = Ipv6Address::default(); + address.0[15] = 9; + set.add( + address, + Parent::new( + 0, + Rank::new(0, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address, + ), + ); + assert_eq!( + set.find(&address), + Some(&Parent::new( + 0, + Rank::new(0, DEFAULT_MIN_HOP_RANK_INCREASE), + Default::default(), + address + )) + ); + assert_eq!(set.find(&last_address), None); + } +} diff --git a/src/iface/rpl/rank.rs b/src/iface/rpl/rank.rs new file mode 100644 index 000000000..02a5ecf59 --- /dev/null +++ b/src/iface/rpl/rank.rs @@ -0,0 +1,104 @@ +//! Implementation of the Rank comparison in RPL. +//! +//! A Rank can be thought of as a fixed-point number, where the position of the radix point between +//! the integer part and the fractional part is determined by `MinHopRankIncrease`. +//! `MinHopRankIncrease` is the minimum increase in Rank between a node and any of its DODAG +//! parents. +//! This value is provisined by the DODAG root. +//! +//! When Rank is compared, the integer portion of the Rank is to be used. +//! +//! Meaning of the comparison: +//! - **Rank M is less than Rank N**: the position of M is closer to the DODAG root than the position +//! of N. Node M may safely be a DODAG parent for node N. +//! - **Ranks are equal**: the positions of both nodes within the DODAG and with respect to the DODAG +//! are similar or identical. Routing through a node with equal Rank may cause a routing loop. +//! - **Rank M is greater than Rank N**: the position of node M is farther from the DODAG root +//! than the position of N. Node M may in fact be in the sub-DODAG of node N. If node N selects +//! node M as a DODAG parent, there is a risk of creating a loop. + +use super::consts::DEFAULT_MIN_HOP_RANK_INCREASE; + +/// The Rank is the expression of the relative position within a DODAG Version with regard to +/// neighbors, and it is not necessarily a good indication or a proper expression of a distance or +/// a path cost to the root. +#[derive(Debug, Clone, Copy, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Rank { + pub(super) value: u16, + pub(super) min_hop_rank_increase: u16, +} + +impl core::fmt::Display for Rank { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!(f, "Rank({})", self.dag_rank()) + } +} + +impl Rank { + pub const INFINITE: Self = Rank::new(0xffff, DEFAULT_MIN_HOP_RANK_INCREASE); + + /// The ROOT_RANK is the smallest rank possible. + /// DAG_RANK(ROOT_RANK) should be 1. See RFC6550 § 17. + pub const ROOT: Self = Rank::new(DEFAULT_MIN_HOP_RANK_INCREASE, DEFAULT_MIN_HOP_RANK_INCREASE); + + /// Create a new Rank from some value and a `MinHopRankIncrease`. + /// The `MinHopRankIncrease` is used for calculating the integer part for comparing to other + /// Ranks. + pub const fn new(value: u16, min_hop_rank_increase: u16) -> Self { + assert!(min_hop_rank_increase > 0); + + Self { + value, + min_hop_rank_increase, + } + } + + /// Return the integer part of the Rank. + pub fn dag_rank(&self) -> u16 { + self.value / self.min_hop_rank_increase + } + + /// Return the raw Rank value. + pub fn raw_value(&self) -> u16 { + self.value + } +} + +impl PartialEq for Rank { + fn eq(&self, other: &Self) -> bool { + self.dag_rank() == other.dag_rank() + } +} + +impl PartialOrd for Rank { + fn partial_cmp(&self, other: &Self) -> Option { + self.dag_rank().partial_cmp(&other.dag_rank()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn calculate_rank() { + let r = Rank::new(27, 16); + assert_eq!(r.dag_rank(), 1) + } + + #[test] + fn comparison() { + let r1 = Rank::ROOT; + let r2 = Rank::new(16, 16); + assert!(r1 == r2); + + let r1 = Rank::new(16, 16); + let r2 = Rank::new(32, 16); + assert!(r1 < r2); + + let r1 = Rank::ROOT; + let r2 = Rank::INFINITE; + assert!(r1 < r2); + } +} diff --git a/src/iface/rpl/relations.rs b/src/iface/rpl/relations.rs new file mode 100644 index 000000000..da02a3cf9 --- /dev/null +++ b/src/iface/rpl/relations.rs @@ -0,0 +1,162 @@ +use crate::time::Instant; +use crate::wire::Ipv6Address; + +use crate::config::RPL_RELATIONS_BUFFER_COUNT; + +#[derive(Debug)] +pub struct Relation { + destination: Ipv6Address, + next_hop: Ipv6Address, + expiration: Instant, +} + +#[derive(Default, Debug)] +pub struct Relations { + relations: heapless::Vec, +} + +impl Relations { + /// Add a new relation to the buffer. If there was already a relation in the buffer, then + /// update it. + pub fn add_relation( + &mut self, + destination: Ipv6Address, + next_hop: Ipv6Address, + expiration: Instant, + ) { + if let Some(r) = self + .relations + .iter_mut() + .find(|r| r.destination == destination) + { + r.next_hop = next_hop; + r.expiration = expiration; + } else { + let relation = Relation { + destination, + next_hop, + expiration, + }; + + if let Err(e) = self.relations.push(relation) { + net_debug!("Unable to add relation, buffer is full"); + } + } + } + + /// Remove all relation entries for a specific destination. + pub fn remove_relation(&mut self, destination: Ipv6Address) { + self.relations.retain(|r| r.destination != destination) + } + + /// Return the next hop for a specific IPv6 address, if there is one. + pub fn find_next_hop(&mut self, destination: Ipv6Address) -> Option { + self.relations.iter().find_map(|r| { + if r.destination == destination { + Some(r.next_hop) + } else { + None + } + }) + } + + /// Purge expired relations. + pub fn purge(&mut self, now: Instant) { + self.relations.retain(|r| r.expiration > now) + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::time::Duration; + + fn addresses(count: usize) -> Vec { + (0..count) + .map(|i| { + let mut ip = Ipv6Address::default(); + ip.0[0] = i as u8; + ip + }) + .collect() + } + + #[test] + fn add_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + } + + #[test] + fn add_relations_full_buffer() { + let addrs = addresses(crate::config::RPL_RELATIONS_BUFFER_COUNT + 1); + + // Try to add RPL_RELATIONS_BUFFER_COUNT + 1 to the buffer. + // The size of the buffer should still be RPL_RELATIONS_BUFFER_COUNT. + let mut relations = Relations::default(); + for a in addrs { + relations.add_relation(a, a, Instant::now()); + } + + assert_eq!(relations.relations.len(), RPL_RELATIONS_BUFFER_COUNT); + } + + #[test] + fn update_relation() { + let addrs = addresses(3); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + relations.add_relation(addrs[0], addrs[2], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[2])); + } + + #[test] + fn find_next_hop() { + let addrs = addresses(3); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[1])); + + relations.add_relation(addrs[0], addrs[2], Instant::now()); + assert_eq!(relations.relations.len(), 1); + assert_eq!(relations.find_next_hop(addrs[0]), Some(addrs[2])); + + // Find the next hop of a destination not in the buffer. + assert_eq!(relations.find_next_hop(addrs[1]), None); + } + + #[test] + fn remove_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now()); + assert_eq!(relations.relations.len(), 1); + + relations.remove_relation(addrs[0]); + assert!(relations.relations.is_empty()); + } + + #[test] + fn purge_relation() { + let addrs = addresses(2); + + let mut relations = Relations::default(); + relations.add_relation(addrs[0], addrs[1], Instant::now() - Duration::from_secs(1)); + + assert_eq!(relations.relations.len(), 1); + + relations.purge(Instant::now()); + assert!(relations.relations.is_empty()); + } +} diff --git a/src/iface/rpl/trickle.rs b/src/iface/rpl/trickle.rs new file mode 100644 index 000000000..e60cad109 --- /dev/null +++ b/src/iface/rpl/trickle.rs @@ -0,0 +1,266 @@ +//! Implementation of the Trickle timer defined in [RFC 6206]. The algorithm allows node in a lossy +//! shared medium to exchange information in a highly robust, energy efficient, simple, and +//! scalable manner. Dynamicaly adjusting transmission windows allows Trickle to spread new +//! information fast while sending only a few messages per hour when information does not change. +//! +//! **NOTE**: the constants used for the default Trickle timer are the ones from the [Enhanced +//! Trickle]. +//! +//! [RFC 6206]: https://datatracker.ietf.org/doc/html/rfc6206 +//! [Enhanced Trickle]: https://d1wqtxts1xzle7.cloudfront.net/71402623/E-Trickle_Enhanced_Trickle_Algorithm_for20211005-2078-1ckh34a.pdf?1633439582=&response-content-disposition=inline%3B+filename%3DE_Trickle_Enhanced_Trickle_Algorithm_for.pdf&Expires=1681472005&Signature=cC7l-Pyr5r64XBNCDeSJ2ha6oqWUtO6A-KlDOyC0UVaHxDV3h3FuVHRtcNp3O9BUfRK8jeuWCYGBkCZgQT4Zgb6XwgVB-3z4TF9o3qBRMteRyYO5vjVkpPBeN7mz4Tl746SsSCHDm2NMtr7UVtLYamriU3D0rryoqLqJXmnkNoJpn~~wJe2H5PmPgIwixTwSvDkfFLSVoESaYS9ZWHZwbW-7G7OxIw8oSYhx9xMBnzkpdmT7sJNmvDzTUhoOjYrHTRM23cLVS9~oOSpT7hKtKD4h5CSmrNK4st07KnT9~tUqEcvGO3aXdd4quRZeKUcCkCbTLvhOEYg9~QqgD8xwhA__&Key-Pair-Id=APKAJLOHF5GGSLRBV4ZA + +use crate::{ + rand::Rand, + time::{Duration, Instant}, +}; + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct TrickleTimer { + i_min: u32, + i_max: u32, + k: usize, + + i: Duration, + t: Duration, + t_exp: Instant, + i_exp: Instant, + counter: usize, +} + +impl TrickleTimer { + /// Creat a new Trickle timer using the default values. + /// + /// **NOTE**: the standard defines I as a random value between [Imin, Imax]. However, this + /// could result in a t value that is very close to Imax. Therefore, sending DIO messages will + /// be sporadic, which is not ideal when a network is started. It might take a long time before + /// the network is actually stable. Therefore, we don't draw a random numberm but just use Imin + /// for I. This only affects the start of the RPL tree and speeds up building it. Also, we + /// don't use the default values from the standard, but the values from the _Enhanced Trickle + /// Algorithm for Low-Power and Lossy Networks_ from Baraq Ghaleb et al. This is also what the + /// Contiki Trickle timer does. + pub(crate) fn default(now: Instant, rand: &mut Rand) -> Self { + use super::consts::{ + DEFAULT_DIO_INTERVAL_DOUBLINGS, DEFAULT_DIO_INTERVAL_MIN, + DEFAULT_DIO_REDUNDANCY_CONSTANT, + }; + + Self::new( + DEFAULT_DIO_INTERVAL_MIN, + DEFAULT_DIO_INTERVAL_MIN + DEFAULT_DIO_INTERVAL_DOUBLINGS, + DEFAULT_DIO_REDUNDANCY_CONSTANT, + now, + rand, + ) + } + + /// Create a new Trickle timer. + pub(crate) fn new(i_min: u32, i_max: u32, k: usize, now: Instant, rand: &mut Rand) -> Self { + let mut timer = Self { + i_min, + i_max, + k, + i: Duration::ZERO, + t: Duration::ZERO, + t_exp: Instant::ZERO, + i_exp: Instant::ZERO, + counter: 0, + }; + + timer.i = Duration::from_millis(2u32.pow(timer.i_min) as u64); + timer.i_exp = now + timer.i; + timer.counter = 0; + + timer.set_t(now, rand); + + timer + } + + /// Poll the Trickle timer. Returns `true` when the Trickle timer singals that a message can be + /// transmitted. This happens when the Trickle timer expires. + pub(crate) fn poll(&mut self, now: Instant, rand: &mut Rand) -> bool { + let can_transmit = self.can_transmit() && self.t_expired(now); + + if can_transmit { + self.set_t(now, rand); + } + + if self.i_expired(now) { + self.expire(now, rand); + } + + can_transmit + } + + /// Returns the Instant at which the Trickle timer should be polled again. Polling the Trickle + /// timer before this Instant is not harmfull, however, polling after it is not correct. + pub(crate) fn poll_at(&self) -> Instant { + self.t_exp.min(self.i_exp) + } + + /// Signal the Trickle timer that a consistency has been heard, and thus increasing it's + /// counter. + pub(crate) fn hear_consistent(&mut self) { + self.counter += 1; + } + + /// Signal the Trickle timer that an inconsistency has been heard. This resets the Trickle + /// timer when the current interval is not the smallest possible. + pub(crate) fn hear_inconsistency(&mut self, now: Instant, rand: &mut Rand) { + let i = Duration::from_millis(2u32.pow(self.i_min) as u64); + if self.i > i { + self.reset(i, now, rand); + } + } + + /// Check if the Trickle timer can transmit or not. Returns `false` when the consistency + /// counter is bigger or equal to the default consistency constant. + pub(crate) fn can_transmit(&self) -> bool { + self.k != 0 && self.counter < self.k + } + + /// Reset the Trickle timer when the interval has expired. + fn expire(&mut self, now: Instant, rand: &mut Rand) { + let max_interval = Duration::from_millis(2u32.pow(self.i_max) as u64); + let i = if self.i >= max_interval { + max_interval + } else { + self.i + self.i + }; + + self.reset(i, now, rand); + } + + pub(crate) fn reset(&mut self, i: Duration, now: Instant, rand: &mut Rand) { + self.i = i; + self.i_exp = now + self.i; + self.counter = 0; + self.set_t(now, rand); + } + + pub(crate) const fn max_expiration(&self) -> Duration { + Duration::from_millis(2u32.pow(self.i_max) as u64) + } + + pub(crate) const fn min_expiration(&self) -> Duration { + Duration::from_millis(2u32.pow(self.i_min) as u64) + } + + fn set_t(&mut self, now: Instant, rand: &mut Rand) { + let t = Duration::from_micros( + self.i.total_micros() / 2 + + (rand.rand_u32() as u64 + % (self.i.total_micros() - self.i.total_micros() / 2 + 1)), + ); + + self.t = t; + self.t_exp = now + t; + } + + fn t_expired(&self, now: Instant) -> bool { + now >= self.t_exp + } + + fn i_expired(&self, now: Instant) -> bool { + now >= self.i_exp + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn trickle_timer_intervals() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + let mut previous_i = trickle.i; + + while now <= Instant::from_secs(100_000) { + trickle.poll(now, &mut rand); + + if now < Instant::ZERO + trickle.max_expiration() { + // t should always be inbetween I/2 and I. + assert!(trickle.i / 2 < trickle.t); + assert!(trickle.i > trickle.t); + } + + if previous_i != trickle.i { + // When a new Interval is selected, this should be double the previous one. + assert_eq!(previous_i * 2, trickle.i); + assert_eq!(trickle.counter, 0); + previous_i = trickle.i; + } + + now += Duration::from_millis(100); + } + } + + #[test] + fn trickle_timer_hear_inconsistency() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + trickle.counter = 1; + + while now <= Instant::from_secs(10_000) { + trickle.poll(now, &mut rand); + + if now < trickle.i_exp && now < Instant::ZERO + trickle.min_expiration() { + assert_eq!(trickle.counter, 1); + } else { + // The first interval expired, so the conter is reset. + assert_eq!(trickle.counter, 0); + } + + if now == Instant::from_secs(10) { + // We set the counter to 1 such that we can test the `hear_inconsistency`. + trickle.counter = 1; + + assert_eq!(trickle.counter, 1); + + trickle.hear_inconsistency(now, &mut rand); + + assert_eq!(trickle.counter, 0); + assert_eq!(trickle.i, trickle.min_expiration()); + } + + now += Duration::from_millis(100); + } + } + + #[test] + fn trickle_timer_hear_consistency() { + let mut rand = Rand::new(1234); + let mut now = Instant::ZERO; + let mut trickle = TrickleTimer::default(now, &mut rand); + + trickle.counter = 1; + + let mut transmit_counter = 0; + + while now <= Instant::from_secs(10_000) { + trickle.hear_consistent(); + + if trickle.poll(now, &mut rand) { + transmit_counter += 1; + } + + if now == Instant::from_secs(10_000) { + use super::super::consts::{ + DEFAULT_DIO_INTERVAL_DOUBLINGS, DEFAULT_DIO_REDUNDANCY_CONSTANT, + }; + assert!(!trickle.poll(now, &mut rand)); + assert!(trickle.counter > DEFAULT_DIO_REDUNDANCY_CONSTANT); + // We should never have transmitted since the counter was higher than the default + // redundancy constant. + assert_eq!(transmit_counter, 0); + } + + now += Duration::from_millis(100); + } + } +} diff --git a/src/iface/socket_meta.rs b/src/iface/socket_meta.rs new file mode 100644 index 000000000..82c99087b --- /dev/null +++ b/src/iface/socket_meta.rs @@ -0,0 +1,103 @@ +use super::SocketHandle; +use crate::{ + socket::PollAt, + time::{Duration, Instant}, + wire::IpAddress, +}; + +/// Neighbor dependency. +/// +/// This enum tracks whether the socket should be polled based on the neighbor +/// it is going to send packets to. +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum NeighborState { + /// Socket can be polled immediately. + #[default] + Active, + /// Socket should not be polled until either `silent_until` passes or + /// `neighbor` appears in the neighbor cache. + Waiting { + neighbor: IpAddress, + silent_until: Instant, + }, +} + +/// Network socket metadata. +/// +/// This includes things that only external (to the socket, that is) code +/// is interested in, but which are more conveniently stored inside the socket +/// itself. +#[derive(Debug, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub(crate) struct Meta { + /// Handle of this socket within its enclosing `SocketSet`. + /// Mainly useful for debug output. + pub(crate) handle: SocketHandle, + /// See [NeighborState](struct.NeighborState.html). + neighbor_state: NeighborState, +} + +impl Meta { + /// Minimum delay between neighbor discovery requests for this particular + /// socket, in milliseconds. + /// + /// See also `iface::NeighborCache::SILENT_TIME`. + pub(crate) const DISCOVERY_SILENT_TIME: Duration = Duration::from_millis(1_000); + + pub(crate) fn poll_at(&self, socket_poll_at: PollAt, has_neighbor: F) -> PollAt + where + F: Fn(IpAddress) -> bool, + { + match self.neighbor_state { + NeighborState::Active => socket_poll_at, + NeighborState::Waiting { neighbor, .. } if has_neighbor(neighbor) => socket_poll_at, + NeighborState::Waiting { silent_until, .. } => PollAt::Time(silent_until), + } + } + + pub(crate) fn egress_permitted(&mut self, timestamp: Instant, has_neighbor: F) -> bool + where + F: Fn(IpAddress) -> bool, + { + match self.neighbor_state { + NeighborState::Active => true, + NeighborState::Waiting { + neighbor, + silent_until, + } => { + if has_neighbor(neighbor) { + net_trace!( + "{}: neighbor {} discovered, unsilencing", + self.handle, + neighbor + ); + self.neighbor_state = NeighborState::Active; + true + } else if timestamp >= silent_until { + net_trace!( + "{}: neighbor {} silence timer expired, rediscovering", + self.handle, + neighbor + ); + true + } else { + false + } + } + } + } + + pub(crate) fn neighbor_missing(&mut self, timestamp: Instant, neighbor: IpAddress) { + net_trace!( + "{}: neighbor {} missing, silencing until t+{}", + self.handle, + neighbor, + Self::DISCOVERY_SILENT_TIME + ); + self.neighbor_state = NeighborState::Waiting { + neighbor, + silent_until: timestamp + Self::DISCOVERY_SILENT_TIME, + }; + } +} diff --git a/src/iface/socket_set.rs b/src/iface/socket_set.rs new file mode 100644 index 000000000..fe9bef755 --- /dev/null +++ b/src/iface/socket_set.rs @@ -0,0 +1,149 @@ +use core::fmt; +use managed::ManagedSlice; + +use super::socket_meta::Meta; +use crate::socket::{AnySocket, Socket}; + +/// Opaque struct with space for storing one socket. +/// +/// This is public so you can use it to allocate space for storing +/// sockets when creating an Interface. +#[derive(Debug, Default)] +pub struct SocketStorage<'a> { + inner: Option>, +} + +impl<'a> SocketStorage<'a> { + pub const EMPTY: Self = Self { inner: None }; +} + +/// An item of a socket set. +#[derive(Debug)] +pub(crate) struct Item<'a> { + pub(crate) meta: Meta, + pub(crate) socket: Socket<'a>, +} + +/// A handle, identifying a socket in an Interface. +#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct SocketHandle(usize); + +impl fmt::Display for SocketHandle { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "#{}", self.0) + } +} + +/// An extensible set of sockets. +/// +/// The lifetime `'a` is used when storing a `Socket<'a>`. +#[derive(Debug)] +pub struct SocketSet<'a> { + sockets: ManagedSlice<'a, SocketStorage<'a>>, +} + +impl<'a> SocketSet<'a> { + /// Create a socket set using the provided storage. + pub fn new(sockets: SocketsT) -> SocketSet<'a> + where + SocketsT: Into>>, + { + let sockets = sockets.into(); + SocketSet { sockets } + } + + /// Add a socket to the set, and return its handle. + /// + /// # Panics + /// This function panics if the storage is fixed-size (not a `Vec`) and is full. + pub fn add>(&mut self, socket: T) -> SocketHandle { + fn put<'a>(index: usize, slot: &mut SocketStorage<'a>, socket: Socket<'a>) -> SocketHandle { + net_trace!("[{}]: adding", index); + let handle = SocketHandle(index); + let mut meta = Meta::default(); + meta.handle = handle; + *slot = SocketStorage { + inner: Some(Item { meta, socket }), + }; + handle + } + + let socket = socket.upcast(); + + for (index, slot) in self.sockets.iter_mut().enumerate() { + if slot.inner.is_none() { + return put(index, slot, socket); + } + } + + match &mut self.sockets { + ManagedSlice::Borrowed(_) => panic!("adding a socket to a full SocketSet"), + #[cfg(feature = "alloc")] + ManagedSlice::Owned(sockets) => { + sockets.push(SocketStorage { inner: None }); + let index = sockets.len() - 1; + put(index, &mut sockets[index], socket) + } + } + } + + /// Get a socket from the set by its handle, as mutable. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set + /// or the socket has the wrong type. + pub fn get>(&self, handle: SocketHandle) -> &T { + match self.sockets[handle.0].inner.as_ref() { + Some(item) => { + T::downcast(&item.socket).expect("handle refers to a socket of a wrong type") + } + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Get a mutable socket from the set by its handle, as mutable. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set + /// or the socket has the wrong type. + pub fn get_mut>(&mut self, handle: SocketHandle) -> &mut T { + match self.sockets[handle.0].inner.as_mut() { + Some(item) => T::downcast_mut(&mut item.socket) + .expect("handle refers to a socket of a wrong type"), + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Remove a socket from the set, without changing its state. + /// + /// # Panics + /// This function may panic if the handle does not belong to this socket set. + pub fn remove(&mut self, handle: SocketHandle) -> Socket<'a> { + net_trace!("[{}]: removing", handle.0); + match self.sockets[handle.0].inner.take() { + Some(item) => item.socket, + None => panic!("handle does not refer to a valid socket"), + } + } + + /// Get an iterator to the inner sockets. + pub fn iter(&self) -> impl Iterator)> { + self.items().map(|i| (i.meta.handle, &i.socket)) + } + + /// Get a mutable iterator to the inner sockets. + pub fn iter_mut(&mut self) -> impl Iterator)> { + self.items_mut().map(|i| (i.meta.handle, &mut i.socket)) + } + + /// Iterate every socket in this set. + pub(crate) fn items(&self) -> impl Iterator> + '_ { + self.sockets.iter().filter_map(|x| x.inner.as_ref()) + } + + /// Iterate every socket in this set. + pub(crate) fn items_mut(&mut self) -> impl Iterator> + '_ { + self.sockets.iter_mut().filter_map(|x| x.inner.as_mut()) + } +} diff --git a/src/lib.rs b/src/lib.rs index 4af96235b..6128c2314 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,7 +1,5 @@ -#![cfg_attr(feature = "alloc", feature(alloc))] -#![no_std] +#![cfg_attr(not(any(test, feature = "std")), no_std)] #![deny(unsafe_code)] -#![cfg_attr(all(any(feature = "proto-ipv4", feature = "proto-ipv6"), feature = "ethernet"), deny(unused))] //! The _smoltcp_ library is built in a layered structure, with the layers corresponding //! to the levels of API abstraction. Only the highest layers would be used by a typical @@ -42,7 +40,7 @@ //! Unlike the higher layers, the wire layer APIs will not be used by a typical application. //! They however are the bedrock of _smoltcp_, and everything else is built on top of them. //! -//! The wire layer APIs are designed by the principle "make illegal states irrepresentable". +//! The wire layer APIs are designed by the principle "make illegal states ir-representable". //! If a wire layer object can be constructed, then it can also be parsed from or emitted to //! a lower level. //! @@ -65,18 +63,13 @@ //! feature ever defined, to ensure that, when the representation layer is unable to make sense //! of a packet, it is still logged correctly and in full. //! -//! ## Packet and representation layer support -//! | Protocol | Packet | Representation | -//! |----------|--------|----------------| -//! | Ethernet | Yes | Yes | -//! | ARP | Yes | Yes | -//! | IPv4 | Yes | Yes | -//! | ICMPv4 | Yes | Yes | -//! | IGMPv1/2 | Yes | Yes | -//! | IPv6 | Yes | Yes | -//! | ICMPv6 | Yes | Yes | -//! | TCP | Yes | Yes | -//! | UDP | Yes | Yes | +//! # Minimum Supported Rust Version (MSRV) +//! +//! This crate is guaranteed to compile on stable Rust 1.65 and up with any valid set of features. +//! It *might* compile on older versions but that may change in any new patch release. +//! +//! The exception is when using the `defmt` feature, in which case `defmt`'s MSRV applies, which +//! is higher. //! //! [wire]: wire/index.html //! [osi]: https://en.wikipedia.org/wiki/OSI_model @@ -88,96 +81,89 @@ feature = "socket-tcp")))] compile_error!("at least one socket needs to be enabled"); */ -// FIXME(dlrobertson): clippy fails with this lint -#![cfg_attr(feature = "cargo-clippy", allow(if_same_then_else))] +#![allow(clippy::match_like_matches_macro)] +#![allow(clippy::redundant_field_names)] +#![allow(clippy::identity_op)] +#![allow(clippy::option_map_unit_fn)] +#![allow(clippy::unit_arg)] +#![allow(clippy::new_without_default)] -#[cfg(all(feature = "proto-ipv6", feature = "ethernet"))] -#[macro_use] -extern crate bitflags; -extern crate byteorder; -extern crate managed; -#[cfg(any(test, feature = "std"))] -#[macro_use] -extern crate std; -#[cfg(any(feature = "phy-raw_socket", feature = "phy-tap_interface"))] -extern crate libc; #[cfg(feature = "alloc")] extern crate alloc; -#[cfg(feature = "log")] -#[macro_use(trace, debug)] -extern crate log; -use core::fmt; +#[cfg(not(any( + feature = "proto-ipv4", + feature = "proto-ipv6", + feature = "proto-sixlowpan" +)))] +compile_error!("You must enable at least one of the following features: proto-ipv4, proto-ipv6, proto-sixlowpan"); -#[macro_use] -mod macros; -mod parsers; +#[cfg(all( + feature = "socket", + not(any( + feature = "socket-raw", + feature = "socket-udp", + feature = "socket-tcp", + feature = "socket-icmp", + feature = "socket-dhcpv4", + feature = "socket-dns", + )) +))] +compile_error!("If you enable the socket feature, you must enable at least one of the following features: socket-raw, socket-udp, socket-tcp, socket-icmp, socket-dhcpv4, socket-dns"); -pub mod storage; -pub mod phy; -pub mod wire; -pub mod iface; -pub mod socket; -pub mod time; -#[cfg(feature = "proto-dhcpv4")] -pub mod dhcp; +#[cfg(all( + feature = "socket", + not(any( + feature = "medium-ethernet", + feature = "medium-ip", + feature = "medium-ieee802154", + )) +))] +compile_error!("If you enable the socket feature, you must enable at least one of the following features: medium-ip, medium-ethernet, medium-ieee802154"); -/// The error type for the networking stack. -#[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub enum Error { - /// An operation cannot proceed because a buffer is empty or full. - Exhausted, - /// An operation is not permitted in the current state. - Illegal, - /// An endpoint or address of a remote host could not be translated to a lower level address. - /// E.g. there was no an Ethernet address corresponding to an IPv4 address in the ARP cache, - /// or a TCP connection attempt was made to an unspecified endpoint. - Unaddressable, +#[cfg(all(feature = "defmt", feature = "log"))] +compile_error!("You must enable at most one of the following features: defmt, log"); - /// The operation is finished. - /// E.g. when reading from a TCP socket, there's no more data to read because the remote - /// has closed the connection. - Finished, +#[macro_use] +mod macros; +mod parsers; +mod rand; - /// An incoming packet could not be parsed because some of its fields were out of bounds - /// of the received data. - Truncated, - /// An incoming packet had an incorrect checksum and was dropped. - Checksum, - /// An incoming packet could not be recognized and was dropped. - /// E.g. an Ethernet packet with an unknown EtherType. - Unrecognized, - /// An incoming IP packet has been split into several IP fragments and was dropped, - /// since IP reassembly is not supported. - Fragmented, - /// An incoming packet was recognized but was self-contradictory. - /// E.g. a TCP packet with both SYN and FIN flags set. - Malformed, - /// An incoming packet was recognized but contradicted internal state. - /// E.g. a TCP packet addressed to a socket that doesn't exist. - Dropped, +#[cfg(test)] +mod config { + #![allow(unused)] + pub const ASSEMBLER_MAX_SEGMENT_COUNT: usize = 4; + pub const DNS_MAX_NAME_SIZE: usize = 255; + pub const DNS_MAX_RESULT_COUNT: usize = 1; + pub const DNS_MAX_SERVER_COUNT: usize = 1; + pub const FRAGMENTATION_BUFFER_SIZE: usize = 1500; + pub const IFACE_MAX_ADDR_COUNT: usize = 8; + pub const IFACE_MAX_MULTICAST_GROUP_COUNT: usize = 4; + pub const IFACE_MAX_ROUTE_COUNT: usize = 4; + pub const IFACE_MAX_SIXLOWPAN_ADDRESS_CONTEXT_COUNT: usize = 4; + pub const IFACE_NEIGHBOR_CACHE_COUNT: usize = 3; + pub const REASSEMBLY_BUFFER_COUNT: usize = 4; + pub const REASSEMBLY_BUFFER_SIZE: usize = 1500; + pub const RPL_RELATIONS_BUFFER_COUNT: usize = 16; + pub const RPL_PARENTS_BUFFER_COUNT: usize = 8; +} - #[doc(hidden)] - __Nonexhaustive +#[cfg(not(test))] +mod config { + #![allow(unused)] + include!(concat!(env!("OUT_DIR"), "/config.rs")); } -/// The result type for the networking stack. -pub type Result = core::result::Result; +#[cfg(any( + feature = "medium-ethernet", + feature = "medium-ip", + feature = "medium-ieee802154" +))] +pub mod iface; -impl fmt::Display for Error { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Error::Exhausted => write!(f, "buffer space exhausted"), - &Error::Illegal => write!(f, "illegal operation"), - &Error::Unaddressable => write!(f, "unaddressable destination"), - &Error::Finished => write!(f, "operation finished"), - &Error::Truncated => write!(f, "truncated packet"), - &Error::Checksum => write!(f, "checksum error"), - &Error::Unrecognized => write!(f, "unrecognized packet"), - &Error::Fragmented => write!(f, "fragmented packet"), - &Error::Malformed => write!(f, "malformed packet"), - &Error::Dropped => write!(f, "dropped by socket"), - &Error::__Nonexhaustive => unreachable!() - } - } -} +pub mod phy; +#[cfg(feature = "socket")] +pub mod socket; +pub mod storage; +pub mod time; +pub mod wire; diff --git a/src/macros.rs b/src/macros.rs index eb51d03a3..e899d24ec 100644 --- a/src/macros.rs +++ b/src/macros.rs @@ -1,18 +1,26 @@ +#[cfg(not(test))] #[cfg(feature = "log")] -#[macro_use] -mod log { - macro_rules! net_log { - (trace, $($arg:expr),*) => { trace!($($arg),*); }; - (debug, $($arg:expr),*) => { debug!($($arg),*); }; - } +macro_rules! net_log { + (trace, $($arg:expr),*) => { log::trace!($($arg),*) }; + (debug, $($arg:expr),*) => { log::debug!($($arg),*) }; } -#[cfg(not(feature = "log"))] -#[macro_use] -mod log { - macro_rules! net_log { - ($level:ident, $($arg:expr),*) => { $( let _ = $arg; )* } - } +#[cfg(test)] +#[cfg(feature = "log")] +macro_rules! net_log { + (trace, $($arg:expr),*) => { println!($($arg),*) }; + (debug, $($arg:expr),*) => { println!($($arg),*) }; +} + +#[cfg(feature = "defmt")] +macro_rules! net_log { + (trace, $($arg:expr),*) => { defmt::trace!($($arg),*) }; + (debug, $($arg:expr),*) => { defmt::debug!($($arg),*) }; +} + +#[cfg(not(any(feature = "log", feature = "defmt")))] +macro_rules! net_log { + ($level:ident, $($arg:expr),*) => {{ $( let _ = $arg; )* }} } macro_rules! net_trace { @@ -27,26 +35,14 @@ macro_rules! enum_with_unknown { ( $( #[$enum_attr:meta] )* pub enum $name:ident($ty:ty) { - $( $variant:ident = $value:expr ),+ $(,)* - } - ) => { - enum_with_unknown! { - $( #[$enum_attr] )* - pub doc enum $name($ty) { - $( #[doc(shown)] $variant = $value ),+ - } - } - }; - ( - $( #[$enum_attr:meta] )* - pub doc enum $name:ident($ty:ty) { $( - $( #[$variant_attr:meta] )+ - $variant:ident = $value:expr $(,)* - ),+ + $( #[$variant_attr:meta] )* + $variant:ident = $value:expr + ),+ $(,)? } ) => { - #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] + #[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Hash)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] $( #[$enum_attr] )* pub enum $name { $( @@ -75,3 +71,99 @@ macro_rules! enum_with_unknown { } } } + +#[cfg(feature = "proto-rpl")] +macro_rules! get { + ($buffer:expr, into: $into:ty, fun: $fun:ident, field: $field:expr $(,)?) => { + { + <$into>::$fun(&$buffer.as_ref()[$field]) + } + }; + + ($buffer:expr, into: $into:ty, field: $field:expr $(,)?) => { + get!($buffer, into: $into, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, into: $into:ty, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + get!($buffer, into: $into, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, into: $into:ty, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => { + { + <$into>::from((&$buffer.as_ref()[$field] >> $bit_shift) & $bit_mask) + } + }; + + ($buffer:expr, field: $field:expr $(,)?) => { + get!($buffer, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + get!($buffer, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) + => + { + { + (&$buffer.as_ref()[$field] >> $bit_shift) & $bit_mask + } + }; + + ($buffer:expr, u16, field: $field:expr $(,)?) => { + { + NetworkEndian::read_u16(&$buffer.as_ref()[$field]) + } + }; + + ($buffer:expr, bool, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => { + { + (($buffer.as_ref()[$field] >> $bit_shift) & $bit_mask) == 0b1 + } + }; + + ($buffer:expr, u32, field: $field:expr $(,)?) => { + { + NetworkEndian::read_u32(&$buffer.as_ref()[$field]) + } + }; +} + +#[cfg(feature = "proto-rpl")] +macro_rules! set { + ($buffer:expr, address: $address:ident, field: $field:expr $(,)?) => {{ + $buffer.as_mut()[$field].copy_from_slice($address.as_bytes()); + }}; + + ($buffer:expr, $value:ident, field: $field:expr $(,)?) => { + set!($buffer, $value, field: $field, shift: 0, mask: 0b1111_1111) + }; + + ($buffer:expr, $value:ident, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + set!($buffer, $value, field: $field, shift: 0, mask: $bit_mask) + }; + + ($buffer:expr, $value:ident, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => {{ + let raw = + ($buffer.as_ref()[$field] & !($bit_mask << $bit_shift)) | ($value << $bit_shift); + $buffer.as_mut()[$field] = raw; + }}; + + ($buffer:expr, $value:ident, bool, field: $field:expr, mask: $bit_mask:expr $(,)?) => { + set!($buffer, $value, bool, field: $field, shift: 0, mask: $bit_mask); + }; + + ($buffer:expr, $value:ident, bool, field: $field:expr, shift: $bit_shift:expr, mask: $bit_mask:expr $(,)?) => {{ + let raw = ($buffer.as_ref()[$field] & !($bit_mask << $bit_shift)) + | (if $value { 0b1 } else { 0b0 } << $bit_shift); + $buffer.as_mut()[$field] = raw; + }}; + + ($buffer:expr, $value:ident, u16, field: $field:expr $(,)?) => {{ + NetworkEndian::write_u16(&mut $buffer.as_mut()[$field], $value); + }}; + + ($buffer:expr, $value:ident, u32, field: $field:expr $(,)?) => {{ + NetworkEndian::write_u32(&mut $buffer.as_mut()[$field], $value); + }}; +} diff --git a/src/parsers.rs b/src/parsers.rs index e6452633b..16419ab5c 100644 --- a/src/parsers.rs +++ b/src/parsers.rs @@ -1,28 +1,31 @@ -#![cfg_attr(not(all(feature = "proto-ipv6", feature = "proto-ipv4")), allow(dead_code))] +#![cfg_attr( + not(all(feature = "proto-ipv6", feature = "proto-ipv4")), + allow(dead_code) +)] -use core::str::FromStr; use core::result; +use core::str::FromStr; -#[cfg(feature = "ethernet")] -use wire::EthernetAddress; -use wire::{IpAddress, IpCidr, IpEndpoint}; +#[cfg(feature = "medium-ethernet")] +use crate::wire::EthernetAddress; +use crate::wire::{IpAddress, IpCidr, IpEndpoint}; #[cfg(feature = "proto-ipv4")] -use wire::{Ipv4Address, Ipv4Cidr}; +use crate::wire::{Ipv4Address, Ipv4Cidr}; #[cfg(feature = "proto-ipv6")] -use wire::{Ipv6Address, Ipv6Cidr}; +use crate::wire::{Ipv6Address, Ipv6Cidr}; type Result = result::Result; struct Parser<'a> { data: &'a [u8], - pos: usize + pos: usize, } impl<'a> Parser<'a> { fn new(data: &'a str) -> Parser<'a> { Parser { data: data.as_bytes(), - pos: 0 + pos: 0, } } @@ -40,12 +43,14 @@ impl<'a> Parser<'a> { self.pos += 1; Ok(chr) } - None => Err(()) + None => Err(()), } } - fn try(&mut self, f: F) -> Option - where F: FnOnce(&mut Parser<'a>) -> Result { + fn try_do(&mut self, f: F) -> Option + where + F: FnOnce(&mut Parser<'a>) -> Result, + { let pos = self.pos; match f(self) { Ok(res) => Some(res), @@ -65,7 +70,9 @@ impl<'a> Parser<'a> { } fn until_eof(&mut self, f: F) -> Result - where F: FnOnce(&mut Parser<'a>) -> Result { + where + F: FnOnce(&mut Parser<'a>) -> Result, + { let res = f(self)?; self.accept_eof()?; Ok(res) @@ -88,27 +95,26 @@ impl<'a> Parser<'a> { fn accept_digit(&mut self, hex: bool) -> Result { let digit = self.advance()?; - if digit >= b'0' && digit <= b'9' { + if digit.is_ascii_digit() { Ok(digit - b'0') - } else if hex && digit >= b'a' && digit <= b'f' { + } else if hex && (b'a'..=b'f').contains(&digit) { Ok(digit - b'a' + 10) - } else if hex && digit >= b'A' && digit <= b'F' { + } else if hex && (b'A'..=b'F').contains(&digit) { Ok(digit - b'A' + 10) } else { Err(()) } } - fn accept_number(&mut self, max_digits: usize, max_value: u32, - hex: bool) -> Result { + fn accept_number(&mut self, max_digits: usize, max_value: u32, hex: bool) -> Result { let mut value = self.accept_digit(hex)? as u32; for _ in 1..max_digits { - match self.try(|p| p.accept_digit(hex)) { + match self.try_do(|p| p.accept_digit(hex)) { Some(digit) => { value *= if hex { 16 } else { 10 }; value += digit as u32; } - None => break + None => break, } } if value < max_value { @@ -118,11 +124,11 @@ impl<'a> Parser<'a> { } } - #[cfg(feature = "ethernet")] + #[cfg(feature = "medium-ethernet")] fn accept_mac_joined_with(&mut self, separator: u8) -> Result { let mut octets = [0u8; 6]; - for n in 0..6 { - octets[n] = self.accept_number(2, 0x100, true)? as u8; + for (n, octet) in octets.iter_mut().enumerate() { + *octet = self.accept_number(2, 0x100, true)? as u8; if n != 5 { self.accept_char(separator)?; } @@ -130,13 +136,13 @@ impl<'a> Parser<'a> { Ok(EthernetAddress(octets)) } - #[cfg(feature = "ethernet")] + #[cfg(feature = "medium-ethernet")] fn accept_mac(&mut self) -> Result { - if let Some(mac) = self.try(|p| p.accept_mac_joined_with(b'-')) { - return Ok(mac) + if let Some(mac) = self.try_do(|p| p.accept_mac_joined_with(b'-')) { + return Ok(mac); } - if let Some(mac) = self.try(|p| p.accept_mac_joined_with(b':')) { - return Ok(mac) + if let Some(mac) = self.try_do(|p| p.accept_mac_joined_with(b':')) { + return Ok(mac); } Err(()) } @@ -154,17 +160,20 @@ impl<'a> Parser<'a> { } #[cfg(feature = "proto-ipv6")] - fn accept_ipv6_part(&mut self, (head, tail): (&mut [u16; 8], &mut [u16; 6]), - (head_idx, tail_idx): (&mut usize, &mut usize), - mut use_tail: bool, is_cidr: bool) -> Result<()> { - let double_colon = match self.try(|p| p.accept_str(b"::")) { + fn accept_ipv6_part( + &mut self, + (head, tail): (&mut [u16; 8], &mut [u16; 6]), + (head_idx, tail_idx): (&mut usize, &mut usize), + mut use_tail: bool, + ) -> Result<()> { + let double_colon = match self.try_do(|p| p.accept_str(b"::")) { Some(_) if !use_tail && *head_idx < 7 => { // Found a double colon. Start filling out the // tail and set the double colon flag in case // this is the last character we can parse. use_tail = true; true - }, + } Some(_) => { // This is a bad address. Only one double colon is // allowed and an address is only 128 bits. @@ -180,39 +189,38 @@ impl<'a> Parser<'a> { } }; - match self.try(|p| p.accept_number(4, 0x10000, true)) { + match self.try_do(|p| p.accept_number(4, 0x10000, true)) { Some(part) if !use_tail && *head_idx < 8 => { // Valid u16 to be added to the address head[*head_idx] = part as u16; *head_idx += 1; if *head_idx == 6 && head[0..*head_idx] == [0, 0, 0, 0, 0, 0xffff] { - self.try(|p| { + self.try_do(|p| { p.accept_char(b':')?; p.accept_ipv4_mapped_ipv6_part(head, head_idx) }); } Ok(()) - }, + } Some(part) if *tail_idx < 6 => { // Valid u16 to be added to the address tail[*tail_idx] = part as u16; *tail_idx += 1; - if *tail_idx == 1 && tail[0] == 0xffff - && head[0..8] == [0, 0, 0, 0, 0, 0, 0, 0] { - self.try(|p| { + if *tail_idx == 1 && tail[0] == 0xffff && head[0..8] == [0, 0, 0, 0, 0, 0, 0, 0] { + self.try_do(|p| { p.accept_char(b':')?; p.accept_ipv4_mapped_ipv6_part(tail, tail_idx) }); } Ok(()) - }, + } Some(_) => { // Tail or head section is too long Err(()) } - None if double_colon && (is_cidr || self.pos == self.data.len()) => { + None if double_colon => { // The address ends with "::". E.g. 1234:: or :: Ok(()) } @@ -233,12 +241,12 @@ impl<'a> Parser<'a> { Ok(()) } else { // Continue recursing - self.accept_ipv6_part((head, tail), (head_idx, tail_idx), use_tail, is_cidr) + self.accept_ipv6_part((head, tail), (head_idx, tail_idx), use_tail) } } #[cfg(feature = "proto-ipv6")] - fn accept_ipv6(&mut self, is_cidr: bool) -> Result { + fn accept_ipv6(&mut self) -> Result { // IPv6 addresses may contain a "::" to indicate a series of // 16 bit sections that evaluate to 0. E.g. // @@ -259,7 +267,11 @@ impl<'a> Parser<'a> { let (mut addr, mut tail) = ([0u16; 8], [0u16; 6]); let (mut head_idx, mut tail_idx) = (0, 0); - self.accept_ipv6_part((&mut addr, &mut tail), (&mut head_idx, &mut tail_idx), false, is_cidr)?; + self.accept_ipv6_part( + (&mut addr, &mut tail), + (&mut head_idx, &mut tail_idx), + false, + )?; // We need to copy the tail portion (the portion following the "::") to the // end of the address. @@ -270,8 +282,8 @@ impl<'a> Parser<'a> { fn accept_ipv4_octets(&mut self) -> Result<[u8; 4]> { let mut octets = [0u8; 4]; - for n in 0..4 { - octets[n] = self.accept_number(3, 0x100, false)? as u8; + for (n, octet) in octets.iter_mut().enumerate() { + *octet = self.accept_number(3, 0x100, false)? as u8; if n != 3 { self.accept_char(b'.')?; } @@ -287,15 +299,17 @@ impl<'a> Parser<'a> { fn accept_ip(&mut self) -> Result { #[cfg(feature = "proto-ipv4")] - match self.try(|p| p.accept_ipv4()) { + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv4()) { Some(ipv4) => return Ok(IpAddress::Ipv4(ipv4)), - None => () + None => (), } #[cfg(feature = "proto-ipv6")] - match self.try(|p| p.accept_ipv6(false)) { + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv6()) { Some(ipv6) => return Ok(IpAddress::Ipv6(ipv6)), - None => () + None => (), } Err(()) @@ -312,43 +326,54 @@ impl<'a> Parser<'a> { self.accept_number(5, 65535, false)? }; - Ok(IpEndpoint { addr: IpAddress::Ipv4(ip), port: port as u16 }) + Ok(IpEndpoint { + addr: IpAddress::Ipv4(ip), + port: port as u16, + }) } #[cfg(feature = "proto-ipv6")] fn accept_ipv6_endpoint(&mut self) -> Result { if self.lookahead_char(b'[') { self.accept_char(b'[')?; - let ip = self.accept_ipv6(false)?; + let ip = self.accept_ipv6()?; self.accept_char(b']')?; self.accept_char(b':')?; let port = self.accept_number(5, 65535, false)?; - Ok(IpEndpoint { addr: IpAddress::Ipv6(ip), port: port as u16 }) + Ok(IpEndpoint { + addr: IpAddress::Ipv6(ip), + port: port as u16, + }) } else { - let ip = self.accept_ipv6(false)?; - Ok(IpEndpoint { addr: IpAddress::Ipv6(ip), port: 0 }) + let ip = self.accept_ipv6()?; + Ok(IpEndpoint { + addr: IpAddress::Ipv6(ip), + port: 0, + }) } } fn accept_ip_endpoint(&mut self) -> Result { #[cfg(feature = "proto-ipv4")] - match self.try(|p| p.accept_ipv4_endpoint()) { + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv4_endpoint()) { Some(ipv4) => return Ok(ipv4), - None => () + None => (), } #[cfg(feature = "proto-ipv6")] - match self.try(|p| p.accept_ipv6_endpoint()) { + #[allow(clippy::single_match)] + match self.try_do(|p| p.accept_ipv6_endpoint()) { Some(ipv6) => return Ok(ipv6), - None => () + None => (), } Err(()) } } -#[cfg(feature = "ethernet")] +#[cfg(feature = "medium-ethernet")] impl FromStr for EthernetAddress { type Err = (); @@ -374,7 +399,7 @@ impl FromStr for Ipv6Address { /// Parse a string representation of an IPv6 address. fn from_str(s: &str) -> Result { - Parser::new(s).until_eof(|p| p.accept_ipv6(false)) + Parser::new(s).until_eof(|p| p.accept_ipv6()) } } @@ -410,7 +435,7 @@ impl FromStr for Ipv6Cidr { fn from_str(s: &str) -> Result { // https://tools.ietf.org/html/rfc4291#section-2.3 Parser::new(s).until_eof(|p| { - let ip = p.accept_ipv6(true)?; + let ip = p.accept_ipv6()?; p.accept_char(b'/')?; let prefix_len = p.accept_number(3, 129, false)? as u8; Ok(Ipv6Cidr::new(ip, prefix_len)) @@ -424,15 +449,17 @@ impl FromStr for IpCidr { /// Parse a string representation of an IP CIDR. fn from_str(s: &str) -> Result { #[cfg(feature = "proto-ipv4")] + #[allow(clippy::single_match)] match Ipv4Cidr::from_str(s) { Ok(cidr) => return Ok(IpCidr::Ipv4(cidr)), - Err(_) => () + Err(_) => (), } #[cfg(feature = "proto-ipv6")] + #[allow(clippy::single_match)] match Ipv6Cidr::from_str(s) { Ok(cidr) => return Ok(IpCidr::Ipv6(cidr)), - Err(_) => () + Err(_) => (), } Err(()) @@ -443,7 +470,7 @@ impl FromStr for IpEndpoint { type Err = (); fn from_str(s: &str) -> Result { - Parser::new(s).until_eof(|p| Ok(p.accept_ip_endpoint()?)) + Parser::new(s).until_eof(|p| p.accept_ip_endpoint()) } } @@ -459,29 +486,40 @@ mod test { if let Ok(cidr) = cidr { assert_eq!($from_str(&format!("{}", cidr)), Ok(cidr)); - assert_eq!(IpCidr::from_str(&format!("{}", cidr)), - Ok($variant(cidr))); + assert_eq!(IpCidr::from_str(&format!("{}", cidr)), Ok($variant(cidr))); } } - } + }; } #[test] - #[cfg(all(feature = "proto-ipv4", feature = "ethernet"))] + #[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] fn test_mac() { assert_eq!(EthernetAddress::from_str(""), Err(())); - assert_eq!(EthernetAddress::from_str("02:00:00:00:00:00"), - Ok(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x00]))); - assert_eq!(EthernetAddress::from_str("01:23:45:67:89:ab"), - Ok(EthernetAddress([0x01, 0x23, 0x45, 0x67, 0x89, 0xab]))); - assert_eq!(EthernetAddress::from_str("cd:ef:10:00:00:00"), - Ok(EthernetAddress([0xcd, 0xef, 0x10, 0x00, 0x00, 0x00]))); - assert_eq!(EthernetAddress::from_str("00:00:00:ab:cd:ef"), - Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef]))); - assert_eq!(EthernetAddress::from_str("00-00-00-ab-cd-ef"), - Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef]))); - assert_eq!(EthernetAddress::from_str("AB-CD-EF-00-00-00"), - Ok(EthernetAddress([0xab, 0xcd, 0xef, 0x00, 0x00, 0x00]))); + assert_eq!( + EthernetAddress::from_str("02:00:00:00:00:00"), + Ok(EthernetAddress([0x02, 0x00, 0x00, 0x00, 0x00, 0x00])) + ); + assert_eq!( + EthernetAddress::from_str("01:23:45:67:89:ab"), + Ok(EthernetAddress([0x01, 0x23, 0x45, 0x67, 0x89, 0xab])) + ); + assert_eq!( + EthernetAddress::from_str("cd:ef:10:00:00:00"), + Ok(EthernetAddress([0xcd, 0xef, 0x10, 0x00, 0x00, 0x00])) + ); + assert_eq!( + EthernetAddress::from_str("00:00:00:ab:cd:ef"), + Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef])) + ); + assert_eq!( + EthernetAddress::from_str("00-00-00-ab-cd-ef"), + Ok(EthernetAddress([0x00, 0x00, 0x00, 0xab, 0xcd, 0xef])) + ); + assert_eq!( + EthernetAddress::from_str("AB-CD-EF-00-00-00"), + Ok(EthernetAddress([0xab, 0xcd, 0xef, 0x00, 0x00, 0x00])) + ); assert_eq!(EthernetAddress::from_str("100:00:00:00:00:00"), Err(())); assert_eq!(EthernetAddress::from_str("002:00:00:00:00:00"), Err(())); assert_eq!(EthernetAddress::from_str("02:00:00:00:00:000"), Err(())); @@ -492,10 +530,14 @@ mod test { #[cfg(feature = "proto-ipv4")] fn test_ipv4() { assert_eq!(Ipv4Address::from_str(""), Err(())); - assert_eq!(Ipv4Address::from_str("1.2.3.4"), - Ok(Ipv4Address([1, 2, 3, 4]))); - assert_eq!(Ipv4Address::from_str("001.2.3.4"), - Ok(Ipv4Address([1, 2, 3, 4]))); + assert_eq!( + Ipv4Address::from_str("1.2.3.4"), + Ok(Ipv4Address([1, 2, 3, 4])) + ); + assert_eq!( + Ipv4Address::from_str("001.2.3.4"), + Ok(Ipv4Address([1, 2, 3, 4])) + ); assert_eq!(Ipv4Address::from_str("0001.2.3.4"), Err(())); assert_eq!(Ipv4Address::from_str("999.2.3.4"), Err(())); assert_eq!(Ipv4Address::from_str("1.2.3.4.5"), Err(())); @@ -509,73 +551,87 @@ mod test { fn test_ipv6() { // Obviously not valid assert_eq!(Ipv6Address::from_str(""), Err(())); - assert_eq!(Ipv6Address::from_str("fe80:0:0:0:0:0:0:1"), - Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))); - assert_eq!(Ipv6Address::from_str("::1"), - Ok(Ipv6Address::LOOPBACK)); - assert_eq!(Ipv6Address::from_str("::"), - Ok(Ipv6Address::UNSPECIFIED)); - assert_eq!(Ipv6Address::from_str("fe80::1"), - Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1))); - assert_eq!(Ipv6Address::from_str("1234:5678::"), - Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0, 0))); - assert_eq!(Ipv6Address::from_str("1234:5678::8765:4321"), - Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0x8765, 0x4321))); + assert_eq!( + Ipv6Address::from_str("fe80:0:0:0:0:0:0:1"), + Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)) + ); + assert_eq!(Ipv6Address::from_str("::1"), Ok(Ipv6Address::LOOPBACK)); + assert_eq!(Ipv6Address::from_str("::"), Ok(Ipv6Address::UNSPECIFIED)); + assert_eq!( + Ipv6Address::from_str("fe80::1"), + Ok(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)) + ); + assert_eq!( + Ipv6Address::from_str("1234:5678::"), + Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0, 0)) + ); + assert_eq!( + Ipv6Address::from_str("1234:5678::8765:4321"), + Ok(Ipv6Address::new(0x1234, 0x5678, 0, 0, 0, 0, 0x8765, 0x4321)) + ); // Two double colons in address - assert_eq!(Ipv6Address::from_str("1234:5678::1::1"), - Err(())); - assert_eq!(Ipv6Address::from_str("4444:333:22:1::4"), - Ok(Ipv6Address::new(0x4444, 0x0333, 0x0022, 0x0001, 0, 0, 0, 4))); - assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1::"), - Ok(Ipv6Address::new(1, 1, 1, 1, 1, 1, 0, 0))); - assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1"), - Ok(Ipv6Address::new(0, 0, 1, 1, 1, 1, 1, 1))); - assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), - Err(())); + assert_eq!(Ipv6Address::from_str("1234:5678::1::1"), Err(())); + assert_eq!( + Ipv6Address::from_str("4444:333:22:1::4"), + Ok(Ipv6Address::new(0x4444, 0x0333, 0x0022, 0x0001, 0, 0, 0, 4)) + ); + assert_eq!( + Ipv6Address::from_str("1:1:1:1:1:1::"), + Ok(Ipv6Address::new(1, 1, 1, 1, 1, 1, 0, 0)) + ); + assert_eq!( + Ipv6Address::from_str("::1:1:1:1:1:1"), + Ok(Ipv6Address::new(0, 0, 1, 1, 1, 1, 1, 1)) + ); + assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), Err(())); // Double colon appears too late indicating an address that is too long - assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1::"), - Err(())); + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1::"), Err(())); // Section after double colon is too long for a valid address - assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), - Err(())); + assert_eq!(Ipv6Address::from_str("::1:1:1:1:1:1:1"), Err(())); // Obviously too long - assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1:1:1"), - Err(())); + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1:1:1"), Err(())); // Address is too short - assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1"), - Err(())); + assert_eq!(Ipv6Address::from_str("1:1:1:1:1:1:1"), Err(())); // Long number - assert_eq!(Ipv6Address::from_str("::000001"), - Err(())); + assert_eq!(Ipv6Address::from_str("::000001"), Err(())); // IPv4-Mapped address - assert_eq!(Ipv6Address::from_str("::ffff:192.168.1.1"), - Ok(Ipv6Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]))); - assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:192.168.1.1"), - Ok(Ipv6Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]))); - assert_eq!(Ipv6Address::from_str("0::ffff:192.168.1.1"), - Ok(Ipv6Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]))); + assert_eq!( + Ipv6Address::from_str("::ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); + assert_eq!( + Ipv6Address::from_str("0:0:0:0:0:ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); + assert_eq!( + Ipv6Address::from_str("0::ffff:192.168.1.1"), + Ok(Ipv6Address([ + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1 + ])) + ); // Only ffff is allowed in position 6 when IPv4 mapped - assert_eq!(Ipv6Address::from_str("0:0:0:0:0:eeee:192.168.1.1"), - Err(())); + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:eeee:192.168.1.1"), Err(())); // Positions 1-5 must be 0 when IPv4 mapped - assert_eq!(Ipv6Address::from_str("0:0:0:0:1:ffff:192.168.1.1"), - Err(())); - assert_eq!(Ipv6Address::from_str("1::ffff:192.168.1.1"), - Err(())); + assert_eq!(Ipv6Address::from_str("0:0:0:0:1:ffff:192.168.1.1"), Err(())); + assert_eq!(Ipv6Address::from_str("1::ffff:192.168.1.1"), Err(())); // Out of range ipv4 octet - assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:256.168.1.1"), - Err(())); + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:256.168.1.1"), Err(())); // Invalid hex in ipv4 octet - assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:c0.168.1.1"), - Err(())); + assert_eq!(Ipv6Address::from_str("0:0:0:0:0:ffff:c0.168.1.1"), Err(())); } #[test] #[cfg(feature = "proto-ipv4")] fn test_ip_ipv4() { assert_eq!(IpAddress::from_str(""), Err(())); - assert_eq!(IpAddress::from_str("1.2.3.4"), - Ok(IpAddress::Ipv4(Ipv4Address([1, 2, 3, 4])))); + assert_eq!( + IpAddress::from_str("1.2.3.4"), + Ok(IpAddress::Ipv4(Ipv4Address([1, 2, 3, 4]))) + ); assert_eq!(IpAddress::from_str("x"), Err(())); } @@ -583,8 +639,12 @@ mod test { #[cfg(feature = "proto-ipv6")] fn test_ip_ipv6() { assert_eq!(IpAddress::from_str(""), Err(())); - assert_eq!(IpAddress::from_str("fe80::1"), - Ok(IpAddress::Ipv6(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)))); + assert_eq!( + IpAddress::from_str("fe80::1"), + Ok(IpAddress::Ipv6(Ipv6Address::new( + 0xfe80, 0, 0, 0, 0, 0, 0, 1 + ))) + ); assert_eq!(IpAddress::from_str("x"), Err(())); } @@ -592,14 +652,22 @@ mod test { #[cfg(feature = "proto-ipv4")] fn test_cidr_ipv4() { let tests = [ - ("127.0.0.1/8", - Ok(Ipv4Cidr::new(Ipv4Address([127, 0, 0, 1]), 8u8))), - ("192.168.1.1/24", - Ok(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 1]), 24u8))), - ("8.8.8.8/32", - Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 32u8))), - ("8.8.8.8/0", - Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 0u8))), + ( + "127.0.0.1/8", + Ok(Ipv4Cidr::new(Ipv4Address([127, 0, 0, 1]), 8u8)), + ), + ( + "192.168.1.1/24", + Ok(Ipv4Cidr::new(Ipv4Address([192, 168, 1, 1]), 24u8)), + ), + ( + "8.8.8.8/32", + Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 32u8)), + ), + ( + "8.8.8.8/0", + Ok(Ipv4Cidr::new(Ipv4Address([8, 8, 8, 8]), 0u8)), + ), ("", Err(())), ("1", Err(())), ("127.0.0.1", Err(())), @@ -616,22 +684,32 @@ mod test { #[cfg(feature = "proto-ipv6")] fn test_cidr_ipv6() { let tests = [ - ("fe80::1/64", - Ok(Ipv6Cidr::new(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64u8))), - ("fe80::/64", - Ok(Ipv6Cidr::new(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0), 64u8))), - ("::1/128", - Ok(Ipv6Cidr::new(Ipv6Address::LOOPBACK, 128u8))), - ("::/128", - Ok(Ipv6Cidr::new(Ipv6Address::UNSPECIFIED, 128u8))), - ("fe80:0:0:0:0:0:0:1/64", - Ok(Ipv6Cidr::new(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), 64u8))), - ("fe80:0:0:0:0:0:0:1|64", - Err(())), - ("fe80::|64", - Err(())), - ("fe80::1::/64", - Err(())) + ( + "fe80::1/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + 64u8, + )), + ), + ( + "fe80::/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 0), + 64u8, + )), + ), + ("::1/128", Ok(Ipv6Cidr::new(Ipv6Address::LOOPBACK, 128u8))), + ("::/128", Ok(Ipv6Cidr::new(Ipv6Address::UNSPECIFIED, 128u8))), + ( + "fe80:0:0:0:0:0:0:1/64", + Ok(Ipv6Cidr::new( + Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), + 64u8, + )), + ), + ("fe80:0:0:0:0:0:0:1|64", Err(())), + ("fe80::|64", Err(())), + ("fe80::1::/64", Err(())), ]; check_cidr_test_array!(tests, Ipv6Cidr::from_str, IpCidr::Ipv6); } @@ -643,11 +721,17 @@ mod test { assert_eq!(IpEndpoint::from_str("x"), Err(())); assert_eq!( IpEndpoint::from_str("127.0.0.1"), - Ok(IpEndpoint { addr: IpAddress::v4(127, 0, 0, 1), port: 0 }) + Ok(IpEndpoint { + addr: IpAddress::v4(127, 0, 0, 1), + port: 0 + }) ); assert_eq!( IpEndpoint::from_str("127.0.0.1:12345"), - Ok(IpEndpoint { addr: IpAddress::v4(127, 0, 0, 1), port: 12345 }) + Ok(IpEndpoint { + addr: IpAddress::v4(127, 0, 0, 1), + port: 12345 + }) ); } @@ -658,11 +742,24 @@ mod test { assert_eq!(IpEndpoint::from_str("x"), Err(())); assert_eq!( IpEndpoint::from_str("fe80::1"), - Ok(IpEndpoint { addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), port: 0 }) + Ok(IpEndpoint { + addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), + port: 0 + }) ); assert_eq!( IpEndpoint::from_str("[fe80::1]:12345"), - Ok(IpEndpoint { addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), port: 12345 }) + Ok(IpEndpoint { + addr: IpAddress::v6(0xfe80, 0, 0, 0, 0, 0, 0, 1), + port: 12345 + }) + ); + assert_eq!( + IpEndpoint::from_str("[::]:12345"), + Ok(IpEndpoint { + addr: IpAddress::v6(0, 0, 0, 0, 0, 0, 0, 0), + port: 12345 + }) ); } } diff --git a/src/phy/fault_injector.rs b/src/phy/fault_injector.rs index c5e2c4b96..fffe11a26 100644 --- a/src/phy/fault_injector.rs +++ b/src/phy/fault_injector.rs @@ -1,8 +1,7 @@ -use core::cell::RefCell; +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::{Duration, Instant}; -use {Error, Result}; -use phy::{self, DeviceCapabilities, Device}; -use time::{Duration, Instant}; +use super::PacketMeta; // We use our own RNG to stay compatible with #![no_std]. // The use of the RNG below has a slight bias, but it doesn't matter. @@ -19,22 +18,23 @@ fn xorshift32(state: &mut u32) -> u32 { const MTU: usize = 1536; #[derive(Debug, Default, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] struct Config { corrupt_pct: u8, - drop_pct: u8, - reorder_pct: u8, - max_size: usize, + drop_pct: u8, + max_size: usize, max_tx_rate: u64, max_rx_rate: u64, - interval: Duration, + interval: Duration, } #[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] struct State { - rng_seed: u32, + rng_seed: u32, refilled_at: Instant, - tx_bucket: u64, - rx_bucket: u64, + tx_bucket: u64, + rx_bucket: u64, } impl State { @@ -46,7 +46,7 @@ impl State { let buffer = buffer.as_mut(); // We introduce a single bitflip, as the most likely, and the hardest to detect, error. let index = (xorshift32(&mut self.rng_seed) as usize) % buffer.len(); - let bit = 1 << (xorshift32(&mut self.rng_seed) % 8) as u8; + let bit = 1 << (xorshift32(&mut self.rng_seed) % 8) as u8; buffer[index] ^= bit; } @@ -59,7 +59,9 @@ impl State { } fn maybe_transmit(&mut self, config: &Config, timestamp: Instant) -> bool { - if config.max_tx_rate == 0 { return true } + if config.max_tx_rate == 0 { + return true; + } self.refill(config, timestamp); if self.tx_bucket > 0 { @@ -71,7 +73,9 @@ impl State { } fn maybe_receive(&mut self, config: &Config, timestamp: Instant) -> bool { - if config.max_rx_rate == 0 { return true } + if config.max_rx_rate == 0 { + return true; + } self.refill(config, timestamp); if self.rx_bucket > 0 { @@ -89,25 +93,26 @@ impl State { /// adverse network conditions (such as random packet loss or corruption), or software /// or hardware limitations (such as a limited number or size of usable network buffers). #[derive(Debug)] -pub struct FaultInjector Device<'a>> { - inner: D, - state: RefCell, - config: Config, +pub struct FaultInjector { + inner: D, + state: State, + config: Config, + rx_buf: [u8; MTU], } -impl Device<'a>> FaultInjector { +impl FaultInjector { /// Create a fault injector device, using the given random number generator seed. pub fn new(inner: D, seed: u32) -> FaultInjector { - let state = State { - rng_seed: seed, - refilled_at: Instant::from_millis(0), - tx_bucket: 0, - rx_bucket: 0, - }; FaultInjector { - inner: inner, - state: RefCell::new(state), + inner, + state: State { + rng_seed: seed, + refilled_at: Instant::from_millis(0), + tx_bucket: 0, + rx_bucket: 0, + }, config: Config::default(), + rx_buf: [0u8; MTU], } } @@ -133,12 +138,12 @@ impl Device<'a>> FaultInjector { /// Return the maximum packet transmission rate, in packets per second. pub fn max_tx_rate(&self) -> u64 { - self.config.max_rx_rate + self.config.max_tx_rate } /// Return the maximum packet reception rate, in packets per second. pub fn max_rx_rate(&self) -> u64 { - self.config.max_tx_rate + self.config.max_rx_rate } /// Return the interval for packet rate limiting, in milliseconds. @@ -151,7 +156,9 @@ impl Device<'a>> FaultInjector { /// # Panics /// This function panics if the probability is not between 0% and 100%. pub fn set_corrupt_chance(&mut self, pct: u8) { - if pct > 100 { panic!("percentage out of range") } + if pct > 100 { + panic!("percentage out of range") + } self.config.corrupt_pct = pct } @@ -160,7 +167,9 @@ impl Device<'a>> FaultInjector { /// # Panics /// This function panics if the probability is not between 0% and 100%. pub fn set_drop_chance(&mut self, pct: u8) { - if pct > 100 { panic!("percentage out of range") } + if pct > 100 { + panic!("percentage out of range") + } self.config.drop_pct = pct } @@ -181,16 +190,18 @@ impl Device<'a>> FaultInjector { /// Set the interval for packet rate limiting, in milliseconds. pub fn set_bucket_interval(&mut self, interval: Duration) { - self.state.borrow_mut().refilled_at = Instant::from_millis(0); + self.state.refilled_at = Instant::from_millis(0); self.config.interval = interval } } -impl<'a, D> Device<'a> for FaultInjector - where D: for<'b> Device<'b>, -{ - type RxToken = RxToken<'a, >::RxToken>; - type TxToken = TxToken<'a, >::TxToken>; +impl Device for FaultInjector { + type RxToken<'a> = RxToken<'a> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>> + where + Self: 'a; fn capabilities(&self) -> DeviceCapabilities { let mut caps = self.inner.capabilities(); @@ -200,94 +211,100 @@ impl<'a, D> Device<'a> for FaultInjector caps } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - let &mut Self { ref mut inner, ref state, config } = self; - inner.receive().map(|(rx_token, tx_token)| { - let rx = RxToken { - state: &state, - config: config, - token: rx_token, - corrupt: [0; MTU], - }; - let tx = TxToken { - state: &state, - config: config, - token: tx_token, - junk: [0; MTU], - }; - (rx, tx) - }) + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let (rx_token, tx_token) = self.inner.receive(timestamp)?; + let rx_meta = as phy::RxToken>::meta(&rx_token); + + let len = super::RxToken::consume(rx_token, |buffer| { + if (self.config.max_size > 0 && buffer.len() > self.config.max_size) + || buffer.len() > self.rx_buf.len() + { + net_trace!("rx: dropping a packet that is too large"); + return None; + } + self.rx_buf[..buffer.len()].copy_from_slice(buffer); + Some(buffer.len()) + })?; + + let buf = &mut self.rx_buf[..len]; + + if self.state.maybe(self.config.drop_pct) { + net_trace!("rx: randomly dropping a packet"); + return None; + } + + if !self.state.maybe_receive(&self.config, timestamp) { + net_trace!("rx: dropping a packet because of rate limiting"); + return None; + } + + if self.state.maybe(self.config.corrupt_pct) { + net_trace!("rx: randomly corrupting a packet"); + self.state.corrupt(&mut buf[..]); + } + + let rx = RxToken { buf, meta: rx_meta }; + let tx = TxToken { + state: &mut self.state, + config: self.config, + token: tx_token, + junk: [0; MTU], + timestamp, + }; + Some((rx, tx)) } - fn transmit(&'a mut self) -> Option { - let &mut Self { ref mut inner, ref state, config } = self; - inner.transmit().map(|token| TxToken { - state: &state, - config: config, - token: token, - junk: [0; MTU], + fn transmit(&mut self, timestamp: Instant) -> Option> { + self.inner.transmit(timestamp).map(|token| TxToken { + state: &mut self.state, + config: self.config, + token, + junk: [0; MTU], + timestamp, }) } } #[doc(hidden)] -pub struct RxToken<'a, Rx: phy::RxToken> { - state: &'a RefCell, - config: Config, - token: Rx, - corrupt: [u8; MTU], +pub struct RxToken<'a> { + buf: &'a mut [u8], + meta: PacketMeta, } -impl<'a, Rx: phy::RxToken> phy::RxToken for RxToken<'a, Rx> { - fn consume(self, timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result +impl<'a> phy::RxToken for RxToken<'a> { + fn consume(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - if self.state.borrow_mut().maybe(self.config.drop_pct) { - net_trace!("rx: randomly dropping a packet"); - return Err(Error::Exhausted) - } - if !self.state.borrow_mut().maybe_receive(&self.config, timestamp) { - net_trace!("rx: dropping a packet because of rate limiting"); - return Err(Error::Exhausted) - } - let Self { token, config, state, mut corrupt } = self; - token.consume(timestamp, |buffer| { - if config.max_size > 0 && buffer.as_ref().len() > config.max_size { - net_trace!("rx: dropping a packet that is too large"); - return Err(Error::Exhausted) - } - if state.borrow_mut().maybe(config.corrupt_pct) { - net_trace!("rx: randomly corrupting a packet"); - let mut corrupt = &mut corrupt[..buffer.len()]; - corrupt.copy_from_slice(buffer); - state.borrow_mut().corrupt(&mut corrupt); - f(&mut corrupt) - } else { - f(buffer) - } - }) + f(self.buf) + } + + fn meta(&self) -> phy::PacketMeta { + self.meta } } #[doc(hidden)] pub struct TxToken<'a, Tx: phy::TxToken> { - state: &'a RefCell, + state: &'a mut State, config: Config, - token: Tx, - junk: [u8; MTU], + token: Tx, + junk: [u8; MTU], + timestamp: Instant, } impl<'a, Tx: phy::TxToken> phy::TxToken for TxToken<'a, Tx> { - fn consume(mut self, timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(mut self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let drop = if self.state.borrow_mut().maybe(self.config.drop_pct) { + let drop = if self.state.maybe(self.config.drop_pct) { net_trace!("tx: randomly dropping a packet"); true } else if self.config.max_size > 0 && len > self.config.max_size { net_trace!("tx: dropping a packet that is too large"); true - } else if !self.state.borrow_mut().maybe_transmit(&self.config, timestamp) { + } else if !self.state.maybe_transmit(&self.config, self.timestamp) { net_trace!("tx: dropping a packet because of rate limiting"); true } else { @@ -295,16 +312,19 @@ impl<'a, Tx: phy::TxToken> phy::TxToken for TxToken<'a, Tx> { }; if drop { - return f(&mut self.junk); + return f(&mut self.junk[..len]); } - let Self { token, state, config, .. } = self; - token.consume(timestamp, len, |mut buf| { - if state.borrow_mut().maybe(config.corrupt_pct) { + self.token.consume(len, |mut buf| { + if self.state.maybe(self.config.corrupt_pct) { net_trace!("tx: corrupting a packet"); - state.borrow_mut().corrupt(&mut buf) + self.state.corrupt(&mut buf) } f(buf) }) } + + fn set_meta(&mut self, meta: PacketMeta) { + self.token.set_meta(meta); + } } diff --git a/src/phy/fuzz_injector.rs b/src/phy/fuzz_injector.rs index 83a6e4251..6769d8ec0 100644 --- a/src/phy/fuzz_injector.rs +++ b/src/phy/fuzz_injector.rs @@ -1,6 +1,5 @@ -use Result; -use phy::{self, DeviceCapabilities, Device}; -use time::Instant; +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::Instant; // This could be fixed once associated consts are stable. const MTU: usize = 1536; @@ -18,17 +17,22 @@ pub trait Fuzzer { /// smoltcp, and is not for production use. #[allow(unused)] #[derive(Debug)] -pub struct FuzzInjector Device<'a>, FTx: Fuzzer, FRx: Fuzzer> { - inner: D, +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct FuzzInjector { + inner: D, fuzz_tx: FTx, fuzz_rx: FRx, } #[allow(unused)] -impl Device<'a>, FTx: Fuzzer, FRx: Fuzzer> FuzzInjector { +impl FuzzInjector { /// Create a fuzz injector device. pub fn new(inner: D, fuzz_tx: FTx, fuzz_rx: FRx) -> FuzzInjector { - FuzzInjector { inner, fuzz_tx, fuzz_rx } + FuzzInjector { + inner, + fuzz_tx, + fuzz_rx, + } } /// Return the underlying device, consuming the fuzz injector. @@ -37,13 +41,17 @@ impl Device<'a>, FTx: Fuzzer, FRx: Fuzzer> FuzzInjector } } -impl<'a, D, FTx, FRx> Device<'a> for FuzzInjector - where D: for<'b> Device<'b>, - FTx: Fuzzer + 'a, - FRx: Fuzzer + 'a +impl Device for FuzzInjector +where + FTx: Fuzzer, + FRx: Fuzzer, { - type RxToken = RxToken<'a, >::RxToken, FRx>; - type TxToken = TxToken<'a, >::TxToken, FTx>; + type RxToken<'a> = RxToken<'a, D::RxToken<'a>, FRx> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>, FTx> + where + Self: 'a; fn capabilities(&self) -> DeviceCapabilities { let mut caps = self.inner.capabilities(); @@ -53,62 +61,69 @@ impl<'a, D, FTx, FRx> Device<'a> for FuzzInjector caps } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - let &mut Self { ref mut inner, ref fuzz_rx, ref fuzz_tx } = self; - inner.receive().map(|(rx_token, tx_token)| { + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + self.inner.receive(timestamp).map(|(rx_token, tx_token)| { let rx = RxToken { - fuzzer: fuzz_rx, - token: rx_token, + fuzzer: &mut self.fuzz_rx, + token: rx_token, }; let tx = TxToken { - fuzzer: fuzz_tx, - token: tx_token, + fuzzer: &mut self.fuzz_tx, + token: tx_token, }; (rx, tx) }) } - fn transmit(&'a mut self) -> Option { - let &mut Self { ref mut inner, fuzz_rx: _, ref fuzz_tx } = self; - inner.transmit().map(|token| TxToken { - fuzzer: fuzz_tx, - token: token, + fn transmit(&mut self, timestamp: Instant) -> Option> { + self.inner.transmit(timestamp).map(|token| TxToken { + fuzzer: &mut self.fuzz_tx, + token: token, }) } } #[doc(hidden)] -pub struct RxToken<'a, Rx: phy::RxToken, F: Fuzzer + 'a>{ +pub struct RxToken<'a, Rx: phy::RxToken, F: Fuzzer + 'a> { fuzzer: &'a F, - token: Rx, + token: Rx, } impl<'a, Rx: phy::RxToken, FRx: Fuzzer> phy::RxToken for RxToken<'a, Rx, FRx> { - fn consume(self, timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let Self { fuzzer, token } = self; - token.consume(timestamp, |buffer| { - fuzzer.fuzz_packet(buffer); + self.token.consume(|buffer| { + self.fuzzer.fuzz_packet(buffer); f(buffer) }) } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } } #[doc(hidden)] pub struct TxToken<'a, Tx: phy::TxToken, F: Fuzzer + 'a> { fuzzer: &'a F, - token: Tx, + token: Tx, } impl<'a, Tx: phy::TxToken, FTx: Fuzzer> phy::TxToken for TxToken<'a, Tx, FTx> { - fn consume(self, timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let Self { fuzzer, token } = self; - token.consume(timestamp, len, |mut buf| { - fuzzer.fuzz_packet(&mut buf); - f(buf) + self.token.consume(len, |buf| { + let result = f(buf); + self.fuzzer.fuzz_packet(buf); + result }) } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } } diff --git a/src/phy/loopback.rs b/src/phy/loopback.rs index 7db99aaa3..1f57c0ca4 100644 --- a/src/phy/loopback.rs +++ b/src/phy/loopback.rs @@ -1,56 +1,53 @@ -#[cfg(feature = "std")] -use std::vec::Vec; -#[cfg(feature = "std")] -use std::collections::VecDeque; -#[cfg(feature = "alloc")] -use alloc::vec::Vec; -#[cfg(all(feature = "alloc", not(feature = "rust-1_28")))] use alloc::collections::VecDeque; -#[cfg(all(feature = "alloc", feature = "rust-1_28"))] -use alloc::VecDeque; +use alloc::vec::Vec; -use Result; -use phy::{self, Device, DeviceCapabilities}; -use time::Instant; +use crate::phy::{self, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; /// A loopback device. #[derive(Debug)] pub struct Loopback { - queue: VecDeque>, + pub(crate) queue: VecDeque>, + medium: Medium, } +#[allow(clippy::new_without_default)] impl Loopback { /// Creates a loopback device. /// /// Every packet transmitted through this device will be received through it /// in FIFO order. - pub fn new() -> Loopback { + pub fn new(medium: Medium) -> Loopback { Loopback { queue: VecDeque::new(), + medium, } } } -impl<'a> Device<'a> for Loopback { - type RxToken = RxToken; - type TxToken = TxToken<'a>; +impl Device for Loopback { + type RxToken<'a> = RxToken; + type TxToken<'a> = TxToken<'a>; fn capabilities(&self) -> DeviceCapabilities { DeviceCapabilities { max_transmission_unit: 65535, + medium: self.medium, ..DeviceCapabilities::default() } } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { self.queue.pop_front().map(move |buffer| { - let rx = RxToken { buffer: buffer }; - let tx = TxToken { queue: &mut self.queue }; + let rx = RxToken { buffer }; + let tx = TxToken { + queue: &mut self.queue, + }; (rx, tx) }) } - fn transmit(&'a mut self) -> Option { + fn transmit(&mut self, _timestamp: Instant) -> Option> { Some(TxToken { queue: &mut self.queue, }) @@ -63,21 +60,24 @@ pub struct RxToken { } impl phy::RxToken for RxToken { - fn consume(mut self, _timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { f(&mut self.buffer) } } #[doc(hidden)] +#[derive(Debug)] pub struct TxToken<'a> { queue: &'a mut VecDeque>, } impl<'a> phy::TxToken for TxToken<'a> { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { let mut buffer = Vec::new(); buffer.resize(len, 0); diff --git a/src/phy/mod.rs b/src/phy/mod.rs index 47146b432..677e78b47 100644 --- a/src/phy/mod.rs +++ b/src/phy/mod.rs @@ -8,17 +8,19 @@ and implementations of it: * _middleware_ [Tracer](struct.Tracer.html) and [FaultInjector](struct.FaultInjector.html), to facilitate debugging; * _adapters_ [RawSocket](struct.RawSocket.html) and - [TapInterface](struct.TapInterface.html), to transmit and receive frames + [TunTapInterface](struct.TunTapInterface.html), to transmit and receive frames on the host OS. - +*/ +#![cfg_attr( + feature = "medium-ethernet", + doc = r##" # Examples An implementation of the [Device](trait.Device.html) trait for a simple hardware Ethernet controller could look as follows: ```rust -use smoltcp::Result; -use smoltcp::phy::{self, DeviceCapabilities, Device}; +use smoltcp::phy::{self, DeviceCapabilities, Device, Medium}; use smoltcp::time::Instant; struct StmPhy { @@ -35,16 +37,16 @@ impl<'a> StmPhy { } } -impl<'a> phy::Device<'a> for StmPhy { - type RxToken = StmPhyRxToken<'a>; - type TxToken = StmPhyTxToken<'a>; +impl phy::Device for StmPhy { + type RxToken<'a> = StmPhyRxToken<'a> where Self: 'a; + type TxToken<'a> = StmPhyTxToken<'a> where Self: 'a; - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { Some((StmPhyRxToken(&mut self.rx_buffer[..]), StmPhyTxToken(&mut self.tx_buffer[..]))) } - fn transmit(&'a mut self) -> Option { + fn transmit(&mut self, _timestamp: Instant) -> Option> { Some(StmPhyTxToken(&mut self.tx_buffer[..])) } @@ -52,6 +54,7 @@ impl<'a> phy::Device<'a> for StmPhy { let mut caps = DeviceCapabilities::default(); caps.max_transmission_unit = 1536; caps.max_burst_size = Some(1); + caps.medium = Medium::Ethernet; caps } } @@ -59,8 +62,8 @@ impl<'a> phy::Device<'a> for StmPhy { struct StmPhyRxToken<'a>(&'a mut [u8]); impl<'a> phy::RxToken for StmPhyRxToken<'a> { - fn consume(mut self, _timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(mut self, f: F) -> R + where F: FnOnce(&mut [u8]) -> R { // TODO: receive packet into buffer let result = f(&mut self.0); @@ -72,8 +75,8 @@ impl<'a> phy::RxToken for StmPhyRxToken<'a> { struct StmPhyTxToken<'a>(&'a mut [u8]); impl<'a> phy::TxToken for StmPhyTxToken<'a> { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(self, len: usize, f: F) -> R + where F: FnOnce(&mut [u8]) -> R { let result = f(&mut self.0[..len]); println!("tx called {}", len); @@ -82,47 +85,88 @@ impl<'a> phy::TxToken for StmPhyTxToken<'a> { } } ``` -*/ +"## +)] -use Result; -use time::Instant; +use crate::time::Instant; -#[cfg(all(any(feature = "phy-raw_socket", feature = "phy-tap_interface"), unix))] +#[cfg(all( + any(feature = "phy-raw_socket", feature = "phy-tuntap_interface"), + unix +))] mod sys; -mod tracer; mod fault_injector; mod fuzz_injector; -mod pcap_writer; -#[cfg(any(feature = "std", feature = "alloc"))] +#[cfg(feature = "alloc")] mod loopback; +mod pcap_writer; #[cfg(all(feature = "phy-raw_socket", unix))] mod raw_socket; -#[cfg(all(feature = "phy-tap_interface", target_os = "linux"))] -mod tap_interface; - -#[cfg(all(any(feature = "phy-raw_socket", feature = "phy-tap_interface"), unix))] +mod tracer; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +mod tuntap_interface; + +#[cfg(all( + any(feature = "phy-raw_socket", feature = "phy-tuntap_interface"), + unix +))] pub use self::sys::wait; -pub use self::tracer::Tracer; pub use self::fault_injector::FaultInjector; -pub use self::fuzz_injector::{Fuzzer, FuzzInjector}; -pub use self::pcap_writer::{PcapLinkType, PcapMode, PcapSink, PcapWriter}; -#[cfg(any(feature = "std", feature = "alloc"))] +pub use self::fuzz_injector::{FuzzInjector, Fuzzer}; +#[cfg(feature = "alloc")] pub use self::loopback::Loopback; +pub use self::pcap_writer::{PcapLinkType, PcapMode, PcapSink, PcapWriter}; #[cfg(all(feature = "phy-raw_socket", unix))] pub use self::raw_socket::RawSocket; -#[cfg(all(feature = "phy-tap_interface", target_os = "linux"))] -pub use self::tap_interface::TapInterface; +pub use self::tracer::Tracer; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub use self::tuntap_interface::TunTapInterface; -#[cfg(feature = "ethernet")] -/// A tracer device for Ethernet frames. -pub type EthernetTracer = Tracer>; +/// Metadata associated to a packet. +/// +/// The packet metadata is a set of attributes associated to network packets +/// as they travel up or down the stack. The metadata is get/set by the +/// [`Device`] implementations or by the user when sending/receiving packets from a +/// socket. +/// +/// Metadata fields are enabled via Cargo features. If no field is enabled, this +/// struct becomes zero-sized, which allows the compiler to optimize it out as if +/// the packet metadata mechanism didn't exist at all. +/// +/// Currently only UDP sockets allow setting/retrieving packet metadata. The metadata +/// for packets emitted with other sockets will be all default values. +/// +/// This struct is marked as `#[non_exhaustive]`. This means it is not possible to +/// create it directly by specifying all fields. You have to instead create it with +/// default values and then set the fields you want. This makes adding metadata +/// fields a non-breaking change. +/// +/// ```rust,ignore +/// let mut meta = PacketMeta::new(); +/// meta.id = 15; +/// ``` +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Hash, Clone, Copy, Default)] +#[non_exhaustive] +pub struct PacketMeta { + #[cfg(feature = "packetmeta-id")] + pub id: u32, +} /// A description of checksum behavior for a particular protocol. -#[derive(Debug, Clone, Copy)] +#[derive(Debug, Clone, Copy, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Checksum { /// Verify checksum when receiving and compute checksum when sending. + #[default] Both, /// Verify checksum when receiving. Rx, @@ -132,18 +176,12 @@ pub enum Checksum { None, } -impl Default for Checksum { - fn default() -> Checksum { - Checksum::Both - } -} - impl Checksum { /// Returns whether checksum should be verified when receiving. pub fn rx(&self) -> bool { match *self { Checksum::Both | Checksum::Rx => true, - _ => false + _ => false, } } @@ -151,13 +189,15 @@ impl Checksum { pub fn tx(&self) -> bool { match *self { Checksum::Both | Checksum::Tx => true, - _ => false + _ => false, } } } /// A description of checksum behavior for every supported protocol. #[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub struct ChecksumCapabilities { pub ipv4: Checksum, pub udp: Checksum, @@ -166,7 +206,6 @@ pub struct ChecksumCapabilities { pub icmpv4: Checksum, #[cfg(feature = "proto-ipv6")] pub icmpv6: Checksum, - dummy: (), } impl ChecksumCapabilities { @@ -181,7 +220,6 @@ impl ChecksumCapabilities { icmpv4: Checksum::None, #[cfg(feature = "proto-ipv6")] icmpv6: Checksum::None, - ..Self::default() } } } @@ -191,13 +229,28 @@ impl ChecksumCapabilities { /// Higher-level protocols may achieve higher throughput or lower latency if they consider /// the bandwidth or packet size limitations. #[derive(Debug, Clone, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub struct DeviceCapabilities { + /// Medium of the device. + /// + /// This indicates what kind of packet the sent/received bytes are, and determines + /// some behaviors of Interface. For example, ARP/NDISC address resolution is only done + /// for Ethernet mediums. + pub medium: Medium, + /// Maximum transmission unit. /// /// The network device is unable to send or receive frames larger than the value returned /// by this function. /// - /// For Ethernet, MTU will fall between 576 (for IPv4) or 1280 (for IPv6) and 9216 octets. + /// For Ethernet devices, this is the maximum Ethernet frame size, including the Ethernet header (14 octets), but + /// *not* including the Ethernet FCS (4 octets). Therefore, Ethernet MTU = IP MTU + 14. + /// + /// Note that in Linux and other OSes, "MTU" is the IP MTU, not the Ethernet MTU, even for Ethernet + /// devices. This is a common source of confusion. + /// + /// Most common IP MTU is 1500. Minimum is 576 (for IPv4) or 1280 (for IPv6). Maximum is 9216 octets. pub max_transmission_unit: usize, /// Maximum burst size, in terms of MTU. @@ -214,10 +267,64 @@ pub struct DeviceCapabilities { /// If the network device is capable of verifying or computing checksums for some protocols, /// it can request that the stack not do so in software to improve performance. pub checksum: ChecksumCapabilities, +} - /// Only present to prevent people from trying to initialize every field of DeviceLimits, - /// which would not let us add new fields in the future. - dummy: () +impl DeviceCapabilities { + pub fn ip_mtu(&self) -> usize { + match self.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + self.max_transmission_unit - crate::wire::EthernetFrame::<&[u8]>::header_len() + } + #[cfg(feature = "medium-ip")] + Medium::Ip => self.max_transmission_unit, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => self.max_transmission_unit, // TODO(thvdveld): what is the MTU for Medium::IEEE802 + } + } +} + +/// Type of medium of a device. +#[derive(Debug, Eq, PartialEq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Medium { + /// Ethernet medium. Devices of this type send and receive Ethernet frames, + /// and interfaces using it must do neighbor discovery via ARP or NDISC. + /// + /// Examples of devices of this type are Ethernet, WiFi (802.11), Linux `tap`, and VPNs in tap (layer 2) mode. + #[cfg(feature = "medium-ethernet")] + Ethernet, + + /// IP medium. Devices of this type send and receive IP frames, without an + /// Ethernet header. MAC addresses are not used, and no neighbor discovery (ARP, NDISC) is done. + /// + /// Examples of devices of this type are the Linux `tun`, PPP interfaces, VPNs in tun (layer 3) mode. + #[cfg(feature = "medium-ip")] + Ip, + + #[cfg(feature = "medium-ieee802154")] + Ieee802154, +} + +impl Default for Medium { + fn default() -> Medium { + #[cfg(feature = "medium-ethernet")] + return Medium::Ethernet; + #[cfg(all(feature = "medium-ip", not(feature = "medium-ethernet")))] + return Medium::Ip; + #[cfg(all( + feature = "medium-ieee802154", + not(feature = "medium-ip"), + not(feature = "medium-ethernet") + ))] + return Medium::Ieee802154; + #[cfg(all( + not(feature = "medium-ip"), + not(feature = "medium-ethernet"), + not(feature = "medium-ieee802154") + ))] + return panic!("No medium enabled"); + } } /// An interface for sending and receiving raw network frames. @@ -225,9 +332,13 @@ pub struct DeviceCapabilities { /// The interface is based on _tokens_, which are types that allow to receive/transmit a /// single packet. The `receive` and `transmit` functions only construct such tokens, the /// real sending/receiving operation are performed when the tokens are consumed. -pub trait Device<'a> { - type RxToken: RxToken + 'a; - type TxToken: TxToken + 'a; +pub trait Device { + type RxToken<'a>: RxToken + where + Self: 'a; + type TxToken<'a>: TxToken + where + Self: 'a; /// Construct a token pair consisting of one receive token and one transmit token. /// @@ -235,10 +346,16 @@ pub trait Device<'a> { /// on the contents of the received packet. For example, this makes it possible to /// handle arbitrarily large ICMP echo ("ping") requests, where the all received bytes /// need to be sent back, without heap allocation. - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)>; + /// + /// The timestamp must be a number of milliseconds, monotonically increasing since an + /// arbitrary moment in time, such as system startup. + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)>; /// Construct a transmit token. - fn transmit(&'a mut self) -> Option; + /// + /// The timestamp must be a number of milliseconds, monotonically increasing since an + /// arbitrary moment in time, such as system startup. + fn transmit(&mut self, timestamp: Instant) -> Option>; /// Get a description of device capabilities. fn capabilities(&self) -> DeviceCapabilities; @@ -250,11 +367,14 @@ pub trait RxToken { /// /// This method receives a packet and then calls the given closure `f` with the raw /// packet bytes as argument. - /// - /// The timestamp must be a number of milliseconds, monotonically increasing since an - /// arbitrary moment in time, such as system startup. - fn consume(self, timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result; + fn consume(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R; + + /// The Packet ID associated with the frame received by this [`RxToken`] + fn meta(&self) -> PacketMeta { + PacketMeta::default() + } } /// A token to transmit a single network packet. @@ -265,9 +385,11 @@ pub trait TxToken { /// closure `f` with a mutable reference to that buffer. The closure should construct /// a valid network packet (e.g. an ethernet packet) in the buffer. When the closure /// returns, the transmit buffer is sent out. - /// - /// The timestamp must be a number of milliseconds, monotonically increasing since an - /// arbitrary moment in time, such as system startup. - fn consume(self, timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result; + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R; + + /// The Packet ID to be associated with the frame to be transmitted by this [`TxToken`]. + #[allow(unused_variables)] + fn set_meta(&mut self, meta: PacketMeta) {} } diff --git a/src/phy/pcap_writer.rs b/src/phy/pcap_writer.rs index 53d7559c5..fc6c3b236 100644 --- a/src/phy/pcap_writer.rs +++ b/src/phy/pcap_writer.rs @@ -1,48 +1,53 @@ -#[cfg(feature = "std")] -use std::cell::RefCell; +use byteorder::{ByteOrder, NativeEndian}; +use core::cell::RefCell; +use phy::Medium; #[cfg(feature = "std")] use std::io::Write; -use byteorder::{ByteOrder, NativeEndian}; -use Result; -use phy::{self, DeviceCapabilities, Device}; -use time::Instant; +use crate::phy::{self, Device, DeviceCapabilities}; +use crate::time::Instant; enum_with_unknown! { /// Captured packet header type. - pub doc enum PcapLinkType(u32) { + pub enum PcapLinkType(u32) { /// Ethernet frames Ethernet = 1, /// IPv4 or IPv6 packets (depending on the version field) - Ip = 101 + Ip = 101, + /// IEEE 802.15.4 packets with FCS included. + Ieee802154WithFcs = 195, } } /// Packet capture mode. #[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum PcapMode { /// Capture both received and transmitted packets. Both, /// Capture only received packets. RxOnly, /// Capture only transmitted packets. - TxOnly + TxOnly, } /// A packet capture sink. pub trait PcapSink { /// Write data into the sink. - fn write(&self, data: &[u8]); + fn write(&mut self, data: &[u8]); + + /// Flush data written into the sync. + fn flush(&mut self) {} /// Write an `u16` into the sink, in native byte order. - fn write_u16(&self, value: u16) { + fn write_u16(&mut self, value: u16) { let mut bytes = [0u8; 2]; NativeEndian::write_u16(&mut bytes, value); self.write(&bytes[..]) } /// Write an `u32` into the sink, in native byte order. - fn write_u32(&self, value: u32) { + fn write_u32(&mut self, value: u32) { let mut bytes = [0u8; 4]; NativeEndian::write_u32(&mut bytes, value); self.write(&bytes[..]) @@ -51,13 +56,13 @@ pub trait PcapSink { /// Write the libpcap global header into the sink. /// /// This method may be overridden e.g. if special synchronization is necessary. - fn global_header(&self, link_type: PcapLinkType) { - self.write_u32(0xa1b2c3d4); // magic number - self.write_u16(2); // major version - self.write_u16(4); // minor version - self.write_u32(0); // timezone (= UTC) - self.write_u32(0); // accuracy (not used) - self.write_u32(65535); // maximum packet length + fn global_header(&mut self, link_type: PcapLinkType) { + self.write_u32(0xa1b2c3d4); // magic number + self.write_u16(2); // major version + self.write_u16(4); // minor version + self.write_u32(0); // timezone (= UTC) + self.write_u32(0); // accuracy (not used) + self.write_u32(65535); // maximum packet length self.write_u32(link_type.into()); // link-layer header type } @@ -67,40 +72,33 @@ pub trait PcapSink { /// /// # Panics /// This function panics if `length` is greater than 65535. - fn packet_header(&self, timestamp: Instant, length: usize) { + fn packet_header(&mut self, timestamp: Instant, length: usize) { assert!(length <= 65535); - self.write_u32(timestamp.secs() as u32); // timestamp seconds - self.write_u32(timestamp.millis() as u32); // timestamp microseconds - self.write_u32(length as u32); // captured length - self.write_u32(length as u32); // original length + self.write_u32(timestamp.secs() as u32); // timestamp seconds + self.write_u32(timestamp.micros() as u32); // timestamp microseconds + self.write_u32(length as u32); // captured length + self.write_u32(length as u32); // original length } /// Write the libpcap packet header followed by packet data into the sink. /// /// See also the note for [global_header](#method.global_header). - fn packet(&self, timestamp: Instant, packet: &[u8]) { + fn packet(&mut self, timestamp: Instant, packet: &[u8]) { self.packet_header(timestamp, packet.len()); - self.write(packet) - } -} - -impl> PcapSink for T { - fn write(&self, data: &[u8]) { - self.as_ref().write(data) + self.write(packet); + self.flush(); } } #[cfg(feature = "std")] -impl PcapSink for RefCell { - fn write(&self, data: &[u8]) { - self.borrow_mut().write_all(data).expect("cannot write") +impl PcapSink for T { + fn write(&mut self, data: &[u8]) { + T::write_all(self, data).expect("cannot write") } - fn packet(&self, timestamp: Instant, packet: &[u8]) { - self.packet_header(timestamp, packet.len()); - PcapSink::write(self, packet); - self.borrow_mut().flush().expect("cannot flush") + fn flush(&mut self) { + T::flush(self).expect("cannot flush") } } @@ -118,89 +116,153 @@ impl PcapSink for RefCell { /// [sink]: trait.PcapSink.html #[derive(Debug)] pub struct PcapWriter - where D: for<'a> Device<'a>, - S: PcapSink + Clone, +where + D: Device, + S: PcapSink, { lower: D, - sink: S, - mode: PcapMode, + sink: RefCell, + mode: PcapMode, } -impl Device<'a>, S: PcapSink + Clone> PcapWriter { +impl PcapWriter { /// Creates a packet capture writer. - pub fn new(lower: D, sink: S, mode: PcapMode, link_type: PcapLinkType) -> PcapWriter { + pub fn new(lower: D, mut sink: S, mode: PcapMode) -> PcapWriter { + let medium = lower.capabilities().medium; + let link_type = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => PcapLinkType::Ip, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => PcapLinkType::Ethernet, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => PcapLinkType::Ieee802154WithFcs, + }; sink.global_header(link_type); - PcapWriter { lower, sink, mode } + PcapWriter { + lower, + sink: RefCell::new(sink), + mode, + } + } + + /// Get a reference to the underlying device. + /// + /// Even if the device offers reading through a standard reference, it is inadvisable to + /// directly read from the device as doing so will circumvent the packet capture. + pub fn get_ref(&self) -> &D { + &self.lower + } + + /// Get a mutable reference to the underlying device. + /// + /// It is inadvisable to directly read from the device as doing so will circumvent the packet capture. + pub fn get_mut(&mut self) -> &mut D { + &mut self.lower } } -impl<'a, D, S> Device<'a> for PcapWriter - where D: for<'b> Device<'b>, - S: PcapSink + Clone + 'a, +impl Device for PcapWriter +where + S: PcapSink, { - type RxToken = RxToken<>::RxToken, S>; - type TxToken = TxToken<>::TxToken, S>; + type RxToken<'a> = RxToken<'a, D::RxToken<'a>, S> + where + Self: 'a; + type TxToken<'a> = TxToken<'a, D::TxToken<'a>, S> + where + Self: 'a; - fn capabilities(&self) -> DeviceCapabilities { self.lower.capabilities() } + fn capabilities(&self) -> DeviceCapabilities { + self.lower.capabilities() + } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - let &mut Self { ref mut lower, ref sink, mode, .. } = self; - lower.receive().map(|(rx_token, tx_token)| { - let rx = RxToken { token: rx_token, sink: sink.clone(), mode: mode }; - let tx = TxToken { token: tx_token, sink: sink.clone(), mode: mode }; - (rx, tx) - }) + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let sink = &self.sink; + let mode = self.mode; + self.lower + .receive(timestamp) + .map(move |(rx_token, tx_token)| { + let rx = RxToken { + token: rx_token, + sink, + mode, + timestamp, + }; + let tx = TxToken { + token: tx_token, + sink, + mode, + timestamp, + }; + (rx, tx) + }) } - fn transmit(&'a mut self) -> Option { - let &mut Self { ref mut lower, ref sink, mode } = self; - lower.transmit().map(|token| { - TxToken { token, sink: sink.clone(), mode: mode } + fn transmit(&mut self, timestamp: Instant) -> Option> { + let sink = &self.sink; + let mode = self.mode; + self.lower.transmit(timestamp).map(move |token| TxToken { + token, + sink, + mode, + timestamp, }) } } #[doc(hidden)] -pub struct RxToken { +pub struct RxToken<'a, Rx: phy::RxToken, S: PcapSink> { token: Rx, - sink: S, - mode: PcapMode, + sink: &'a RefCell, + mode: PcapMode, + timestamp: Instant, } -impl phy::RxToken for RxToken { - fn consume Result>(self, timestamp: Instant, f: F) -> Result { - let Self { token, sink, mode } = self; - token.consume(timestamp, |buffer| { - match mode { - PcapMode::Both | PcapMode::RxOnly => - sink.packet(timestamp, buffer.as_ref()), - PcapMode::TxOnly => () +impl<'a, Rx: phy::RxToken, S: PcapSink> phy::RxToken for RxToken<'a, Rx, S> { + fn consume R>(self, f: F) -> R { + self.token.consume(|buffer| { + match self.mode { + PcapMode::Both | PcapMode::RxOnly => self + .sink + .borrow_mut() + .packet(self.timestamp, buffer.as_ref()), + PcapMode::TxOnly => (), } f(buffer) }) } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } } #[doc(hidden)] -pub struct TxToken { +pub struct TxToken<'a, Tx: phy::TxToken, S: PcapSink> { token: Tx, - sink: S, - mode: PcapMode + sink: &'a RefCell, + mode: PcapMode, + timestamp: Instant, } -impl phy::TxToken for TxToken { - fn consume(self, timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result +impl<'a, Tx: phy::TxToken, S: PcapSink> phy::TxToken for TxToken<'a, Tx, S> { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let Self { token, sink, mode } = self; - token.consume(timestamp, len, |buffer| { + self.token.consume(len, |buffer| { let result = f(buffer); - match mode { - PcapMode::Both | PcapMode::TxOnly => - sink.packet(timestamp, &buffer), - PcapMode::RxOnly => () + match self.mode { + PcapMode::Both | PcapMode::TxOnly => { + self.sink.borrow_mut().packet(self.timestamp, buffer) + } + PcapMode::RxOnly => (), }; result }) } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } } diff --git a/src/phy/raw_socket.rs b/src/phy/raw_socket.rs index 231d3d171..0a4cc2990 100644 --- a/src/phy/raw_socket.rs +++ b/src/phy/raw_socket.rs @@ -1,18 +1,18 @@ use std::cell::RefCell; -use std::vec::Vec; -use std::rc::Rc; use std::io; -use std::os::unix::io::{RawFd, AsRawFd}; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::rc::Rc; +use std::vec::Vec; -use Result; -use phy::{self, sys, DeviceCapabilities, Device}; -use time::Instant; +use crate::phy::{self, sys, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; /// A socket that captures or transmits the complete frame. #[derive(Debug)] pub struct RawSocket { - lower: Rc>, - mtu: usize + medium: Medium, + lower: Rc>, + mtu: usize, } impl AsRawFd for RawSocket { @@ -26,46 +26,67 @@ impl RawSocket { /// /// This requires superuser privileges or a corresponding capability bit /// set on the executable. - pub fn new(name: &str) -> io::Result { - let mut lower = sys::RawSocketDesc::new(name)?; + pub fn new(name: &str, medium: Medium) -> io::Result { + let mut lower = sys::RawSocketDesc::new(name, medium)?; lower.bind_interface()?; - let mtu = lower.interface_mtu()?; + + let mut mtu = lower.interface_mtu()?; + + // FIXME(thvdveld): this is a workaround for https://github.com/smoltcp-rs/smoltcp/issues/622 + #[cfg(feature = "medium-ieee802154")] + if medium == Medium::Ieee802154 { + mtu += 2; + } + + #[cfg(feature = "medium-ethernet")] + if medium == Medium::Ethernet { + // SIOCGIFMTU returns the IP MTU (typically 1500 bytes.) + // smoltcp counts the entire Ethernet packet in the MTU, so add the Ethernet header size to it. + mtu += crate::wire::EthernetFrame::<&[u8]>::header_len() + } + Ok(RawSocket { + medium, lower: Rc::new(RefCell::new(lower)), - mtu: mtu + mtu, }) } } -impl<'a> Device<'a> for RawSocket { - type RxToken = RxToken; - type TxToken = TxToken; +impl Device for RawSocket { + type RxToken<'a> = RxToken + where + Self: 'a; + type TxToken<'a> = TxToken + where + Self: 'a; fn capabilities(&self) -> DeviceCapabilities { DeviceCapabilities { max_transmission_unit: self.mtu, + medium: self.medium, ..DeviceCapabilities::default() } } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { let mut lower = self.lower.borrow_mut(); let mut buffer = vec![0; self.mtu]; match lower.recv(&mut buffer[..]) { Ok(size) => { buffer.resize(size, 0); let rx = RxToken { buffer }; - let tx = TxToken { lower: self.lower.clone() }; + let tx = TxToken { + lower: self.lower.clone(), + }; Some((rx, tx)) } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - None - } - Err(err) => panic!("{}", err) + Err(err) if err.kind() == io::ErrorKind::WouldBlock => None, + Err(err) => panic!("{}", err), } } - fn transmit(&'a mut self) -> Option { + fn transmit(&mut self, _timestamp: Instant) -> Option> { Some(TxToken { lower: self.lower.clone(), }) @@ -74,12 +95,13 @@ impl<'a> Device<'a> for RawSocket { #[doc(hidden)] pub struct RxToken { - buffer: Vec + buffer: Vec, } impl phy::RxToken for RxToken { - fn consume(mut self, _timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { f(&mut self.buffer[..]) } @@ -87,17 +109,24 @@ impl phy::RxToken for RxToken { #[doc(hidden)] pub struct TxToken { - lower: Rc>, + lower: Rc>, } impl phy::TxToken for TxToken { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { let mut lower = self.lower.borrow_mut(); let mut buffer = vec![0; len]; let result = f(&mut buffer); - lower.send(&buffer[..]).unwrap(); + match lower.send(&buffer[..]) { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + net_debug!("phy: tx failed due to WouldBlock") + } + Err(err) => panic!("{}", err), + } result } } diff --git a/src/phy/sys/bpf.rs b/src/phy/sys/bpf.rs index 566941474..1935ac3fc 100644 --- a/src/phy/sys/bpf.rs +++ b/src/phy/sys/bpf.rs @@ -1,22 +1,35 @@ use std::io; +use std::mem; use std::os::unix::io::{AsRawFd, RawFd}; use libc; use super::{ifreq, ifreq_for}; +use crate::phy::Medium; +use crate::wire::ETHERNET_HEADER_LEN; /// set interface -#[cfg(target_os = "macos")] +#[cfg(any(target_os = "macos", target_os = "openbsd"))] const BIOCSETIF: libc::c_ulong = 0x8020426c; /// get buffer length -#[cfg(target_os = "macos")] +#[cfg(any(target_os = "macos", target_os = "openbsd"))] const BIOCGBLEN: libc::c_ulong = 0x40044266; /// set immediate/nonblocking read -#[cfg(target_os = "macos")] +#[cfg(any(target_os = "macos", target_os = "openbsd"))] const BIOCIMMEDIATE: libc::c_ulong = 0x80044270; -// TODO: check if this is same for OSes other than macos +/// set bpf_hdr struct size #[cfg(target_os = "macos")] -const BPF_HDRLEN: usize = 18; +const SIZEOF_BPF_HDR: usize = 18; +/// set bpf_hdr struct size +#[cfg(target_os = "openbsd")] +const SIZEOF_BPF_HDR: usize = 24; +/// The actual header length may be larger than the bpf_hdr struct due to aligning +/// see https://github.com/openbsd/src/blob/37ecb4d066e5566411cc16b362d3960c93b1d0be/sys/net/bpf.c#L1649 +/// and https://github.com/apple/darwin-xnu/blob/8f02f2a044b9bb1ad951987ef5bab20ec9486310/bsd/net/bpf.c#L3580 +#[cfg(any(target_os = "macos", target_os = "openbsd"))] +const BPF_HDRLEN: usize = (((SIZEOF_BPF_HDR + ETHERNET_HEADER_LEN) + mem::align_of::() - 1) + & !(mem::align_of::() - 1)) + - ETHERNET_HEADER_LEN; macro_rules! try_ioctl { ($fd:expr,$cmd:expr,$req:expr) => { @@ -43,8 +56,8 @@ impl AsRawFd for BpfDevice { fn open_device() -> io::Result { unsafe { for i in 0..256 { - let dev = format!("/dev/bpf{}", i).as_ptr() as *const libc::c_char; - match libc::open(dev, libc::O_RDWR) { + let dev = format!("/dev/bpf{}\0", i); + match libc::open(dev.as_ptr() as *const libc::c_char, libc::O_RDWR) { -1 => continue, fd => return Ok(fd), }; @@ -55,7 +68,7 @@ fn open_device() -> io::Result { } impl BpfDevice { - pub fn new(name: &str) -> io::Result { + pub fn new(name: &str, _medium: Medium) -> io::Result { Ok(BpfDevice { fd: open_device()?, ifreq: ifreq_for(name), @@ -145,3 +158,20 @@ impl Drop for BpfDevice { } } } + +#[cfg(test)] +mod test { + use super::*; + + #[test] + #[cfg(target_os = "macos")] + fn test_aligned_bpf_hdr_len() { + assert_eq!(18, BPF_HDRLEN); + } + + #[test] + #[cfg(target_os = "openbsd")] + fn test_aligned_bpf_hdr_len() { + assert_eq!(26, BPF_HDRLEN); + } +} diff --git a/src/phy/sys/linux.rs b/src/phy/sys/linux.rs index fdf52baf3..f83ab5866 100644 --- a/src/phy/sys/linux.rs +++ b/src/phy/sys/linux.rs @@ -1,17 +1,11 @@ -use libc; +#![allow(unused)] -#[cfg(any(feature = "phy-raw_socket", - feature = "phy-tap_interface"))] -pub const SIOCGIFMTU: libc::c_ulong = 0x8921; -#[cfg(any(feature = "phy-raw_socket"))] +pub const SIOCGIFMTU: libc::c_ulong = 0x8921; pub const SIOCGIFINDEX: libc::c_ulong = 0x8933; -#[cfg(any(feature = "phy-raw_socket"))] -pub const ETH_P_ALL: libc::c_short = 0x0003; - -#[cfg(feature = "phy-tap_interface")] -pub const TUNSETIFF: libc::c_ulong = 0x400454CA; -#[cfg(feature = "phy-tap_interface")] -pub const IFF_TAP: libc::c_int = 0x0002; -#[cfg(feature = "phy-tap_interface")] -pub const IFF_NO_PI: libc::c_int = 0x1000; +pub const ETH_P_ALL: libc::c_short = 0x0003; +pub const ETH_P_IEEE802154: libc::c_short = 0x00F6; +pub const TUNSETIFF: libc::c_ulong = 0x400454CA; +pub const IFF_TUN: libc::c_int = 0x0001; +pub const IFF_TAP: libc::c_int = 0x0002; +pub const IFF_NO_PI: libc::c_int = 0x1000; diff --git a/src/phy/sys/mod.rs b/src/phy/sys/mod.rs index 508eb1abf..3f42301c5 100644 --- a/src/phy/sys/mod.rs +++ b/src/phy/sys/mod.rs @@ -1,27 +1,46 @@ #![allow(unsafe_code)] -use libc; -use std::{mem, ptr, io}; +use crate::time::Duration; use std::os::unix::io::RawFd; -use time::Duration; +use std::{io, mem, ptr}; -#[cfg(target_os = "linux")] +#[cfg(any(target_os = "linux", target_os = "android"))] #[path = "linux.rs"] mod imp; -#[cfg(all(feature = "phy-raw_socket", target_os = "linux"))] -pub mod raw_socket; -#[cfg(all(feature = "phy-raw_socket", not(target_os = "linux"), unix))] +#[cfg(all( + feature = "phy-raw_socket", + not(any(target_os = "linux", target_os = "android")), + unix +))] pub mod bpf; -#[cfg(all(feature = "phy-tap_interface", target_os = "linux"))] -pub mod tap_interface; +#[cfg(all( + feature = "phy-raw_socket", + any(target_os = "linux", target_os = "android") +))] +pub mod raw_socket; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub mod tuntap_interface; -#[cfg(all(feature = "phy-raw_socket", target_os = "linux"))] -pub use self::raw_socket::RawSocketDesc; -#[cfg(all(feature = "phy-raw_socket", not(target_os = "linux"), unix))] +#[cfg(all( + feature = "phy-raw_socket", + not(any(target_os = "linux", target_os = "android")), + unix +))] pub use self::bpf::BpfDevice as RawSocketDesc; -#[cfg(all(feature = "phy-tap_interface", target_os = "linux"))] -pub use self::tap_interface::TapInterfaceDesc; +#[cfg(all( + feature = "phy-raw_socket", + any(target_os = "linux", target_os = "android") +))] +pub use self::raw_socket::RawSocketDesc; +#[cfg(all( + feature = "phy-tuntap_interface", + any(target_os = "linux", target_os = "android") +))] +pub use self::tuntap_interface::TunTapInterfaceDesc; /// Wait until given file descriptor becomes readable, but no longer than given timeout. pub fn wait(fd: RawFd, duration: Option) -> io::Result<()> { @@ -45,34 +64,51 @@ pub fn wait(fd: RawFd, duration: Option) -> io::Result<()> { exceptfds.assume_init() }; - let mut timeout = libc::timeval { tv_sec: 0, tv_usec: 0 }; - let timeout_ptr = - if let Some(duration) = duration { - timeout.tv_usec = (duration.total_millis() * 1_000) as libc::suseconds_t; - &mut timeout as *mut _ - } else { - ptr::null_mut() - }; + let mut timeout = libc::timeval { + tv_sec: 0, + tv_usec: 0, + }; + let timeout_ptr = if let Some(duration) = duration { + timeout.tv_sec = duration.secs() as libc::time_t; + timeout.tv_usec = (duration.millis() * 1_000) as libc::suseconds_t; + &mut timeout as *mut _ + } else { + ptr::null_mut() + }; - let res = libc::select(fd + 1, &mut readfds, &mut writefds, &mut exceptfds, timeout_ptr); - if res == -1 { return Err(io::Error::last_os_error()) } + let res = libc::select( + fd + 1, + &mut readfds, + &mut writefds, + &mut exceptfds, + timeout_ptr, + ); + if res == -1 { + return Err(io::Error::last_os_error()); + } Ok(()) } } -#[cfg(all(any(feature = "phy-tap_interface", feature = "phy-raw_socket"), unix))] +#[cfg(all( + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket"), + unix +))] #[repr(C)] #[derive(Debug)] struct ifreq { ifr_name: [libc::c_char; libc::IF_NAMESIZE], - ifr_data: libc::c_int /* ifr_ifindex or ifr_mtu */ + ifr_data: libc::c_int, /* ifr_ifindex or ifr_mtu */ } -#[cfg(all(any(feature = "phy-tap_interface", feature = "phy-raw_socket"), unix))] +#[cfg(all( + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket"), + unix +))] fn ifreq_for(name: &str) -> ifreq { let mut ifreq = ifreq { ifr_name: [0; libc::IF_NAMESIZE], - ifr_data: 0 + ifr_data: 0, }; for (i, byte) in name.as_bytes().iter().enumerate() { ifreq.ifr_name[i] = *byte as libc::c_char @@ -80,12 +116,20 @@ fn ifreq_for(name: &str) -> ifreq { ifreq } -#[cfg(all(target_os = "linux", any(feature = "phy-tap_interface", feature = "phy-raw_socket")))] -fn ifreq_ioctl(lower: libc::c_int, ifreq: &mut ifreq, - cmd: libc::c_ulong) -> io::Result { +#[cfg(all( + any(target_os = "linux", target_os = "android"), + any(feature = "phy-tuntap_interface", feature = "phy-raw_socket") +))] +fn ifreq_ioctl( + lower: libc::c_int, + ifreq: &mut ifreq, + cmd: libc::c_ulong, +) -> io::Result { unsafe { let res = libc::ioctl(lower, cmd as _, ifreq as *mut ifreq); - if res == -1 { return Err(io::Error::last_os_error()) } + if res == -1 { + return Err(io::Error::last_os_error()); + } } Ok(ifreq.ifr_data) diff --git a/src/phy/sys/raw_socket.rs b/src/phy/sys/raw_socket.rs index 2b60bd7ce..f37fe960f 100644 --- a/src/phy/sys/raw_socket.rs +++ b/src/phy/sys/raw_socket.rs @@ -1,12 +1,13 @@ -use std::{mem, io}; -use std::os::unix::io::{RawFd, AsRawFd}; -use libc; use super::*; +use crate::phy::Medium; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::{io, mem}; #[derive(Debug)] pub struct RawSocketDesc { + protocol: libc::c_short, lower: libc::c_int, - ifreq: ifreq + ifreq: ifreq, } impl AsRawFd for RawSocketDesc { @@ -16,17 +17,32 @@ impl AsRawFd for RawSocketDesc { } impl RawSocketDesc { - pub fn new(name: &str) -> io::Result { + pub fn new(name: &str, medium: Medium) -> io::Result { + let protocol = match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => imp::ETH_P_ALL, + #[cfg(feature = "medium-ip")] + Medium::Ip => imp::ETH_P_ALL, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => imp::ETH_P_IEEE802154, + }; + let lower = unsafe { - let lower = libc::socket(libc::AF_PACKET, libc::SOCK_RAW | libc::SOCK_NONBLOCK, - imp::ETH_P_ALL.to_be() as i32); - if lower == -1 { return Err(io::Error::last_os_error()) } + let lower = libc::socket( + libc::AF_PACKET, + libc::SOCK_RAW | libc::SOCK_NONBLOCK, + protocol.to_be() as i32, + ); + if lower == -1 { + return Err(io::Error::last_os_error()); + } lower }; Ok(RawSocketDesc { - lower: lower, - ifreq: ifreq_for(name) + protocol, + lower, + ifreq: ifreq_for(name), }) } @@ -36,20 +52,24 @@ impl RawSocketDesc { pub fn bind_interface(&mut self) -> io::Result<()> { let sockaddr = libc::sockaddr_ll { - sll_family: libc::AF_PACKET as u16, - sll_protocol: imp::ETH_P_ALL.to_be() as u16, - sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, imp::SIOCGIFINDEX)?, - sll_hatype: 1, - sll_pkttype: 0, - sll_halen: 6, - sll_addr: [0; 8] + sll_family: libc::AF_PACKET as u16, + sll_protocol: self.protocol.to_be() as u16, + sll_ifindex: ifreq_ioctl(self.lower, &mut self.ifreq, imp::SIOCGIFINDEX)?, + sll_hatype: 1, + sll_pkttype: 0, + sll_halen: 6, + sll_addr: [0; 8], }; unsafe { - let res = libc::bind(self.lower, - &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr, - mem::size_of::() as u32); - if res == -1 { return Err(io::Error::last_os_error()) } + let res = libc::bind( + self.lower, + &sockaddr as *const libc::sockaddr_ll as *const libc::sockaddr, + mem::size_of::() as libc::socklen_t, + ); + if res == -1 { + return Err(io::Error::last_os_error()); + } } Ok(()) @@ -57,18 +77,30 @@ impl RawSocketDesc { pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result { unsafe { - let len = libc::recv(self.lower, buffer.as_mut_ptr() as *mut libc::c_void, - buffer.len(), 0); - if len == -1 { return Err(io::Error::last_os_error()) } + let len = libc::recv( + self.lower, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } Ok(len as usize) } } pub fn send(&mut self, buffer: &[u8]) -> io::Result { unsafe { - let len = libc::send(self.lower, buffer.as_ptr() as *const libc::c_void, - buffer.len(), 0); - if len == -1 { Err(io::Error::last_os_error()).unwrap() } + let len = libc::send( + self.lower, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + 0, + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } Ok(len as usize) } } @@ -76,6 +108,8 @@ impl RawSocketDesc { impl Drop for RawSocketDesc { fn drop(&mut self) { - unsafe { libc::close(self.lower); } + unsafe { + libc::close(self.lower); + } } } diff --git a/src/phy/sys/tap_interface.rs b/src/phy/sys/tap_interface.rs deleted file mode 100644 index f89597ff8..000000000 --- a/src/phy/sys/tap_interface.rs +++ /dev/null @@ -1,75 +0,0 @@ -use std::io; -use std::os::unix::io::{RawFd, AsRawFd}; -use libc; -use super::*; - -#[derive(Debug)] -pub struct TapInterfaceDesc { - lower: libc::c_int, - ifreq: ifreq -} - -impl AsRawFd for TapInterfaceDesc { - fn as_raw_fd(&self) -> RawFd { - self.lower - } -} - -impl TapInterfaceDesc { - pub fn new(name: &str) -> io::Result { - let lower = unsafe { - let lower = libc::open("/dev/net/tun\0".as_ptr() as *const libc::c_char, - libc::O_RDWR | libc::O_NONBLOCK); - if lower == -1 { return Err(io::Error::last_os_error()) } - lower - }; - - Ok(TapInterfaceDesc { - lower: lower, - ifreq: ifreq_for(name) - }) - } - - pub fn attach_interface(&mut self) -> io::Result<()> { - self.ifreq.ifr_data = imp::IFF_TAP | imp::IFF_NO_PI; - ifreq_ioctl(self.lower, &mut self.ifreq, imp::TUNSETIFF).map(|_| ()) - } - - pub fn interface_mtu(&mut self) -> io::Result { - let lower = unsafe { - let lower = libc::socket(libc::AF_INET, libc::SOCK_DGRAM, libc::IPPROTO_IP); - if lower == -1 { return Err(io::Error::last_os_error()) } - lower - }; - - let mtu = ifreq_ioctl(lower, &mut self.ifreq, imp::SIOCGIFMTU).map(|mtu| mtu as usize); - - unsafe { libc::close(lower); } - - mtu - } - - pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result { - unsafe { - let len = libc::read(self.lower, buffer.as_mut_ptr() as *mut libc::c_void, - buffer.len()); - if len == -1 { return Err(io::Error::last_os_error()) } - Ok(len as usize) - } - } - - pub fn send(&mut self, buffer: &[u8]) -> io::Result { - unsafe { - let len = libc::write(self.lower, buffer.as_ptr() as *const libc::c_void, - buffer.len()); - if len == -1 { Err(io::Error::last_os_error()).unwrap() } - Ok(len as usize) - } - } -} - -impl Drop for TapInterfaceDesc { - fn drop(&mut self) { - unsafe { libc::close(self.lower); } - } -} diff --git a/src/phy/sys/tuntap_interface.rs b/src/phy/sys/tuntap_interface.rs new file mode 100644 index 000000000..3019cadea --- /dev/null +++ b/src/phy/sys/tuntap_interface.rs @@ -0,0 +1,130 @@ +use super::*; +use crate::{phy::Medium, wire::EthernetFrame}; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; + +#[derive(Debug)] +pub struct TunTapInterfaceDesc { + lower: libc::c_int, + mtu: usize, +} + +impl AsRawFd for TunTapInterfaceDesc { + fn as_raw_fd(&self) -> RawFd { + self.lower + } +} + +impl TunTapInterfaceDesc { + pub fn new(name: &str, medium: Medium) -> io::Result { + let lower = unsafe { + let lower = libc::open( + "/dev/net/tun\0".as_ptr() as *const libc::c_char, + libc::O_RDWR | libc::O_NONBLOCK, + ); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + let mut ifreq = ifreq_for(name); + Self::attach_interface_ifreq(lower, medium, &mut ifreq)?; + let mtu = Self::mtu_ifreq(medium, &mut ifreq)?; + + Ok(TunTapInterfaceDesc { lower, mtu }) + } + + pub fn from_fd(fd: RawFd, mtu: usize) -> io::Result { + Ok(TunTapInterfaceDesc { lower: fd, mtu }) + } + + fn attach_interface_ifreq( + lower: libc::c_int, + medium: Medium, + ifr: &mut ifreq, + ) -> io::Result<()> { + let mode = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => imp::IFF_TUN, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => imp::IFF_TAP, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + ifr.ifr_data = mode | imp::IFF_NO_PI; + ifreq_ioctl(lower, ifr, imp::TUNSETIFF).map(|_| ()) + } + + fn mtu_ifreq(medium: Medium, ifr: &mut ifreq) -> io::Result { + let lower = unsafe { + let lower = libc::socket(libc::AF_INET, libc::SOCK_DGRAM, libc::IPPROTO_IP); + if lower == -1 { + return Err(io::Error::last_os_error()); + } + lower + }; + + let ip_mtu = ifreq_ioctl(lower, ifr, imp::SIOCGIFMTU).map(|mtu| mtu as usize); + + unsafe { + libc::close(lower); + } + + // Propagate error after close, to ensure we always close. + let ip_mtu = ip_mtu?; + + // SIOCGIFMTU returns the IP MTU (typically 1500 bytes.) + // smoltcp counts the entire Ethernet packet in the MTU, so add the Ethernet header size to it. + let mtu = match medium { + #[cfg(feature = "medium-ip")] + Medium::Ip => ip_mtu, + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => ip_mtu + EthernetFrame::<&[u8]>::header_len(), + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => todo!(), + }; + + Ok(mtu) + } + + pub fn interface_mtu(&self) -> io::Result { + Ok(self.mtu) + } + + pub fn recv(&mut self, buffer: &mut [u8]) -> io::Result { + unsafe { + let len = libc::read( + self.lower, + buffer.as_mut_ptr() as *mut libc::c_void, + buffer.len(), + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } + + pub fn send(&mut self, buffer: &[u8]) -> io::Result { + unsafe { + let len = libc::write( + self.lower, + buffer.as_ptr() as *const libc::c_void, + buffer.len(), + ); + if len == -1 { + return Err(io::Error::last_os_error()); + } + Ok(len as usize) + } + } +} + +impl Drop for TunTapInterfaceDesc { + fn drop(&mut self) { + unsafe { + libc::close(self.lower); + } + } +} diff --git a/src/phy/tap_interface.rs b/src/phy/tap_interface.rs deleted file mode 100644 index dfdfb27b4..000000000 --- a/src/phy/tap_interface.rs +++ /dev/null @@ -1,104 +0,0 @@ -use std::cell::RefCell; -use std::vec::Vec; -use std::rc::Rc; -use std::io; -use std::os::unix::io::{RawFd, AsRawFd}; - -use Result; -use phy::{self, sys, DeviceCapabilities, Device}; -use time::Instant; - -/// A virtual Ethernet interface. -#[derive(Debug)] -pub struct TapInterface { - lower: Rc>, - mtu: usize -} - -impl AsRawFd for TapInterface { - fn as_raw_fd(&self) -> RawFd { - self.lower.borrow().as_raw_fd() - } -} - -impl TapInterface { - /// Attaches to a TAP interface called `name`, or creates it if it does not exist. - /// - /// If `name` is a persistent interface configured with UID of the current user, - /// no special privileges are needed. Otherwise, this requires superuser privileges - /// or a corresponding capability set on the executable. - pub fn new(name: &str) -> io::Result { - let mut lower = sys::TapInterfaceDesc::new(name)?; - lower.attach_interface()?; - let mtu = lower.interface_mtu()?; - Ok(TapInterface { - lower: Rc::new(RefCell::new(lower)), - mtu: mtu - }) - } -} - -impl<'a> Device<'a> for TapInterface { - type RxToken = RxToken; - type TxToken = TxToken; - - fn capabilities(&self) -> DeviceCapabilities { - DeviceCapabilities { - max_transmission_unit: self.mtu, - ..DeviceCapabilities::default() - } - } - - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - let mut lower = self.lower.borrow_mut(); - let mut buffer = vec![0; self.mtu]; - match lower.recv(&mut buffer[..]) { - Ok(size) => { - buffer.resize(size, 0); - let rx = RxToken { buffer }; - let tx = TxToken { lower: self.lower.clone() }; - Some((rx, tx)) - } - Err(ref err) if err.kind() == io::ErrorKind::WouldBlock => { - None - } - Err(err) => panic!("{}", err) - } - } - - fn transmit(&'a mut self) -> Option { - Some(TxToken { - lower: self.lower.clone(), - }) - } -} - -#[doc(hidden)] -pub struct RxToken { - buffer: Vec -} - -impl phy::RxToken for RxToken { - fn consume(mut self, _timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result - { - f(&mut self.buffer[..]) - } -} - -#[doc(hidden)] -pub struct TxToken { - lower: Rc>, -} - -impl phy::TxToken for TxToken { - fn consume(self, _timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result - { - let mut lower = self.lower.borrow_mut(); - let mut buffer = vec![0; len]; - let result = f(&mut buffer); - lower.send(&buffer[..]).unwrap(); - result - } -} diff --git a/src/phy/tracer.rs b/src/phy/tracer.rs index b1598adbd..48e60ec2b 100644 --- a/src/phy/tracer.rs +++ b/src/phy/tracer.rs @@ -1,21 +1,22 @@ -use Result; -use wire::pretty_print::{PrettyPrint, PrettyPrinter}; -use phy::{self, DeviceCapabilities, Device}; -use time::Instant; +use core::fmt; + +use crate::phy::{self, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; /// A tracer device. /// /// A tracer is a device that pretty prints all packets traversing it /// using the provided writer function, and then passes them to another /// device. -pub struct Tracer Device<'a>, P: PrettyPrint> { - inner: D, - writer: fn(Instant, PrettyPrinter

), +pub struct Tracer { + inner: D, + writer: fn(Instant, Packet), } -impl Device<'a>, P: PrettyPrint> Tracer { +impl Tracer { /// Create a tracer device. - pub fn new(inner: D, writer: fn(timestamp: Instant, printer: PrettyPrinter

)) -> Tracer { + pub fn new(inner: D, writer: fn(timestamp: Instant, packet: Packet)) -> Tracer { Tracer { inner, writer } } @@ -40,65 +41,149 @@ impl Device<'a>, P: PrettyPrint> Tracer { } } -impl<'a, D, P> Device<'a> for Tracer - where D: for<'b> Device<'b>, - P: PrettyPrint + 'a, -{ - type RxToken = RxToken<>::RxToken, P>; - type TxToken = TxToken<>::TxToken, P>; +impl Device for Tracer { + type RxToken<'a> = RxToken> + where + Self: 'a; + type TxToken<'a> = TxToken> + where + Self: 'a; - fn capabilities(&self) -> DeviceCapabilities { self.inner.capabilities() } + fn capabilities(&self) -> DeviceCapabilities { + self.inner.capabilities() + } - fn receive(&'a mut self) -> Option<(Self::RxToken, Self::TxToken)> { - let &mut Self { ref mut inner, writer, .. } = self; - inner.receive().map(|(rx_token, tx_token)| { - let rx = RxToken { token: rx_token, writer: writer }; - let tx = TxToken { token: tx_token, writer: writer }; + fn receive(&mut self, timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let medium = self.inner.capabilities().medium; + self.inner.receive(timestamp).map(|(rx_token, tx_token)| { + let rx = RxToken { + token: rx_token, + writer: self.writer, + medium, + timestamp, + }; + let tx = TxToken { + token: tx_token, + writer: self.writer, + medium, + timestamp, + }; (rx, tx) }) } - fn transmit(&'a mut self) -> Option { - let &mut Self { ref mut inner, writer } = self; - inner.transmit().map(|tx_token| { - TxToken { token: tx_token, writer: writer } + fn transmit(&mut self, timestamp: Instant) -> Option> { + let medium = self.inner.capabilities().medium; + self.inner.transmit(timestamp).map(|tx_token| TxToken { + token: tx_token, + medium, + writer: self.writer, + timestamp, }) } } #[doc(hidden)] -pub struct RxToken { - token: Rx, - writer: fn(Instant, PrettyPrinter

) +pub struct RxToken { + token: Rx, + writer: fn(Instant, Packet), + medium: Medium, + timestamp: Instant, } -impl phy::RxToken for RxToken { - fn consume(self, timestamp: Instant, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result +impl phy::RxToken for RxToken { + fn consume(self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let Self { token, writer } = self; - token.consume(timestamp, |buffer| { - writer(timestamp, PrettyPrinter::

::new("<- ", &buffer)); + self.token.consume(|buffer| { + (self.writer)( + self.timestamp, + Packet { + buffer, + medium: self.medium, + prefix: "<- ", + }, + ); f(buffer) }) } + + fn meta(&self) -> phy::PacketMeta { + self.token.meta() + } } #[doc(hidden)] -pub struct TxToken { - token: Tx, - writer: fn(Instant, PrettyPrinter

) +pub struct TxToken { + token: Tx, + writer: fn(Instant, Packet), + medium: Medium, + timestamp: Instant, } -impl phy::TxToken for TxToken { - fn consume(self, timestamp: Instant, len: usize, f: F) -> Result - where F: FnOnce(&mut [u8]) -> Result +impl phy::TxToken for TxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, { - let Self { token, writer } = self; - token.consume(timestamp, len, |buffer| { + self.token.consume(len, |buffer| { let result = f(buffer); - writer(timestamp, PrettyPrinter::

::new("-> ", &buffer)); + (self.writer)( + self.timestamp, + Packet { + buffer, + medium: self.medium, + prefix: "-> ", + }, + ); result }) } + + fn set_meta(&mut self, meta: phy::PacketMeta) { + self.token.set_meta(meta) + } +} + +pub struct Packet<'a> { + buffer: &'a [u8], + medium: Medium, + prefix: &'static str, +} + +impl<'a> fmt::Display for Packet<'a> { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + let mut indent = PrettyIndent::new(self.prefix); + match self.medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => crate::wire::EthernetFrame::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ), + #[cfg(feature = "medium-ip")] + Medium::Ip => match crate::wire::IpVersion::of_packet(self.buffer) { + #[cfg(feature = "proto-ipv4")] + Ok(crate::wire::IpVersion::Ipv4) => { + crate::wire::Ipv4Packet::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ) + } + #[cfg(feature = "proto-ipv6")] + Ok(crate::wire::IpVersion::Ipv6) => { + crate::wire::Ipv6Packet::<&'static [u8]>::pretty_print( + &self.buffer, + f, + &mut indent, + ) + } + _ => f.write_str("unrecognized IP version"), + }, + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => Ok(()), // XXX + } + } } diff --git a/src/phy/tuntap_interface.rs b/src/phy/tuntap_interface.rs new file mode 100644 index 000000000..32a28dbb4 --- /dev/null +++ b/src/phy/tuntap_interface.rs @@ -0,0 +1,126 @@ +use std::cell::RefCell; +use std::io; +use std::os::unix::io::{AsRawFd, RawFd}; +use std::rc::Rc; +use std::vec::Vec; + +use crate::phy::{self, sys, Device, DeviceCapabilities, Medium}; +use crate::time::Instant; + +/// A virtual TUN (IP) or TAP (Ethernet) interface. +#[derive(Debug)] +pub struct TunTapInterface { + lower: Rc>, + mtu: usize, + medium: Medium, +} + +impl AsRawFd for TunTapInterface { + fn as_raw_fd(&self) -> RawFd { + self.lower.borrow().as_raw_fd() + } +} + +impl TunTapInterface { + /// Attaches to a TUN/TAP interface called `name`, or creates it if it does not exist. + /// + /// If `name` is a persistent interface configured with UID of the current user, + /// no special privileges are needed. Otherwise, this requires superuser privileges + /// or a corresponding capability set on the executable. + pub fn new(name: &str, medium: Medium) -> io::Result { + let lower = sys::TunTapInterfaceDesc::new(name, medium)?; + let mtu = lower.interface_mtu()?; + Ok(TunTapInterface { + lower: Rc::new(RefCell::new(lower)), + mtu, + medium, + }) + } + + /// Attaches to a TUN/TAP interface specified by file descriptor `fd`. + /// + /// On platforms like Android, a file descriptor to a tun interface is exposed. + /// On these platforms, a TunTapInterface cannot be instantiated with a name. + pub fn from_fd(fd: RawFd, medium: Medium, mtu: usize) -> io::Result { + let lower = sys::TunTapInterfaceDesc::from_fd(fd, mtu)?; + Ok(TunTapInterface { + lower: Rc::new(RefCell::new(lower)), + mtu, + medium, + }) + } +} + +impl Device for TunTapInterface { + type RxToken<'a> = RxToken; + type TxToken<'a> = TxToken; + + fn capabilities(&self) -> DeviceCapabilities { + DeviceCapabilities { + max_transmission_unit: self.mtu, + medium: self.medium, + ..DeviceCapabilities::default() + } + } + + fn receive(&mut self, _timestamp: Instant) -> Option<(Self::RxToken<'_>, Self::TxToken<'_>)> { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; self.mtu]; + match lower.recv(&mut buffer[..]) { + Ok(size) => { + buffer.resize(size, 0); + let rx = RxToken { buffer }; + let tx = TxToken { + lower: self.lower.clone(), + }; + Some((rx, tx)) + } + Err(err) if err.kind() == io::ErrorKind::WouldBlock => None, + Err(err) => panic!("{}", err), + } + } + + fn transmit(&mut self, _timestamp: Instant) -> Option> { + Some(TxToken { + lower: self.lower.clone(), + }) + } +} + +#[doc(hidden)] +pub struct RxToken { + buffer: Vec, +} + +impl phy::RxToken for RxToken { + fn consume(mut self, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + f(&mut self.buffer[..]) + } +} + +#[doc(hidden)] +pub struct TxToken { + lower: Rc>, +} + +impl phy::TxToken for TxToken { + fn consume(self, len: usize, f: F) -> R + where + F: FnOnce(&mut [u8]) -> R, + { + let mut lower = self.lower.borrow_mut(); + let mut buffer = vec![0; len]; + let result = f(&mut buffer); + match lower.send(&buffer[..]) { + Ok(_) => {} + Err(err) if err.kind() == io::ErrorKind::WouldBlock => { + net_debug!("phy: tx failed due to WouldBlock") + } + Err(err) => panic!("{}", err), + } + result + } +} diff --git a/src/rand.rs b/src/rand.rs new file mode 100644 index 000000000..15d88f77e --- /dev/null +++ b/src/rand.rs @@ -0,0 +1,40 @@ +#![allow(unsafe_code)] +#![allow(unused)] + +#[derive(Debug)] +pub(crate) struct Rand { + state: u64, +} + +impl Rand { + pub(crate) const fn new(seed: u64) -> Self { + Self { state: seed } + } + + pub(crate) fn rand_u32(&mut self) -> u32 { + // sPCG32 from https://www.pcg-random.org/paper.html + // see also https://nullprogram.com/blog/2017/09/21/ + const M: u64 = 0xbb2efcec3c39611d; + const A: u64 = 0x7590ef39; + + let s = self.state.wrapping_mul(M).wrapping_add(A); + self.state = s; + + let shift = 29 - (s >> 61); + (s >> shift) as u32 + } + + pub(crate) fn rand_u16(&mut self) -> u16 { + let n = self.rand_u32(); + (n ^ (n >> 16)) as u16 + } + + pub(crate) fn rand_source_port(&mut self) -> u16 { + loop { + let res = self.rand_u16(); + if res > 1024 { + return res; + } + } + } +} diff --git a/src/socket/dhcpv4.rs b/src/socket/dhcpv4.rs new file mode 100644 index 000000000..2609621eb --- /dev/null +++ b/src/socket/dhcpv4.rs @@ -0,0 +1,1345 @@ +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::time::{Duration, Instant}; +use crate::wire::dhcpv4::field as dhcpv4_field; +use crate::wire::{ + DhcpMessageType, DhcpPacket, DhcpRepr, IpAddress, IpProtocol, Ipv4Address, Ipv4Cidr, Ipv4Repr, + UdpRepr, DHCP_CLIENT_PORT, DHCP_MAX_DNS_SERVER_COUNT, DHCP_SERVER_PORT, UDP_HEADER_LEN, +}; +use crate::wire::{DhcpOption, HardwareAddress}; +use heapless::Vec; + +#[cfg(feature = "async")] +use super::WakerRegistration; + +use super::PollAt; + +const DEFAULT_LEASE_DURATION: Duration = Duration::from_secs(120); + +const DEFAULT_PARAMETER_REQUEST_LIST: &[u8] = &[ + dhcpv4_field::OPT_SUBNET_MASK, + dhcpv4_field::OPT_ROUTER, + dhcpv4_field::OPT_DOMAIN_NAME_SERVER, +]; + +/// IPv4 configuration data provided by the DHCP server. +#[derive(Debug, Eq, PartialEq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Config<'a> { + /// Information on how to reach the DHCP server that responded with DHCP + /// configuration. + pub server: ServerInfo, + /// IP address + pub address: Ipv4Cidr, + /// Router address, also known as default gateway. Does not necessarily + /// match the DHCP server's address. + pub router: Option, + /// DNS servers + pub dns_servers: Vec, + /// Received DHCP packet + pub packet: Option>, +} + +/// Information on how to reach a DHCP server. +#[derive(Debug, Clone, Copy, Eq, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct ServerInfo { + /// IP address to use as destination in outgoing packets + pub address: Ipv4Address, + /// Server identifier to use in outgoing packets. Usually equal to server_address, + /// but may differ in some situations (eg DHCP relays) + pub identifier: Ipv4Address, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct DiscoverState { + /// When to send next request + retry_at: Instant, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RequestState { + /// When to send next request + retry_at: Instant, + /// How many retries have been done + retry: u16, + /// Server we're trying to request from + server: ServerInfo, + /// IP address that we're trying to request. + requested_ip: Ipv4Address, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RenewState { + /// Active network config + config: Config<'static>, + + /// Renew timer. When reached, we will start attempting + /// to renew this lease with the DHCP server. + /// + /// Must be less or equal than `rebind_at`. + renew_at: Instant, + + /// Rebind timer. When reached, we will start broadcasting to renew + /// this lease with any DHCP server. + /// + /// Must be greater than or equal to `renew_at`, and less than or + /// equal to `expires_at`. + rebind_at: Instant, + + /// Whether the T2 time has elapsed + rebinding: bool, + + /// Expiration timer. When reached, this lease is no longer valid, so it must be + /// thrown away and the ethernet interface deconfigured. + expires_at: Instant, +} + +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +enum ClientState { + /// Discovering the DHCP server + Discovering(DiscoverState), + /// Requesting an address + Requesting(RequestState), + /// Having an address, refresh it periodically. + Renewing(RenewState), +} + +/// Timeout and retry configuration. +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RetryConfig { + pub discover_timeout: Duration, + /// The REQUEST timeout doubles every 2 tries. + pub initial_request_timeout: Duration, + pub request_retries: u16, + pub min_renew_timeout: Duration, +} + +impl Default for RetryConfig { + fn default() -> Self { + Self { + discover_timeout: Duration::from_secs(10), + initial_request_timeout: Duration::from_secs(5), + request_retries: 5, + min_renew_timeout: Duration::from_secs(60), + } + } +} + +/// Return value for the `Dhcpv4Socket::poll` function +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Event<'a> { + /// Configuration has been lost (for example, the lease has expired) + Deconfigured, + /// Configuration has been newly acquired, or modified. + Configured(Config<'a>), +} + +#[derive(Debug)] +pub struct Socket<'a> { + /// State of the DHCP client. + state: ClientState, + /// Set to true on config/state change, cleared back to false by the `config` function. + config_changed: bool, + /// xid of the last sent message. + transaction_id: u32, + + /// Max lease duration. If set, it sets a maximum cap to the server-provided lease duration. + /// Useful to react faster to IP configuration changes and to test whether renews work correctly. + max_lease_duration: Option, + + retry_config: RetryConfig, + + /// Ignore NAKs. + ignore_naks: bool, + + /// Server port config + pub(crate) server_port: u16, + + /// Client port config + pub(crate) client_port: u16, + + /// A buffer contains options additional to be added to outgoing DHCP + /// packets. + outgoing_options: &'a [DhcpOption<'a>], + /// A buffer containing all requested parameters. + parameter_request_list: Option<&'a [u8]>, + + /// Incoming DHCP packets are copied into this buffer, overwriting the previous. + receive_packet_buffer: Option<&'a mut [u8]>, + + /// Waker registration + #[cfg(feature = "async")] + waker: WakerRegistration, +} + +/// DHCP client socket. +/// +/// The socket acquires an IP address configuration through DHCP autonomously. +/// You must query the configuration with `.poll()` after every call to `Interface::poll()`, +/// and apply the configuration to the `Interface`. +impl<'a> Socket<'a> { + /// Create a DHCPv4 socket + #[allow(clippy::new_without_default)] + pub fn new() -> Self { + Socket { + state: ClientState::Discovering(DiscoverState { + retry_at: Instant::from_millis(0), + }), + config_changed: true, + transaction_id: 1, + max_lease_duration: None, + retry_config: RetryConfig::default(), + ignore_naks: false, + outgoing_options: &[], + parameter_request_list: None, + receive_packet_buffer: None, + #[cfg(feature = "async")] + waker: WakerRegistration::new(), + server_port: DHCP_SERVER_PORT, + client_port: DHCP_CLIENT_PORT, + } + } + + /// Set the retry/timeouts configuration. + pub fn set_retry_config(&mut self, config: RetryConfig) { + self.retry_config = config; + } + + /// Set the outgoing options. + pub fn set_outgoing_options(&mut self, options: &'a [DhcpOption<'a>]) { + self.outgoing_options = options; + } + + /// Set the buffer into which incoming DHCP packets are copied into. + pub fn set_receive_packet_buffer(&mut self, buffer: &'a mut [u8]) { + self.receive_packet_buffer = Some(buffer); + } + + /// Set the parameter request list. + /// + /// This should contain at least `OPT_SUBNET_MASK` (`1`), `OPT_ROUTER` + /// (`3`), and `OPT_DOMAIN_NAME_SERVER` (`6`). + pub fn set_parameter_request_list(&mut self, parameter_request_list: &'a [u8]) { + self.parameter_request_list = Some(parameter_request_list); + } + + /// Get the configured max lease duration. + /// + /// See also [`Self::set_max_lease_duration()`] + pub fn max_lease_duration(&self) -> Option { + self.max_lease_duration + } + + /// Set the max lease duration. + /// + /// When set, the lease duration will be capped at the configured duration if the + /// DHCP server gives us a longer lease. This is generally not recommended, but + /// can be useful for debugging or reacting faster to network configuration changes. + /// + /// If None, no max is applied (the lease duration from the DHCP server is used.) + pub fn set_max_lease_duration(&mut self, max_lease_duration: Option) { + self.max_lease_duration = max_lease_duration; + } + + /// Get whether to ignore NAKs. + /// + /// See also [`Self::set_ignore_naks()`] + pub fn ignore_naks(&self) -> bool { + self.ignore_naks + } + + /// Set whether to ignore NAKs. + /// + /// This is not compliant with the DHCP RFCs, since theoretically + /// we must stop using the assigned IP when receiving a NAK. This + /// can increase reliability on broken networks with buggy routers + /// or rogue DHCP servers, however. + pub fn set_ignore_naks(&mut self, ignore_naks: bool) { + self.ignore_naks = ignore_naks; + } + + /// Set the server/client port + /// + /// Allows you to specify the ports used by DHCP. + /// This is meant to support esoteric usecases allowed by the dhclient program. + pub fn set_ports(&mut self, server_port: u16, client_port: u16) { + self.server_port = server_port; + self.client_port = client_port; + } + + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { + let t = match &self.state { + ClientState::Discovering(state) => state.retry_at, + ClientState::Requesting(state) => state.retry_at, + ClientState::Renewing(state) => if state.rebinding { + state.rebind_at + } else { + state.renew_at.min(state.rebind_at) + } + .min(state.expires_at), + }; + PollAt::Time(t) + } + + pub(crate) fn process( + &mut self, + cx: &mut Context, + ip_repr: &Ipv4Repr, + repr: &UdpRepr, + payload: &[u8], + ) { + let src_ip = ip_repr.src_addr; + + // This is enforced in interface.rs. + assert!(repr.src_port == self.server_port && repr.dst_port == self.client_port); + + let dhcp_packet = match DhcpPacket::new_checked(payload) { + Ok(dhcp_packet) => dhcp_packet, + Err(e) => { + net_debug!("DHCP invalid pkt from {}: {:?}", src_ip, e); + return; + } + }; + let dhcp_repr = match DhcpRepr::parse(&dhcp_packet) { + Ok(dhcp_repr) => dhcp_repr, + Err(e) => { + net_debug!("DHCP error parsing pkt from {}: {:?}", src_ip, e); + return; + } + }; + + let HardwareAddress::Ethernet(ethernet_addr) = cx.hardware_addr() else { + panic!("using DHCPv4 socket with a non-ethernet hardware address."); + }; + + if dhcp_repr.client_hardware_address != ethernet_addr { + return; + } + if dhcp_repr.transaction_id != self.transaction_id { + return; + } + let server_identifier = match dhcp_repr.server_identifier { + Some(server_identifier) => server_identifier, + None => { + net_debug!( + "DHCP ignoring {:?} because missing server_identifier", + dhcp_repr.message_type + ); + return; + } + }; + + net_debug!( + "DHCP recv {:?} from {}: {:?}", + dhcp_repr.message_type, + src_ip, + dhcp_repr + ); + + // Copy over the payload into the receive packet buffer. + if let Some(buffer) = self.receive_packet_buffer.as_mut() { + if let Some(buffer) = buffer.get_mut(..payload.len()) { + buffer.copy_from_slice(payload); + } + } + + match (&mut self.state, dhcp_repr.message_type) { + (ClientState::Discovering(_state), DhcpMessageType::Offer) => { + if !dhcp_repr.your_ip.is_unicast() { + net_debug!("DHCP ignoring OFFER because your_ip is not unicast"); + return; + } + + self.state = ClientState::Requesting(RequestState { + retry_at: cx.now(), + retry: 0, + server: ServerInfo { + address: src_ip, + identifier: server_identifier, + }, + requested_ip: dhcp_repr.your_ip, // use the offered ip + }); + } + (ClientState::Requesting(state), DhcpMessageType::Ack) => { + if let Some((config, renew_at, rebind_at, expires_at)) = + Self::parse_ack(cx.now(), &dhcp_repr, self.max_lease_duration, state.server) + { + self.state = ClientState::Renewing(RenewState { + config, + renew_at, + rebind_at, + expires_at, + rebinding: false, + }); + self.config_changed(); + } + } + (ClientState::Requesting(_), DhcpMessageType::Nak) => { + if !self.ignore_naks { + self.reset(); + } + } + (ClientState::Renewing(state), DhcpMessageType::Ack) => { + if let Some((config, renew_at, rebind_at, expires_at)) = Self::parse_ack( + cx.now(), + &dhcp_repr, + self.max_lease_duration, + state.config.server, + ) { + state.renew_at = renew_at; + state.rebind_at = rebind_at; + state.rebinding = false; + state.expires_at = expires_at; + // The `receive_packet_buffer` field isn't populated until + // the client asks for the state, but receiving any packet + // will change it, so we indicate that the config has + // changed every time if the receive packet buffer is set, + // but we only write changes to the rest of the config now. + let config_changed = + state.config != config || self.receive_packet_buffer.is_some(); + if state.config != config { + state.config = config; + } + if config_changed { + self.config_changed(); + } + } + } + (ClientState::Renewing(_), DhcpMessageType::Nak) => { + if !self.ignore_naks { + self.reset(); + } + } + _ => { + net_debug!( + "DHCP ignoring {:?}: unexpected in current state", + dhcp_repr.message_type + ); + } + } + } + + fn parse_ack( + now: Instant, + dhcp_repr: &DhcpRepr, + max_lease_duration: Option, + server: ServerInfo, + ) -> Option<(Config<'static>, Instant, Instant, Instant)> { + let subnet_mask = match dhcp_repr.subnet_mask { + Some(subnet_mask) => subnet_mask, + None => { + net_debug!("DHCP ignoring ACK because missing subnet_mask"); + return None; + } + }; + + let prefix_len = match IpAddress::Ipv4(subnet_mask).prefix_len() { + Some(prefix_len) => prefix_len, + None => { + net_debug!("DHCP ignoring ACK because subnet_mask is not a valid mask"); + return None; + } + }; + + if !dhcp_repr.your_ip.is_unicast() { + net_debug!("DHCP ignoring ACK because your_ip is not unicast"); + return None; + } + + let mut lease_duration = dhcp_repr + .lease_duration + .map(|d| Duration::from_secs(d as _)) + .unwrap_or(DEFAULT_LEASE_DURATION); + if let Some(max_lease_duration) = max_lease_duration { + lease_duration = lease_duration.min(max_lease_duration); + } + + // Cleanup the DNS servers list, keeping only unicasts/ + // TP-Link TD-W8970 sends 0.0.0.0 as second DNS server if there's only one configured :( + let mut dns_servers = Vec::new(); + + dhcp_repr + .dns_servers + .iter() + .flatten() + .filter(|s| s.is_unicast()) + .for_each(|a| { + // This will never produce an error, as both the arrays and `dns_servers` + // have length DHCP_MAX_DNS_SERVER_COUNT + dns_servers.push(*a).ok(); + }); + + let config = Config { + server, + address: Ipv4Cidr::new(dhcp_repr.your_ip, prefix_len), + router: dhcp_repr.router, + dns_servers, + packet: None, + }; + + // Set renew and rebind times as per RFC 2131: + // Times T1 and T2 are configurable by the server through + // options. T1 defaults to (0.5 * duration_of_lease). T2 + // defaults to (0.875 * duration_of_lease). + let (renew_duration, rebind_duration) = match ( + dhcp_repr + .renew_duration + .map(|d| Duration::from_secs(d as u64)), + dhcp_repr + .rebind_duration + .map(|d| Duration::from_secs(d as u64)), + ) { + (Some(renew_duration), Some(rebind_duration)) => (renew_duration, rebind_duration), + (None, None) => (lease_duration / 2, lease_duration * 7 / 8), + // RFC 2131 does not say what to do if only one value is + // provided, so: + + // If only T1 is provided, set T2 to be 0.75 through the gap + // between T1 and the duration of the lease. If T1 is set to + // the default (0.5 * duration_of_lease), then T2 will also + // be set to the default (0.875 * duration_of_lease). + (Some(renew_duration), None) => ( + renew_duration, + renew_duration + (lease_duration - renew_duration) * 3 / 4, + ), + + // If only T2 is provided, then T1 will be set to be + // whichever is smaller of the default (0.5 * + // duration_of_lease) or T2. + (None, Some(rebind_duration)) => { + ((lease_duration / 2).min(rebind_duration), rebind_duration) + } + }; + let renew_at = now + renew_duration; + let rebind_at = now + rebind_duration; + let expires_at = now + lease_duration; + + Some((config, renew_at, rebind_at, expires_at)) + } + + #[cfg(not(test))] + fn random_transaction_id(cx: &mut Context) -> u32 { + cx.rand().rand_u32() + } + + #[cfg(test)] + fn random_transaction_id(_cx: &mut Context) -> u32 { + 0x12345678 + } + + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (Ipv4Repr, UdpRepr, DhcpRepr)) -> Result<(), E>, + { + // note: Dhcpv4Socket is only usable in ethernet mediums, so the + // unwrap can never fail. + let HardwareAddress::Ethernet(ethernet_addr) = cx.hardware_addr() else { + panic!("using DHCPv4 socket with a non-ethernet hardware address."); + }; + + // Worst case biggest IPv4 header length. + // 0x0f * 4 = 60 bytes. + const MAX_IPV4_HEADER_LEN: usize = 60; + + // We don't directly modify self.transaction_id because sending the packet + // may fail. We only want to update state after succesfully sending. + let next_transaction_id = Self::random_transaction_id(cx); + + let mut dhcp_repr = DhcpRepr { + message_type: DhcpMessageType::Discover, + transaction_id: next_transaction_id, + secs: 0, + client_hardware_address: ethernet_addr, + client_ip: Ipv4Address::UNSPECIFIED, + your_ip: Ipv4Address::UNSPECIFIED, + server_ip: Ipv4Address::UNSPECIFIED, + router: None, + subnet_mask: None, + relay_agent_ip: Ipv4Address::UNSPECIFIED, + broadcast: false, + requested_ip: None, + client_identifier: Some(ethernet_addr), + server_identifier: None, + parameter_request_list: Some( + self.parameter_request_list + .unwrap_or(DEFAULT_PARAMETER_REQUEST_LIST), + ), + max_size: Some((cx.ip_mtu() - MAX_IPV4_HEADER_LEN - UDP_HEADER_LEN) as u16), + lease_duration: None, + renew_duration: None, + rebind_duration: None, + dns_servers: None, + additional_options: self.outgoing_options, + }; + + let udp_repr = UdpRepr { + src_port: self.client_port, + dst_port: self.server_port, + }; + + let mut ipv4_repr = Ipv4Repr { + src_addr: Ipv4Address::UNSPECIFIED, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, // filled right before emit + hop_limit: 64, + }; + + match &mut self.state { + ClientState::Discovering(state) => { + if cx.now() < state.retry_at { + return Ok(()); + } + + // send packet + net_debug!( + "DHCP send DISCOVER to {}: {:?}", + ipv4_repr.dst_addr, + dhcp_repr + ); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // Update state AFTER the packet has been successfully sent. + state.retry_at = cx.now() + self.retry_config.discover_timeout; + self.transaction_id = next_transaction_id; + Ok(()) + } + ClientState::Requesting(state) => { + if cx.now() < state.retry_at { + return Ok(()); + } + + if state.retry >= self.retry_config.request_retries { + net_debug!("DHCP request retries exceeded, restarting discovery"); + self.reset(); + return Ok(()); + } + + dhcp_repr.message_type = DhcpMessageType::Request; + dhcp_repr.requested_ip = Some(state.requested_ip); + dhcp_repr.server_identifier = Some(state.server.identifier); + + net_debug!( + "DHCP send request to {}: {:?}", + ipv4_repr.dst_addr, + dhcp_repr + ); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // Exponential backoff: Double every 2 retries. + state.retry_at = cx.now() + + (self.retry_config.initial_request_timeout << (state.retry as u32 / 2)); + state.retry += 1; + + self.transaction_id = next_transaction_id; + Ok(()) + } + ClientState::Renewing(state) => { + let now = cx.now(); + if state.expires_at <= now { + net_debug!("DHCP lease expired"); + self.reset(); + // return Ok so we get polled again + return Ok(()); + } + + if now < state.renew_at || state.rebinding && now < state.rebind_at { + return Ok(()); + } + + state.rebinding |= now >= state.rebind_at; + + ipv4_repr.src_addr = state.config.address.address(); + // Renewing is unicast to the original server, rebinding is broadcast + if !state.rebinding { + ipv4_repr.dst_addr = state.config.server.address; + } + dhcp_repr.message_type = DhcpMessageType::Request; + dhcp_repr.client_ip = state.config.address.address(); + + net_debug!("DHCP send renew to {}: {:?}", ipv4_repr.dst_addr, dhcp_repr); + ipv4_repr.payload_len = udp_repr.header_len() + dhcp_repr.buffer_len(); + emit(cx, (ipv4_repr, udp_repr, dhcp_repr))?; + + // In both RENEWING and REBINDING states, if the client receives no + // response to its DHCPREQUEST message, the client SHOULD wait one-half + // of the remaining time until T2 (in RENEWING state) and one-half of + // the remaining lease time (in REBINDING state), down to a minimum of + // 60 seconds, before retransmitting the DHCPREQUEST message. + if state.rebinding { + state.rebind_at = now + + self + .retry_config + .min_renew_timeout + .max((state.expires_at - now) / 2); + } else { + state.renew_at = now + + self + .retry_config + .min_renew_timeout + .max((state.rebind_at - now) / 2) + .min(state.rebind_at - now); + } + + self.transaction_id = next_transaction_id; + Ok(()) + } + } + } + + /// Reset state and restart discovery phase. + /// + /// Use this to speed up acquisition of an address in a new + /// network if a link was down and it is now back up. + pub fn reset(&mut self) { + net_trace!("DHCP reset"); + if let ClientState::Renewing(_) = &self.state { + self.config_changed(); + } + self.state = ClientState::Discovering(DiscoverState { + retry_at: Instant::from_millis(0), + }); + } + + /// Query the socket for configuration changes. + /// + /// The socket has an internal "configuration changed" flag. If + /// set, this function returns the configuration and resets the flag. + pub fn poll(&mut self) -> Option { + if !self.config_changed { + None + } else if let ClientState::Renewing(state) = &self.state { + self.config_changed = false; + Some(Event::Configured(Config { + server: state.config.server, + address: state.config.address, + router: state.config.router, + dns_servers: state.config.dns_servers.clone(), + packet: self + .receive_packet_buffer + .as_deref() + .map(DhcpPacket::new_unchecked), + })) + } else { + self.config_changed = false; + Some(Event::Deconfigured) + } + } + + /// This function _must_ be called when the configuration provided to the + /// interface, by this DHCP socket, changes. It will update the `config_changed` field + /// so that a subsequent call to `poll` will yield an event, and wake a possible waker. + pub(crate) fn config_changed(&mut self) { + self.config_changed = true; + #[cfg(feature = "async")] + self.waker.wake(); + } + + /// Register a waker. + /// + /// The waker is woken on state changes that might affect the return value + /// of `poll` method calls, which indicates a new state in the DHCP configuration + /// provided by this DHCP socket. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + #[cfg(feature = "async")] + pub fn register_waker(&mut self, waker: &Waker) { + self.waker.register(waker) + } +} + +#[cfg(test)] +mod test { + + use std::ops::{Deref, DerefMut}; + + use super::*; + use crate::wire::EthernetAddress; + + // =========================================================================================// + // Helper functions + + struct TestSocket { + socket: Socket<'static>, + cx: Context, + } + + impl Deref for TestSocket { + type Target = Socket<'static>; + fn deref(&self) -> &Self::Target { + &self.socket + } + } + + impl DerefMut for TestSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.socket + } + } + + fn send( + s: &mut TestSocket, + timestamp: Instant, + (ip_repr, udp_repr, dhcp_repr): (Ipv4Repr, UdpRepr, DhcpRepr), + ) { + s.cx.set_now(timestamp); + + net_trace!("send: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + + let mut payload = vec![0; dhcp_repr.buffer_len()]; + dhcp_repr + .emit(&mut DhcpPacket::new_unchecked(&mut payload)) + .unwrap(); + + s.socket.process(&mut s.cx, &ip_repr, &udp_repr, &payload) + } + + fn recv(s: &mut TestSocket, timestamp: Instant, reprs: &[(Ipv4Repr, UdpRepr, DhcpRepr)]) { + s.cx.set_now(timestamp); + + let mut i = 0; + + while s.socket.poll_at(&mut s.cx) <= PollAt::Time(timestamp) { + let _ = s + .socket + .dispatch(&mut s.cx, |_, (mut ip_repr, udp_repr, dhcp_repr)| { + assert_eq!(ip_repr.next_header, IpProtocol::Udp); + assert_eq!( + ip_repr.payload_len, + udp_repr.header_len() + dhcp_repr.buffer_len() + ); + + // We validated the payload len, change it to 0 to make equality testing easier + ip_repr.payload_len = 0; + + net_trace!("recv: {:?}", ip_repr); + net_trace!(" {:?}", udp_repr); + net_trace!(" {:?}", dhcp_repr); + + let got_repr = (ip_repr, udp_repr, dhcp_repr); + match reprs.get(i) { + Some(want_repr) => assert_eq!(want_repr, &got_repr), + None => panic!("Too many reprs emitted"), + } + i += 1; + Ok::<_, ()>(()) + }); + } + + assert_eq!(i, reprs.len()); + } + + macro_rules! send { + ($socket:ident, $repr:expr) => + (send!($socket, time 0, $repr)); + ($socket:ident, time $time:expr, $repr:expr) => + (send(&mut $socket, Instant::from_millis($time), $repr)); + } + + macro_rules! recv { + ($socket:ident, $reprs:expr) => ({ + recv!($socket, time 0, $reprs); + }); + ($socket:ident, time $time:expr, $reprs:expr) => ({ + recv(&mut $socket, Instant::from_millis($time), &$reprs); + }); + } + + // =========================================================================================// + // Constants + + const TXID: u32 = 0x12345678; + + const MY_IP: Ipv4Address = Ipv4Address([192, 168, 1, 42]); + const SERVER_IP: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + const DNS_IP_1: Ipv4Address = Ipv4Address([1, 1, 1, 1]); + const DNS_IP_2: Ipv4Address = Ipv4Address([1, 1, 1, 2]); + const DNS_IP_3: Ipv4Address = Ipv4Address([1, 1, 1, 3]); + const DNS_IPS: &[Ipv4Address] = &[DNS_IP_1, DNS_IP_2, DNS_IP_3]; + + const MASK_24: Ipv4Address = Ipv4Address([255, 255, 255, 0]); + + const MY_MAC: EthernetAddress = EthernetAddress([0x02, 0x02, 0x02, 0x02, 0x02, 0x02]); + + const IP_BROADCAST: Ipv4Repr = Ipv4Repr { + src_addr: Ipv4Address::UNSPECIFIED, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_BROADCAST_ADDRESSED: Ipv4Repr = Ipv4Repr { + src_addr: MY_IP, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_SERVER_BROADCAST: Ipv4Repr = Ipv4Repr { + src_addr: SERVER_IP, + dst_addr: Ipv4Address::BROADCAST, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_RECV: Ipv4Repr = Ipv4Repr { + src_addr: SERVER_IP, + dst_addr: MY_IP, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const IP_SEND: Ipv4Repr = Ipv4Repr { + src_addr: MY_IP, + dst_addr: SERVER_IP, + next_header: IpProtocol::Udp, + payload_len: 0, + hop_limit: 64, + }; + + const UDP_SEND: UdpRepr = UdpRepr { + src_port: DHCP_CLIENT_PORT, + dst_port: DHCP_SERVER_PORT, + }; + const UDP_RECV: UdpRepr = UdpRepr { + src_port: DHCP_SERVER_PORT, + dst_port: DHCP_CLIENT_PORT, + }; + + const DIFFERENT_CLIENT_PORT: u16 = 6800; + const DIFFERENT_SERVER_PORT: u16 = 6700; + + const UDP_SEND_DIFFERENT_PORT: UdpRepr = UdpRepr { + src_port: DIFFERENT_CLIENT_PORT, + dst_port: DIFFERENT_SERVER_PORT, + }; + const UDP_RECV_DIFFERENT_PORT: UdpRepr = UdpRepr { + src_port: DIFFERENT_SERVER_PORT, + dst_port: DIFFERENT_CLIENT_PORT, + }; + + const DHCP_DEFAULT: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Unknown(99), + transaction_id: TXID, + secs: 0, + client_hardware_address: MY_MAC, + client_ip: Ipv4Address::UNSPECIFIED, + your_ip: Ipv4Address::UNSPECIFIED, + server_ip: Ipv4Address::UNSPECIFIED, + router: None, + subnet_mask: None, + relay_agent_ip: Ipv4Address::UNSPECIFIED, + broadcast: false, + requested_ip: None, + client_identifier: None, + server_identifier: None, + parameter_request_list: None, + dns_servers: None, + max_size: None, + renew_duration: None, + rebind_duration: None, + lease_duration: None, + additional_options: &[], + }; + + const DHCP_DISCOVER: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Discover, + client_identifier: Some(MY_MAC), + parameter_request_list: Some(&[1, 3, 6]), + max_size: Some(1432), + ..DHCP_DEFAULT + }; + + fn dhcp_offer() -> DhcpRepr<'static> { + DhcpRepr { + message_type: DhcpMessageType::Offer, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()), + lease_duration: Some(1000), + + ..DHCP_DEFAULT + } + } + + const DHCP_REQUEST: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + server_identifier: Some(SERVER_IP), + max_size: Some(1432), + + requested_ip: Some(MY_IP), + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + fn dhcp_ack() -> DhcpRepr<'static> { + DhcpRepr { + message_type: DhcpMessageType::Ack, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + + your_ip: MY_IP, + router: Some(SERVER_IP), + subnet_mask: Some(MASK_24), + dns_servers: Some(Vec::from_slice(DNS_IPS).unwrap()), + lease_duration: Some(1000), + + ..DHCP_DEFAULT + } + } + + const DHCP_NAK: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Nak, + server_ip: SERVER_IP, + server_identifier: Some(SERVER_IP), + ..DHCP_DEFAULT + }; + + const DHCP_RENEW: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + // NO server_identifier in renew requests, only in first one! + client_ip: MY_IP, + max_size: Some(1432), + + requested_ip: None, + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + const DHCP_REBIND: DhcpRepr = DhcpRepr { + message_type: DhcpMessageType::Request, + client_identifier: Some(MY_MAC), + // NO server_identifier in renew requests, only in first one! + client_ip: MY_IP, + max_size: Some(1432), + + requested_ip: None, + parameter_request_list: Some(&[1, 3, 6]), + ..DHCP_DEFAULT + }; + + // =========================================================================================// + // Tests + + fn socket() -> TestSocket { + let mut s = Socket::new(); + assert_eq!(s.poll(), Some(Event::Deconfigured)); + TestSocket { + socket: s, + cx: Context::mock(), + } + } + + fn socket_different_port() -> TestSocket { + let mut s = Socket::new(); + s.set_ports(DIFFERENT_SERVER_PORT, DIFFERENT_CLIENT_PORT); + + assert_eq!(s.poll(), Some(Event::Deconfigured)); + TestSocket { + socket: s, + cx: Context::mock(), + } + } + + fn socket_bound() -> TestSocket { + let mut s = socket(); + s.state = ClientState::Renewing(RenewState { + config: Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + }, + renew_at: Instant::from_secs(500), + rebind_at: Instant::from_secs(875), + rebinding: false, + expires_at: Instant::from_secs(1000), + }); + + s + } + + #[test] + fn test_bind() { + let mut s = socket(); + + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, dhcp_offer())); + assert_eq!(s.poll(), None); + recv!(s, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV, dhcp_ack())); + + assert_eq!( + s.poll(), + Some(Event::Configured(Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + })) + ); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(500)); + assert_eq!(r.rebind_at, Instant::from_secs(875)); + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_bind_different_ports() { + let mut s = socket_different_port(); + + recv!(s, [(IP_BROADCAST, UDP_SEND_DIFFERENT_PORT, DHCP_DISCOVER)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV_DIFFERENT_PORT, dhcp_offer())); + assert_eq!(s.poll(), None); + recv!(s, [(IP_BROADCAST, UDP_SEND_DIFFERENT_PORT, DHCP_REQUEST)]); + assert_eq!(s.poll(), None); + send!(s, (IP_RECV, UDP_RECV_DIFFERENT_PORT, dhcp_ack())); + + assert_eq!( + s.poll(), + Some(Event::Configured(Config { + server: ServerInfo { + address: SERVER_IP, + identifier: SERVER_IP, + }, + address: Ipv4Cidr::new(MY_IP, 24), + dns_servers: Vec::from_slice(DNS_IPS).unwrap(), + router: Some(SERVER_IP), + packet: None, + })) + ); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(500)); + assert_eq!(r.rebind_at, Instant::from_secs(875)); + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_discover_retransmit() { + let mut s = socket(); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + recv!(s, time 1_000, []); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + recv!(s, time 11_000, []); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + + // check after retransmits it still works + send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + } + + #[test] + fn test_request_retransmit() { + let mut s = socket(); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 1_000, []); + recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 6_000, []); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 15_000, []); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + + // check after retransmits it still works + send!(s, time 20_000, (IP_RECV, UDP_RECV, dhcp_ack())); + + match &s.state { + ClientState::Renewing(r) => { + assert_eq!(r.renew_at, Instant::from_secs(20 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(20 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_request_timeout() { + let mut s = socket(); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 5_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 10_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 20_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + recv!(s, time 30_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + + // After 5 tries and 70 seconds, it gives up. + // 5 + 5 + 10 + 10 + 20 = 70 + recv!(s, time 70_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + + // check it still works + send!(s, time 60_000, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 60_000, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + } + + #[test] + fn test_request_nak() { + let mut s = socket(); + + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + send!(s, time 0, (IP_RECV, UDP_RECV, dhcp_offer())); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_REQUEST)]); + send!(s, time 0, (IP_SERVER_BROADCAST, UDP_RECV, DHCP_NAK)); + recv!(s, time 0, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + } + + #[test] + fn test_renew() { + let mut s = socket_bound(); + + recv!(s, []); + assert_eq!(s.poll(), None); + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + assert_eq!(s.poll(), None); + + match &s.state { + ClientState::Renewing(r) => { + // the expiration still hasn't been bumped, because + // we haven't received the ACK yet + assert_eq!(r.expires_at, Instant::from_secs(1000)); + } + _ => panic!("Invalid state"), + } + + send!(s, time 500_000, (IP_RECV, UDP_RECV, dhcp_ack())); + assert_eq!(s.poll(), None); + + match &s.state { + ClientState::Renewing(r) => { + // NOW the expiration gets bumped + assert_eq!(r.renew_at, Instant::from_secs(500 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(500 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_renew_rebind_retransmit() { + let mut s = socket_bound(); + + recv!(s, []); + // First renew attempt at T1 + recv!(s, time 499_000, []); + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way to T2 + recv!(s, time 687_000, []); + recv!(s, time 687_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way again to T2 + recv!(s, time 781_000, []); + recv!(s, time 781_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 60s later (minimum interval) + recv!(s, time 841_000, []); + recv!(s, time 841_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // No more renews due to minimum interval + recv!(s, time 874_000, []); + // First rebind attempt + recv!(s, time 875_000, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // Next rebind attempt half way to expiry + recv!(s, time 937_000, []); + recv!(s, time 937_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // Next rebind attempt 60s later (minimum interval) + recv!(s, time 997_000, []); + recv!(s, time 997_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + + // check it still works + send!(s, time 999_000, (IP_RECV, UDP_RECV, dhcp_ack())); + match &s.state { + ClientState::Renewing(r) => { + // NOW the expiration gets bumped + assert_eq!(r.renew_at, Instant::from_secs(999 + 500)); + assert_eq!(r.expires_at, Instant::from_secs(999 + 1000)); + } + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_renew_rebind_timeout() { + let mut s = socket_bound(); + + recv!(s, []); + // First renew attempt at T1 + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way to T2 + recv!(s, time 687_500, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt at half way again to T2 + recv!(s, time 781_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // Next renew attempt 60s later (minimum interval) + recv!(s, time 841_250, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + // TODO uncomment below part of test + // // First rebind attempt + // recv!(s, time 875_000, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // // Next rebind attempt half way to expiry + // recv!(s, time 937_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // // Next rebind attempt 60s later (minimum interval) + // recv!(s, time 997_500, [(IP_BROADCAST_ADDRESSED, UDP_SEND, DHCP_REBIND)]); + // No more rebinds due to minimum interval + recv!(s, time 1_000_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + match &s.state { + ClientState::Discovering(_) => {} + _ => panic!("Invalid state"), + } + } + + #[test] + fn test_renew_nak() { + let mut s = socket_bound(); + + recv!(s, time 500_000, [(IP_SEND, UDP_SEND, DHCP_RENEW)]); + send!(s, time 500_000, (IP_SERVER_BROADCAST, UDP_RECV, DHCP_NAK)); + recv!(s, time 500_000, [(IP_BROADCAST, UDP_SEND, DHCP_DISCOVER)]); + } +} diff --git a/src/socket/dns.rs b/src/socket/dns.rs new file mode 100644 index 000000000..ca267b055 --- /dev/null +++ b/src/socket/dns.rs @@ -0,0 +1,699 @@ +#[cfg(feature = "async")] +use core::task::Waker; + +use heapless::Vec; +use managed::ManagedSlice; + +use crate::config::{DNS_MAX_NAME_SIZE, DNS_MAX_RESULT_COUNT, DNS_MAX_SERVER_COUNT}; +use crate::socket::{Context, PollAt}; +use crate::time::{Duration, Instant}; +use crate::wire::dns::{Flags, Opcode, Packet, Question, Rcode, Record, RecordData, Repr, Type}; +use crate::wire::{self, IpAddress, IpProtocol, IpRepr, UdpRepr}; + +#[cfg(feature = "async")] +use super::WakerRegistration; + +const DNS_PORT: u16 = 53; +const MDNS_DNS_PORT: u16 = 5353; +const RETRANSMIT_DELAY: Duration = Duration::from_millis(1_000); +const MAX_RETRANSMIT_DELAY: Duration = Duration::from_millis(10_000); +const RETRANSMIT_TIMEOUT: Duration = Duration::from_millis(10_000); // Should generally be 2-10 secs + +#[cfg(feature = "proto-ipv6")] +const MDNS_IPV6_ADDR: IpAddress = IpAddress::Ipv6(crate::wire::Ipv6Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfb, +])); + +#[cfg(feature = "proto-ipv4")] +const MDNS_IPV4_ADDR: IpAddress = IpAddress::Ipv4(crate::wire::Ipv4Address([224, 0, 0, 251])); + +/// Error returned by [`Socket::start_query`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum StartQueryError { + NoFreeSlot, + InvalidName, + NameTooLong, +} + +impl core::fmt::Display for StartQueryError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + StartQueryError::NoFreeSlot => write!(f, "No free slot"), + StartQueryError::InvalidName => write!(f, "Invalid name"), + StartQueryError::NameTooLong => write!(f, "Name too long"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for StartQueryError {} + +/// Error returned by [`Socket::get_query_result`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum GetQueryResultError { + /// Query is not done yet. + Pending, + /// Query failed. + Failed, +} + +impl core::fmt::Display for GetQueryResultError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + GetQueryResultError::Pending => write!(f, "Query is not done yet"), + GetQueryResultError::Failed => write!(f, "Query failed"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for GetQueryResultError {} + +/// State for an in-progress DNS query. +/// +/// The only reason this struct is public is to allow the socket state +/// to be allocated externally. +#[derive(Debug)] +pub struct DnsQuery { + state: State, + + #[cfg(feature = "async")] + waker: WakerRegistration, +} + +impl DnsQuery { + fn set_state(&mut self, state: State) { + self.state = state; + #[cfg(feature = "async")] + self.waker.wake(); + } +} + +#[derive(Debug)] +#[allow(clippy::large_enum_variant)] +enum State { + Pending(PendingQuery), + Completed(CompletedQuery), + Failure, +} + +#[derive(Debug)] +struct PendingQuery { + name: Vec, + type_: Type, + + port: u16, // UDP port (src for request, dst for response) + txid: u16, // transaction ID + + timeout_at: Option, + retransmit_at: Instant, + delay: Duration, + + server_idx: usize, + mdns: MulticastDns, +} + +#[derive(Debug)] +pub enum MulticastDns { + Disabled, + #[cfg(feature = "socket-mdns")] + Enabled, +} + +#[derive(Debug)] +struct CompletedQuery { + addresses: Vec, +} + +/// A handle to an in-progress DNS query. +#[derive(Clone, Copy)] +pub struct QueryHandle(usize); + +/// A Domain Name System socket. +/// +/// A UDP socket is bound to a specific endpoint, and owns transmit and receive +/// packet buffers. +#[derive(Debug)] +pub struct Socket<'a> { + servers: Vec, + queries: ManagedSlice<'a, Option>, + + /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + hop_limit: Option, +} + +impl<'a> Socket<'a> { + /// Create a DNS socket. + /// + /// # Panics + /// + /// Panics if `servers.len() > MAX_SERVER_COUNT` + pub fn new(servers: &[IpAddress], queries: Q) -> Socket<'a> + where + Q: Into>>, + { + Socket { + servers: Vec::from_slice(servers).unwrap(), + queries: queries.into(), + hop_limit: None, + } + } + + /// Update the list of DNS servers, will replace all existing servers + /// + /// # Panics + /// + /// Panics if `servers.len() > MAX_SERVER_COUNT` + pub fn update_servers(&mut self, servers: &[IpAddress]) { + self.servers = Vec::from_slice(servers).unwrap(); + } + + /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// See also the [set_hop_limit](#method.set_hop_limit) method + pub fn hop_limit(&self) -> Option { + self.hop_limit + } + + /// Set the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. + /// + /// A socket without an explicitly set hop limit value uses the default [IANA recommended] + /// value (64). + /// + /// # Panics + /// + /// This function panics if a hop limit value of 0 is given. See [RFC 1122 § 3.2.1.7]. + /// + /// [IANA recommended]: https://www.iana.org/assignments/ip-parameters/ip-parameters.xhtml + /// [RFC 1122 § 3.2.1.7]: https://tools.ietf.org/html/rfc1122#section-3.2.1.7 + pub fn set_hop_limit(&mut self, hop_limit: Option) { + // A host MUST NOT send a datagram with a hop limit value of 0 + if let Some(0) = hop_limit { + panic!("the time-to-live value of a packet must not be zero") + } + + self.hop_limit = hop_limit + } + + fn find_free_query(&mut self) -> Option { + for (i, q) in self.queries.iter().enumerate() { + if q.is_none() { + return Some(QueryHandle(i)); + } + } + + match &mut self.queries { + ManagedSlice::Borrowed(_) => None, + #[cfg(feature = "alloc")] + ManagedSlice::Owned(queries) => { + queries.push(None); + let index = queries.len() - 1; + Some(QueryHandle(index)) + } + } + } + + /// Start a query. + /// + /// `name` is specified in human-friendly format, such as `"rust-lang.org"`. + /// It accepts names both with and without trailing dot, and they're treated + /// the same (there's no support for DNS search path). + pub fn start_query( + &mut self, + cx: &mut Context, + name: &str, + query_type: Type, + ) -> Result { + let mut name = name.as_bytes(); + + if name.is_empty() { + net_trace!("invalid name: zero length"); + return Err(StartQueryError::InvalidName); + } + + // Remove trailing dot, if any + if name[name.len() - 1] == b'.' { + name = &name[..name.len() - 1]; + } + + let mut raw_name: Vec = Vec::new(); + + let mut mdns = MulticastDns::Disabled; + #[cfg(feature = "socket-mdns")] + if name.split(|&c| c == b'.').last().unwrap() == b"local" { + net_trace!("Starting a mDNS query"); + mdns = MulticastDns::Enabled; + } + + for s in name.split(|&c| c == b'.') { + if s.len() > 63 { + net_trace!("invalid name: too long label"); + return Err(StartQueryError::InvalidName); + } + if s.is_empty() { + net_trace!("invalid name: zero length label"); + return Err(StartQueryError::InvalidName); + } + + // Push label + raw_name + .push(s.len() as u8) + .map_err(|_| StartQueryError::NameTooLong)?; + raw_name + .extend_from_slice(s) + .map_err(|_| StartQueryError::NameTooLong)?; + } + + // Push terminator. + raw_name + .push(0x00) + .map_err(|_| StartQueryError::NameTooLong)?; + + self.start_query_raw(cx, &raw_name, query_type, mdns) + } + + /// Start a query with a raw (wire-format) DNS name. + /// `b"\x09rust-lang\x03org\x00"` + /// + /// You probably want to use [`start_query`] instead. + pub fn start_query_raw( + &mut self, + cx: &mut Context, + raw_name: &[u8], + query_type: Type, + mdns: MulticastDns, + ) -> Result { + let handle = self.find_free_query().ok_or(StartQueryError::NoFreeSlot)?; + + self.queries[handle.0] = Some(DnsQuery { + state: State::Pending(PendingQuery { + name: Vec::from_slice(raw_name).map_err(|_| StartQueryError::NameTooLong)?, + type_: query_type, + txid: cx.rand().rand_u16(), + port: cx.rand().rand_source_port(), + delay: RETRANSMIT_DELAY, + timeout_at: None, + retransmit_at: Instant::ZERO, + server_idx: 0, + mdns, + }), + #[cfg(feature = "async")] + waker: WakerRegistration::new(), + }); + Ok(handle) + } + + /// Get the result of a query. + /// + /// If the query is completed, the query slot is automatically freed. + /// + /// # Panics + /// Panics if the QueryHandle corresponds to a free slot. + pub fn get_query_result( + &mut self, + handle: QueryHandle, + ) -> Result, GetQueryResultError> { + let slot = &mut self.queries[handle.0]; + let q = slot.as_mut().unwrap(); + match &mut q.state { + // Query is not done yet. + State::Pending(_) => Err(GetQueryResultError::Pending), + // Query is done + State::Completed(q) => { + let res = q.addresses.clone(); + *slot = None; // Free up the slot for recycling. + Ok(res) + } + State::Failure => { + *slot = None; // Free up the slot for recycling. + Err(GetQueryResultError::Failed) + } + } + } + + /// Cancels a query, freeing the slot. + /// + /// # Panics + /// + /// Panics if the QueryHandle corresponds to an already free slot. + pub fn cancel_query(&mut self, handle: QueryHandle) { + let slot = &mut self.queries[handle.0]; + if slot.is_none() { + panic!("Canceling query in a free slot.") + } + *slot = None; // Free up the slot for recycling. + } + + /// Assign a waker to a query slot + /// + /// The waker will be woken when the query completes, either successfully or failed. + /// + /// # Panics + /// + /// Panics if the QueryHandle corresponds to an already free slot. + #[cfg(feature = "async")] + pub fn register_query_waker(&mut self, handle: QueryHandle, waker: &Waker) { + self.queries[handle.0] + .as_mut() + .unwrap() + .waker + .register(waker); + } + + pub(crate) fn accepts(&self, ip_repr: &IpRepr, udp_repr: &UdpRepr) -> bool { + (udp_repr.src_port == DNS_PORT + && self + .servers + .iter() + .any(|server| *server == ip_repr.src_addr())) + || (udp_repr.src_port == MDNS_DNS_PORT) + } + + pub(crate) fn process( + &mut self, + _cx: &mut Context, + ip_repr: &IpRepr, + udp_repr: &UdpRepr, + payload: &[u8], + ) { + debug_assert!(self.accepts(ip_repr, udp_repr)); + + let size = payload.len(); + + net_trace!( + "receiving {} octets from {:?}:{}", + size, + ip_repr.src_addr(), + udp_repr.dst_port + ); + + let p = match Packet::new_checked(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("dns packet malformed"); + return; + } + }; + if p.opcode() != Opcode::Query { + net_trace!("unwanted opcode {:?}", p.opcode()); + return; + } + + if !p.flags().contains(Flags::RESPONSE) { + net_trace!("packet doesn't have response bit set"); + return; + } + + if p.question_count() != 1 { + net_trace!("bad question count {:?}", p.question_count()); + return; + } + + // Find pending query + for q in self.queries.iter_mut().flatten() { + if let State::Pending(pq) = &mut q.state { + if udp_repr.dst_port != pq.port || p.transaction_id() != pq.txid { + continue; + } + + if p.rcode() == Rcode::NXDomain { + net_trace!("rcode NXDomain"); + q.set_state(State::Failure); + continue; + } + + let payload = p.payload(); + let (mut payload, question) = match Question::parse(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("question malformed"); + return; + } + }; + + if question.type_ != pq.type_ { + net_trace!("question type mismatch"); + return; + } + + match eq_names(p.parse_name(question.name), p.parse_name(&pq.name)) { + Ok(true) => {} + Ok(false) => { + net_trace!("question name mismatch"); + return; + } + Err(_) => { + net_trace!("dns question name malformed"); + return; + } + } + + let mut addresses = Vec::new(); + + for _ in 0..p.answer_record_count() { + let (payload2, r) = match Record::parse(payload) { + Ok(x) => x, + Err(_) => { + net_trace!("dns answer record malformed"); + return; + } + }; + payload = payload2; + + match eq_names(p.parse_name(r.name), p.parse_name(&pq.name)) { + Ok(true) => {} + Ok(false) => { + net_trace!("answer name mismatch: {:?}", r); + continue; + } + Err(_) => { + net_trace!("dns answer record name malformed"); + return; + } + } + + match r.data { + #[cfg(feature = "proto-ipv4")] + RecordData::A(addr) => { + net_trace!("A: {:?}", addr); + if addresses.push(addr.into()).is_err() { + net_trace!("too many addresses in response, ignoring {:?}", addr); + } + } + #[cfg(feature = "proto-ipv6")] + RecordData::Aaaa(addr) => { + net_trace!("AAAA: {:?}", addr); + if addresses.push(addr.into()).is_err() { + net_trace!("too many addresses in response, ignoring {:?}", addr); + } + } + RecordData::Cname(name) => { + net_trace!("CNAME: {:?}", name); + + // When faced with a CNAME, recursive resolvers are supposed to + // resolve the CNAME and append the results for it. + // + // We update the query with the new name, so that we pick up the A/AAAA + // records for the CNAME when we parse them later. + // I believe it's mandatory the CNAME results MUST come *after* in the + // packet, so it's enough to do one linear pass over it. + if copy_name(&mut pq.name, p.parse_name(name)).is_err() { + net_trace!("dns answer cname malformed"); + return; + } + } + RecordData::Other(type_, data) => { + net_trace!("unknown: {:?} {:?}", type_, data) + } + } + } + + q.set_state(if addresses.is_empty() { + State::Failure + } else { + State::Completed(CompletedQuery { addresses }) + }); + + // If we get here, packet matched the current query, stop processing. + return; + } + } + + // If we get here, packet matched with no query. + net_trace!("no query matched"); + } + + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, + { + let hop_limit = self.hop_limit.unwrap_or(64); + + for q in self.queries.iter_mut().flatten() { + if let State::Pending(pq) = &mut q.state { + // As per RFC 6762 any DNS query ending in .local. MUST be sent as mdns + // so we internally overwrite the servers for any of those queries + // in this function. + let servers = match pq.mdns { + #[cfg(feature = "socket-mdns")] + MulticastDns::Enabled => &[ + #[cfg(feature = "proto-ipv6")] + MDNS_IPV6_ADDR, + #[cfg(feature = "proto-ipv4")] + MDNS_IPV4_ADDR, + ], + MulticastDns::Disabled => self.servers.as_slice(), + }; + + let timeout = if let Some(timeout) = pq.timeout_at { + timeout + } else { + let v = cx.now() + RETRANSMIT_TIMEOUT; + pq.timeout_at = Some(v); + v + }; + + // Check timeout + if timeout < cx.now() { + // DNS timeout + pq.timeout_at = Some(cx.now() + RETRANSMIT_TIMEOUT); + pq.retransmit_at = Instant::ZERO; + pq.delay = RETRANSMIT_DELAY; + + // Try next server. We check below whether we've tried all servers. + pq.server_idx += 1; + } + // Check if we've run out of servers to try. + if pq.server_idx >= servers.len() { + net_trace!("already tried all servers."); + q.set_state(State::Failure); + continue; + } + + // Check so the IP address is valid + if servers[pq.server_idx].is_unspecified() { + net_trace!("invalid unspecified DNS server addr."); + q.set_state(State::Failure); + continue; + } + + if pq.retransmit_at > cx.now() { + // query is waiting for retransmit + continue; + } + + let repr = Repr { + transaction_id: pq.txid, + flags: Flags::RECURSION_DESIRED, + opcode: Opcode::Query, + question: Question { + name: &pq.name, + type_: pq.type_, + }, + }; + + let mut payload = [0u8; 512]; + let payload = &mut payload[..repr.buffer_len()]; + repr.emit(&mut Packet::new_unchecked(payload)); + + let dst_port = match pq.mdns { + #[cfg(feature = "socket-mdns")] + MulticastDns::Enabled => MDNS_DNS_PORT, + MulticastDns::Disabled => DNS_PORT, + }; + + let udp_repr = UdpRepr { + src_port: pq.port, + dst_port, + }; + + let dst_addr = servers[pq.server_idx]; + let src_addr = cx.get_source_address(dst_addr).unwrap(); // TODO remove unwrap + let ip_repr = IpRepr::new( + src_addr, + dst_addr, + IpProtocol::Udp, + udp_repr.header_len() + payload.len(), + hop_limit, + ); + + net_trace!( + "sending {} octets to {} from port {}", + payload.len(), + ip_repr.dst_addr(), + udp_repr.src_port + ); + + emit(cx, (ip_repr, udp_repr, payload))?; + + pq.retransmit_at = cx.now() + pq.delay; + pq.delay = MAX_RETRANSMIT_DELAY.min(pq.delay * 2); + + return Ok(()); + } + } + + // Nothing to dispatch + Ok(()) + } + + pub(crate) fn poll_at(&self, _cx: &Context) -> PollAt { + self.queries + .iter() + .flatten() + .filter_map(|q| match &q.state { + State::Pending(pq) => Some(PollAt::Time(pq.retransmit_at)), + State::Completed(_) => None, + State::Failure => None, + }) + .min() + .unwrap_or(PollAt::Ingress) + } +} + +fn eq_names<'a>( + mut a: impl Iterator>, + mut b: impl Iterator>, +) -> wire::Result { + loop { + match (a.next(), b.next()) { + // Handle errors + (Some(Err(e)), _) => return Err(e), + (_, Some(Err(e))) => return Err(e), + + // Both finished -> equal + (None, None) => return Ok(true), + + // One finished before the other -> not equal + (None, _) => return Ok(false), + (_, None) => return Ok(false), + + // Got two labels, check if they're equal + (Some(Ok(la)), Some(Ok(lb))) => { + if la != lb { + return Ok(false); + } + } + } + } +} + +fn copy_name<'a, const N: usize>( + dest: &mut Vec, + name: impl Iterator>, +) -> Result<(), wire::Error> { + dest.truncate(0); + + for label in name { + let label = label?; + dest.push(label.len() as u8).map_err(|_| wire::Error)?; + dest.extend_from_slice(label).map_err(|_| wire::Error)?; + } + + // Write terminator 0x00 + dest.push(0).map_err(|_| wire::Error)?; + + Ok(()) +} diff --git a/src/socket/icmp.rs b/src/socket/icmp.rs index b06daaf14..b6867585e 100644 --- a/src/socket/icmp.rs +++ b/src/socket/icmp.rs @@ -1,27 +1,90 @@ use core::cmp; +#[cfg(feature = "async")] +use core::task::Waker; -use {Error, Result}; -use phy::{ChecksumCapabilities, DeviceCapabilities}; -use socket::{Socket, SocketMeta, SocketHandle, PollAt}; -use storage::{PacketBuffer, PacketMetadata}; -use wire::{IpAddress, IpEndpoint, IpProtocol, IpRepr}; +use crate::phy::ChecksumCapabilities; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::socket::{Context, PollAt}; +use crate::storage::Empty; +use crate::wire::IcmpRepr; #[cfg(feature = "proto-ipv4")] -use wire::{Ipv4Address, Ipv4Repr, Icmpv4Packet, Icmpv4Repr}; +use crate::wire::{Icmpv4Packet, Icmpv4Repr, Ipv4Repr}; #[cfg(feature = "proto-ipv6")] -use wire::{Ipv6Address, Ipv6Repr, Icmpv6Packet, Icmpv6Repr}; -use wire::IcmpRepr; -use wire::{UdpPacket, UdpRepr}; +use crate::wire::{Icmpv6Packet, Icmpv6Repr, Ipv6Repr}; +use crate::wire::{IpAddress, IpListenEndpoint, IpProtocol, IpRepr}; +use crate::wire::{UdpPacket, UdpRepr}; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + Unaddressable, + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + SendError::Unaddressable => write!(f, "unaddressable"), + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} /// Type of endpoint to bind the ICMP socket to. See [IcmpSocket::bind] for /// more details. /// /// [IcmpSocket::bind]: struct.IcmpSocket.html#method.bind -#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[derive(Debug, Default, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Endpoint { + #[default] Unspecified, Ident(u16), - Udp(IpEndpoint) + Udp(IpListenEndpoint), } impl Endpoint { @@ -29,58 +92,88 @@ impl Endpoint { match *self { Endpoint::Ident(_) => true, Endpoint::Udp(endpoint) => endpoint.port != 0, - Endpoint::Unspecified => false + Endpoint::Unspecified => false, } } } -impl Default for Endpoint { - fn default() -> Endpoint { Endpoint::Unspecified } -} - /// An ICMP packet metadata. -pub type IcmpPacketMetadata = PacketMetadata; +pub type PacketMetadata = crate::storage::PacketMetadata; /// An ICMP packet ring buffer. -pub type IcmpSocketBuffer<'a, 'b> = PacketBuffer<'a, 'b, IpAddress>; +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, IpAddress>; /// A ICMP socket /// /// An ICMP socket is bound to a specific [IcmpEndpoint] which may -/// be a sepecific UDP port to listen for ICMP error messages related +/// be a specific UDP port to listen for ICMP error messages related /// to the port or a specific ICMP identifier value. See [bind] for /// more details. /// /// [IcmpEndpoint]: enum.IcmpEndpoint.html /// [bind]: #method.bind #[derive(Debug)] -pub struct IcmpSocket<'a, 'b: 'a> { - pub(crate) meta: SocketMeta, - rx_buffer: IcmpSocketBuffer<'a, 'b>, - tx_buffer: IcmpSocketBuffer<'a, 'b>, +pub struct Socket<'a> { + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, /// The endpoint this socket is communicating with - endpoint: Endpoint, + endpoint: Endpoint, /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. - hop_limit: Option + hop_limit: Option, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, } -impl<'a, 'b> IcmpSocket<'a, 'b> { +impl<'a> Socket<'a> { /// Create an ICMP socket with the given buffers. - pub fn new(rx_buffer: IcmpSocketBuffer<'a, 'b>, - tx_buffer: IcmpSocketBuffer<'a, 'b>) -> IcmpSocket<'a, 'b> { - IcmpSocket { - meta: SocketMeta::default(), + pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { + Socket { rx_buffer: rx_buffer, tx_buffer: tx_buffer, - endpoint: Endpoint::default(), - hop_limit: None + endpoint: Default::default(), + hop_limit: None, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), } } - /// Return the socket handle. - #[inline] - pub fn handle(&self) -> SocketHandle { - self.meta.handle + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) } /// Return the time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. @@ -126,18 +219,17 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// diagnose connection problems. /// /// ``` - /// # use smoltcp::socket::{Socket, IcmpSocket, IcmpSocketBuffer, IcmpPacketMetadata}; - /// # let rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]); - /// # let tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]); - /// use smoltcp::wire::IpEndpoint; - /// use smoltcp::socket::IcmpEndpoint; + /// use smoltcp::wire::IpListenEndpoint; + /// use smoltcp::socket::icmp; + /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); /// /// let mut icmp_socket = // ... - /// # IcmpSocket::new(rx_buffer, tx_buffer); + /// # icmp::Socket::new(rx_buffer, tx_buffer); /// /// // Bind to ICMP error responses for UDP packets sent from port 53. - /// let endpoint = IpEndpoint::from(53); - /// icmp_socket.bind(IcmpEndpoint::Udp(endpoint)).unwrap(); + /// let endpoint = IpListenEndpoint::from(53); + /// icmp_socket.bind(icmp::Endpoint::Udp(endpoint)).unwrap(); /// ``` /// /// ## Bind to a specific ICMP identifier: @@ -148,16 +240,16 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// messages. /// /// ``` - /// # use smoltcp::socket::{Socket, IcmpSocket, IcmpSocketBuffer, IcmpPacketMetadata}; - /// # let rx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]); - /// # let tx_buffer = IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY], vec![0; 20]); - /// use smoltcp::socket::IcmpEndpoint; + /// use smoltcp::wire::IpListenEndpoint; + /// use smoltcp::socket::icmp; + /// # let rx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); + /// # let tx_buffer = icmp::PacketBuffer::new(vec![icmp::PacketMetadata::EMPTY], vec![0; 20]); /// /// let mut icmp_socket = // ... - /// # IcmpSocket::new(rx_buffer, tx_buffer); + /// # icmp::Socket::new(rx_buffer, tx_buffer); /// /// // Bind to ICMP messages with the ICMP identifier 0x1234 - /// icmp_socket.bind(IcmpEndpoint::Ident(0x1234)).unwrap(); + /// icmp_socket.bind(icmp::Endpoint::Ident(0x1234)).unwrap(); /// ``` /// /// [is_specified]: enum.IcmpEndpoint.html#method.is_specified @@ -165,15 +257,24 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// [IcmpEndpoint::Udp]: enum.IcmpEndpoint.html#variant.Udp /// [send]: #method.send /// [recv]: #method.recv - pub fn bind>(&mut self, endpoint: T) -> Result<()> { + pub fn bind>(&mut self, endpoint: T) -> Result<(), BindError> { let endpoint = endpoint.into(); if !endpoint.is_specified() { - return Err(Error::Unaddressable); + return Err(BindError::Unaddressable); } - if self.is_open() { return Err(Error::Illegal) } + if self.is_open() { + return Err(BindError::InvalidState); + } self.endpoint = endpoint; + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + Ok(()) } @@ -225,22 +326,51 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// This function returns `Err(Error::Exhausted)` if the transmit buffer is full, /// `Err(Error::Truncated)` if the requested size is larger than the packet buffer /// size, and `Err(Error::Unaddressable)` if the remote address is unspecified. - pub fn send(&mut self, size: usize, endpoint: IpAddress) -> Result<&mut [u8]> { + pub fn send(&mut self, size: usize, endpoint: IpAddress) -> Result<&mut [u8], SendError> { if endpoint.is_unspecified() { - return Err(Error::Unaddressable) + return Err(SendError::Unaddressable); } - let packet_buf = self.tx_buffer.enqueue(size, endpoint)?; + let packet_buf = self + .tx_buffer + .enqueue(size, endpoint) + .map_err(|_| SendError::BufferFull)?; - net_trace!("{}:{}: buffer to send {} octets", - self.meta.handle, endpoint, size); + net_trace!("icmp:{}: buffer to send {} octets", endpoint, size); Ok(packet_buf) } + /// Enqueue a packet to be send to a given remote address and pass the buffer + /// to the provided closure. The closure then returns the size of the data written + /// into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with( + &mut self, + max_size: usize, + endpoint: IpAddress, + f: F, + ) -> Result + where + F: FnOnce(&mut [u8]) -> usize, + { + if endpoint.is_unspecified() { + return Err(SendError::Unaddressable); + } + + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, endpoint, f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!("icmp:{}: buffer to send {} octets", endpoint, size); + Ok(size) + } + /// Enqueue a packet to be sent to a given remote address, and fill it from a slice. /// /// See also [send](#method.send). - pub fn send_slice(&mut self, data: &[u8], endpoint: IpAddress) -> Result<()> { + pub fn send_slice(&mut self, data: &[u8], endpoint: IpAddress) -> Result<(), SendError> { let packet_buf = self.send(data.len(), endpoint)?; packet_buf.copy_from_slice(data); Ok(()) @@ -250,11 +380,14 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// as a pointer to the payload. /// /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. - pub fn recv(&mut self) -> Result<(&[u8], IpAddress)> { - let (endpoint, packet_buf) = self.rx_buffer.dequeue()?; - - net_trace!("{}:{}: receive {} buffered octets", - self.meta.handle, endpoint, packet_buf.len()); + pub fn recv(&mut self) -> Result<(&[u8], IpAddress), RecvError> { + let (endpoint, packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "icmp:{}: receive {} buffered octets", + endpoint, + packet_buf.len() + ); Ok((packet_buf, endpoint)) } @@ -262,7 +395,7 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// and return the amount of octets copied as well as the `IpAddress` /// /// See also [recv](#method.recv). - pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress)> { + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpAddress), RecvError> { let (buffer, endpoint) = self.recv()?; let length = cmp::min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); @@ -271,26 +404,46 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { /// Filter determining which packets received by the interface are appended to /// the given sockets received buffer. - pub(crate) fn accepts(&self, ip_repr: &IpRepr, icmp_repr: &IcmpRepr, - cksum: &ChecksumCapabilities) -> bool { + pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) -> bool { match (&self.endpoint, icmp_repr) { // If we are bound to ICMP errors associated to a UDP port, only - // accept Destination Unreachable messages with the data containing - // a UDP packet send from the local port we are bound to. + // accept Destination Unreachable or Time Exceeded messages with + // the data containing a UDP packet send from the local port we + // are bound to. #[cfg(feature = "proto-ipv4")] - (&Endpoint::Udp(endpoint), &IcmpRepr::Ipv4(Icmpv4Repr::DstUnreachable { data, .. })) - if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => { + ( + &Endpoint::Udp(endpoint), + &IcmpRepr::Ipv4( + Icmpv4Repr::DstUnreachable { data, header, .. } + | Icmpv4Repr::TimeExceeded { data, header, .. }, + ), + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => { let packet = UdpPacket::new_unchecked(data); - match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), cksum) { + match UdpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { Ok(repr) => endpoint.port == repr.src_port, Err(_) => false, } } #[cfg(feature = "proto-ipv6")] - (&Endpoint::Udp(endpoint), &IcmpRepr::Ipv6(Icmpv6Repr::DstUnreachable { data, .. })) - if endpoint.addr.is_unspecified() || endpoint.addr == ip_repr.dst_addr() => { + ( + &Endpoint::Udp(endpoint), + &IcmpRepr::Ipv6( + Icmpv6Repr::DstUnreachable { data, header, .. } + | Icmpv6Repr::TimeExceeded { data, header, .. }, + ), + ) if endpoint.addr.is_none() || endpoint.addr == Some(ip_repr.dst_addr()) => { let packet = UdpPacket::new_unchecked(data); - match UdpRepr::parse(&packet, &ip_repr.src_addr(), &ip_repr.dst_addr(), cksum) { + match UdpRepr::parse( + &packet, + &header.src_addr.into(), + &header.dst_addr.into(), + &cx.checksum_caps(), + ) { Ok(repr) => endpoint.port == repr.src_port, Err(_) => false, } @@ -299,91 +452,164 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { // Echo Request/Reply with the identifier field matching the endpoint // port. #[cfg(feature = "proto-ipv4")] - (&Endpoint::Ident(bound_ident), - &IcmpRepr::Ipv4(Icmpv4Repr::EchoRequest { ident, .. })) | - (&Endpoint::Ident(bound_ident), - &IcmpRepr::Ipv4(Icmpv4Repr::EchoReply { ident, .. })) => - ident == bound_ident, + ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv4(Icmpv4Repr::EchoRequest { ident, .. }), + ) + | ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv4(Icmpv4Repr::EchoReply { ident, .. }), + ) => ident == bound_ident, #[cfg(feature = "proto-ipv6")] - (&Endpoint::Ident(bound_ident), - &IcmpRepr::Ipv6(Icmpv6Repr::EchoRequest { ident, .. })) | - (&Endpoint::Ident(bound_ident), - &IcmpRepr::Ipv6(Icmpv6Repr::EchoReply { ident, .. })) => - ident == bound_ident, + ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv6(Icmpv6Repr::EchoRequest { ident, .. }), + ) + | ( + &Endpoint::Ident(bound_ident), + &IcmpRepr::Ipv6(Icmpv6Repr::EchoReply { ident, .. }), + ) => ident == bound_ident, _ => false, } } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, icmp_repr: &IcmpRepr, - _cksum: &ChecksumCapabilities) -> Result<()> { + pub(crate) fn process(&mut self, _cx: &mut Context, ip_repr: &IpRepr, icmp_repr: &IcmpRepr) { match icmp_repr { #[cfg(feature = "proto-ipv4")] - &IcmpRepr::Ipv4(ref icmp_repr) => { - let packet_buf = self.rx_buffer.enqueue(icmp_repr.buffer_len(), - ip_repr.src_addr())?; - icmp_repr.emit(&mut Icmpv4Packet::new_unchecked(packet_buf), - &ChecksumCapabilities::default()); - - net_trace!("{}:{}: receiving {} octets", - self.meta.handle, icmp_repr.buffer_len(), packet_buf.len()); - }, + IcmpRepr::Ipv4(icmp_repr) => { + net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len()); + + match self + .rx_buffer + .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr()) + { + Ok(packet_buf) => { + icmp_repr.emit( + &mut Icmpv4Packet::new_unchecked(packet_buf), + &ChecksumCapabilities::default(), + ); + } + Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"), + } + } #[cfg(feature = "proto-ipv6")] - &IcmpRepr::Ipv6(ref icmp_repr) => { - let packet_buf = self.rx_buffer.enqueue(icmp_repr.buffer_len(), - ip_repr.src_addr())?; - icmp_repr.emit(&ip_repr.src_addr(), &ip_repr.dst_addr(), - &mut Icmpv6Packet::new_unchecked(packet_buf), - &ChecksumCapabilities::default()); - - net_trace!("{}:{}: receiving {} octets", - self.meta.handle, icmp_repr.buffer_len(), packet_buf.len()); - }, + IcmpRepr::Ipv6(icmp_repr) => { + net_trace!("icmp: receiving {} octets", icmp_repr.buffer_len()); + + match self + .rx_buffer + .enqueue(icmp_repr.buffer_len(), ip_repr.src_addr()) + { + Ok(packet_buf) => icmp_repr.emit( + &ip_repr.src_addr(), + &ip_repr.dst_addr(), + &mut Icmpv6Packet::new_unchecked(packet_buf), + &ChecksumCapabilities::default(), + ), + Err(_) => net_trace!("icmp: buffer full, dropped incoming packet"), + } + } } - Ok(()) + + #[cfg(feature = "async")] + self.rx_waker.wake(); } - pub(crate) fn dispatch(&mut self, _caps: &DeviceCapabilities, emit: F) -> Result<()> - where F: FnOnce((IpRepr, IcmpRepr)) -> Result<()> + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, IcmpRepr)) -> Result<(), E>, { - let handle = self.meta.handle; let hop_limit = self.hop_limit.unwrap_or(64); - self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| { - net_trace!("{}:{}: sending {} octets", - handle, remote_endpoint, packet_buf.len()); + let res = self.tx_buffer.dequeue_with(|remote_endpoint, packet_buf| { + net_trace!( + "icmp:{}: sending {} octets", + remote_endpoint, + packet_buf.len() + ); match *remote_endpoint { #[cfg(feature = "proto-ipv4")] - IpAddress::Ipv4(ipv4_addr) => { + IpAddress::Ipv4(dst_addr) => { + let src_addr = match cx.get_source_address_ipv4(dst_addr) { + Some(addr) => addr, + None => { + net_trace!( + "icmp:{}: not find suitable source address, dropping", + remote_endpoint + ); + return Ok(()); + } + }; let packet = Icmpv4Packet::new_unchecked(&*packet_buf); - let repr = Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored())?; + let repr = match Icmpv4Repr::parse(&packet, &ChecksumCapabilities::ignored()) { + Ok(x) => x, + Err(_) => { + net_trace!( + "icmp:{}: malformed packet in queue, dropping", + remote_endpoint + ); + return Ok(()); + } + }; let ip_repr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::default(), - dst_addr: ipv4_addr, - protocol: IpProtocol::Icmp, + src_addr, + dst_addr, + next_header: IpProtocol::Icmp, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit: hop_limit, }); - emit((ip_repr, IcmpRepr::Ipv4(repr))) - }, + emit(cx, (ip_repr, IcmpRepr::Ipv4(repr))) + } #[cfg(feature = "proto-ipv6")] - IpAddress::Ipv6(ipv6_addr) => { + IpAddress::Ipv6(dst_addr) => { + let src_addr = match cx.get_source_address_ipv6(dst_addr) { + Some(addr) => addr, + None => { + net_trace!( + "icmp:{}: not find suitable source address, dropping", + remote_endpoint + ); + return Ok(()); + } + }; let packet = Icmpv6Packet::new_unchecked(&*packet_buf); - let src_addr = Ipv6Address::default(); - let repr = Icmpv6Repr::parse(&src_addr.into(), &ipv6_addr.into(), &packet, &ChecksumCapabilities::ignored())?; + let repr = match Icmpv6Repr::parse( + &src_addr.into(), + &dst_addr.into(), + &packet, + &ChecksumCapabilities::ignored(), + ) { + Ok(x) => x, + Err(_) => { + net_trace!( + "icmp:{}: malformed packet in queue, dropping", + remote_endpoint + ); + return Ok(()); + } + }; let ip_repr = IpRepr::Ipv6(Ipv6Repr { - src_addr: src_addr, - dst_addr: ipv6_addr, + src_addr, + dst_addr, next_header: IpProtocol::Icmpv6, payload_len: repr.buffer_len(), - hop_limit: hop_limit, + hop_limit: hop_limit, }); - emit((ip_repr, IcmpRepr::Ipv6(repr))) - }, - _ => Err(Error::Unaddressable) + emit(cx, (ip_repr, IcmpRepr::Ipv6(repr))) + } } - }) + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -392,109 +618,128 @@ impl<'a, 'b> IcmpSocket<'a, 'b> { } } -impl<'a, 'b> Into> for IcmpSocket<'a, 'b> { - fn into(self) -> Socket<'a, 'b> { - Socket::Icmp(self) - } -} - #[cfg(test)] mod tests_common { - pub use phy::DeviceCapabilities; - pub use wire::IpAddress; pub use super::*; + pub use crate::phy::DeviceCapabilities; + pub use crate::wire::IpAddress; - pub fn buffer(packets: usize) -> IcmpSocketBuffer<'static, 'static> { - IcmpSocketBuffer::new(vec![IcmpPacketMetadata::EMPTY; packets], vec![0; 66 * packets]) + pub fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 66 * packets]) } - pub fn socket(rx_buffer: IcmpSocketBuffer<'static, 'static>, - tx_buffer: IcmpSocketBuffer<'static, 'static>) -> IcmpSocket<'static, 'static> { - IcmpSocket::new(rx_buffer, tx_buffer) + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new(rx_buffer, tx_buffer) } - pub const LOCAL_PORT: u16 = 53; + pub const LOCAL_PORT: u16 = 53; pub static UDP_REPR: UdpRepr = UdpRepr { src_port: 53, dst_port: 9090, - payload: &[0xff; 10] }; + + pub static UDP_PAYLOAD: &[u8] = &[0xff; 10]; } #[cfg(all(test, feature = "proto-ipv4"))] mod test_ipv4 { use super::tests_common::*; + use crate::wire::{Icmpv4DstUnreachable, IpEndpoint, Ipv4Address}; - use wire::Icmpv4DstUnreachable; - - const REMOTE_IPV4: Ipv4Address = Ipv4Address([0x7f, 0x00, 0x00, 0x02]); - const LOCAL_IPV4: Ipv4Address = Ipv4Address([0x7f, 0x00, 0x00, 0x01]); - const LOCAL_END_V4: IpEndpoint = IpEndpoint { addr: IpAddress::Ipv4(LOCAL_IPV4), port: LOCAL_PORT }; + const REMOTE_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 2]); + const LOCAL_IPV4: Ipv4Address = Ipv4Address([192, 168, 1, 1]); + const LOCAL_END_V4: IpEndpoint = IpEndpoint { + addr: IpAddress::Ipv4(LOCAL_IPV4), + port: LOCAL_PORT, + }; static ECHOV4_REPR: Icmpv4Repr = Icmpv4Repr::EchoRequest { - ident: 0x1234, - seq_no: 0x5678, - data: &[0xff; 16] + ident: 0x1234, + seq_no: 0x5678, + data: &[0xff; 16], }; static LOCAL_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::UNSPECIFIED, + src_addr: LOCAL_IPV4, dst_addr: REMOTE_IPV4, - protocol: IpProtocol::Icmp, + next_header: IpProtocol::Icmp, payload_len: 24, - hop_limit: 0x40 + hop_limit: 0x40, }); static REMOTE_IPV4_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { src_addr: REMOTE_IPV4, dst_addr: LOCAL_IPV4, - protocol: IpProtocol::Icmp, + next_header: IpProtocol::Icmp, payload_len: 24, - hop_limit: 0x40 + hop_limit: 0x40, }); #[test] fn test_send_unaddressable() { let mut socket = socket(buffer(0), buffer(1)); - assert_eq!(socket.send_slice(b"abcdef", IpAddress::default()), - Err(Error::Unaddressable)); + assert_eq!( + socket.send_slice(b"abcdef", IpAddress::Ipv4(Ipv4Address::default())), + Err(SendError::Unaddressable) + ); assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV4.into()), Ok(())); } #[test] fn test_send_dispatch() { let mut socket = socket(buffer(0), buffer(1)); - let caps = DeviceCapabilities::default(); + let mut cx = Context::mock(); + let checksum = ChecksumCapabilities::default(); - assert_eq!(socket.dispatch(&caps, |_| unreachable!()), - Err(Error::Exhausted)); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); // This buffer is too long - assert_eq!(socket.send_slice(&[0xff; 67], REMOTE_IPV4.into()), Err(Error::Truncated)); + assert_eq!( + socket.send_slice(&[0xff; 67], REMOTE_IPV4.into()), + Err(SendError::BufferFull) + ); assert!(socket.can_send()); let mut bytes = [0xff; 24]; let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); - ECHOV4_REPR.emit(&mut packet, &caps.checksum); + ECHOV4_REPR.emit(&mut packet, &checksum); - assert_eq!(socket.send_slice(&packet.into_inner()[..], REMOTE_IPV4.into()), Ok(())); - assert_eq!(socket.send_slice(b"123456", REMOTE_IPV4.into()), Err(Error::Exhausted)); + assert_eq!( + socket.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()), + Ok(()) + ); + assert_eq!( + socket.send_slice(b"123456", REMOTE_IPV4.into()), + Err(SendError::BufferFull) + ); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&caps, |(ip_repr, icmp_repr)| { - assert_eq!(ip_repr, LOCAL_IPV4_REPR); - assert_eq!(icmp_repr, ECHOV4_REPR.into()); - Err(Error::Unaddressable) - }), Err(Error::Unaddressable)); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV4_REPR); + assert_eq!(icmp_repr, ECHOV4_REPR.into()); + Err(()) + }), + Err(()) + ); // buffer is not taken off of the tx queue due to the error assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&caps, |(ip_repr, icmp_repr)| { - assert_eq!(ip_repr, LOCAL_IPV4_REPR); - assert_eq!(icmp_repr, ECHOV4_REPR.into()); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV4_REPR); + assert_eq!(icmp_repr, ECHOV4_REPR.into()); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); // buffer is taken off of the queue this time assert!(socket.can_send()); } @@ -502,120 +747,140 @@ mod test_ipv4 { #[test] fn test_set_hop_limit_v4() { let mut s = socket(buffer(0), buffer(1)); - let caps = DeviceCapabilities::default(); + let mut cx = Context::mock(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 24]; let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); - ECHOV4_REPR.emit(&mut packet, &caps.checksum); + ECHOV4_REPR.emit(&mut packet, &checksum); s.set_hop_limit(Some(0x2a)); - assert_eq!(s.send_slice(&packet.into_inner()[..], REMOTE_IPV4.into()), Ok(())); - assert_eq!(s.dispatch(&caps, |(ip_repr, _)| { - assert_eq!(ip_repr, IpRepr::Ipv4(Ipv4Repr { - src_addr: Ipv4Address::UNSPECIFIED, - dst_addr: REMOTE_IPV4, - protocol: IpProtocol::Icmp, - payload_len: ECHOV4_REPR.buffer_len(), - hop_limit: 0x2a, - })); + assert_eq!( + s.send_slice(&*packet.into_inner(), REMOTE_IPV4.into()), + Ok(()) + ); + assert_eq!( + s.dispatch(&mut cx, |_, (ip_repr, _)| { + assert_eq!( + ip_repr, + IpRepr::Ipv4(Ipv4Repr { + src_addr: LOCAL_IPV4, + dst_addr: REMOTE_IPV4, + next_header: IpProtocol::Icmp, + payload_len: ECHOV4_REPR.buffer_len(), + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); } #[test] fn test_recv_process() { let mut socket = socket(buffer(1), buffer(1)); + let mut cx = Context::mock(); assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); assert!(!socket.can_recv()); - assert_eq!(socket.recv(), Err(Error::Exhausted)); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 24]; - let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); - ECHOV4_REPR.emit(&mut packet, &caps.checksum); - let data = &packet.into_inner()[..]; + let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); + ECHOV4_REPR.emit(&mut packet, &checksum); + let data = &*packet.into_inner(); - assert!(socket.accepts(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &caps.checksum)); - assert_eq!(socket.process(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &caps.checksum), - Ok(())); + assert!(socket.accepts(&mut cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + socket.process(&mut cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()); assert!(socket.can_recv()); - assert!(socket.accepts(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &caps.checksum)); - assert_eq!(socket.process(&REMOTE_IPV4_REPR, &ECHOV4_REPR.into(), &caps.checksum), - Err(Error::Exhausted)); + assert!(socket.accepts(&mut cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into())); + socket.process(&mut cx, &REMOTE_IPV4_REPR, &ECHOV4_REPR.into()); - assert_eq!(socket.recv(), Ok((&data[..], REMOTE_IPV4.into()))); + assert_eq!(socket.recv(), Ok((data, REMOTE_IPV4.into()))); assert!(!socket.can_recv()); } #[test] fn test_accept_bad_id() { let mut socket = socket(buffer(1), buffer(1)); + let mut cx = Context::mock(); assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 20]; let mut packet = Icmpv4Packet::new_unchecked(&mut bytes); let icmp_repr = Icmpv4Repr::EchoRequest { - ident: 0x4321, + ident: 0x4321, seq_no: 0x5678, - data: &[0xff; 16] + data: &[0xff; 16], }; - icmp_repr.emit(&mut packet, &caps.checksum); + icmp_repr.emit(&mut packet, &checksum); // Ensure that a packet with an identifier that isn't the bound // ID is not accepted - assert!(!socket.accepts(&REMOTE_IPV4_REPR, &icmp_repr.into(), &caps.checksum)); + assert!(!socket.accepts(&mut cx, &REMOTE_IPV4_REPR, &icmp_repr.into())); } #[test] fn test_accepts_udp() { let mut socket = socket(buffer(1), buffer(1)); - assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4)), Ok(())); + let mut cx = Context::mock(); + assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V4.into())), Ok(())); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 18]; let mut packet = UdpPacket::new_unchecked(&mut bytes); - UDP_REPR.emit(&mut packet, &REMOTE_IPV4.into(), &LOCAL_IPV4.into(), &caps.checksum); + UDP_REPR.emit( + &mut packet, + &REMOTE_IPV4.into(), + &LOCAL_IPV4.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(UDP_PAYLOAD), + &checksum, + ); - let data = &packet.into_inner()[..]; + let data = &*packet.into_inner(); let icmp_repr = Icmpv4Repr::DstUnreachable { reason: Icmpv4DstUnreachable::PortUnreachable, header: Ipv4Repr { src_addr: LOCAL_IPV4, dst_addr: REMOTE_IPV4, - protocol: IpProtocol::Icmp, + next_header: IpProtocol::Icmp, payload_len: 12, - hop_limit: 0x40 + hop_limit: 0x40, }, - data: data + data: data, }; - let ip_repr = IpRepr::Unspecified { - src_addr: REMOTE_IPV4.into(), - dst_addr: LOCAL_IPV4.into(), - protocol: IpProtocol::Icmp, + let ip_repr = IpRepr::Ipv4(Ipv4Repr { + src_addr: REMOTE_IPV4, + dst_addr: LOCAL_IPV4, + next_header: IpProtocol::Icmp, payload_len: icmp_repr.buffer_len(), - hop_limit: 0x40 - }; + hop_limit: 0x40, + }); assert!(!socket.can_recv()); // Ensure we can accept ICMP error response to the bound // UDP port - assert!(socket.accepts(&ip_repr, &icmp_repr.into(), &caps.checksum)); - assert_eq!(socket.process(&ip_repr, &icmp_repr.into(), &caps.checksum), - Ok(())); + assert!(socket.accepts(&mut cx, &ip_repr, &icmp_repr.into())); + socket.process(&mut cx, &ip_repr, &icmp_repr.into()); assert!(socket.can_recv()); let mut bytes = [0x00; 46]; let mut packet = Icmpv4Packet::new_unchecked(&mut bytes[..]); - icmp_repr.emit(&mut packet, &caps.checksum); - assert_eq!(socket.recv(), Ok((&packet.into_inner()[..], REMOTE_IPV4.into()))); + icmp_repr.emit(&mut packet, &checksum); + assert_eq!( + socket.recv(), + Ok((&*packet.into_inner(), REMOTE_IPV4.into())) + ); assert!(!socket.can_recv()); } } @@ -624,25 +889,28 @@ mod test_ipv4 { mod test_ipv6 { use super::tests_common::*; - use wire::Icmpv6DstUnreachable; + use crate::wire::{Icmpv6DstUnreachable, IpEndpoint, Ipv6Address}; - const REMOTE_IPV6: Ipv6Address = Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1]); - const LOCAL_IPV6: Ipv6Address = Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 2]); - const LOCAL_END_V6: IpEndpoint = IpEndpoint { addr: IpAddress::Ipv6(LOCAL_IPV6), port: LOCAL_PORT }; + const REMOTE_IPV6: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]); + const LOCAL_IPV6: Ipv6Address = + Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); + const LOCAL_END_V6: IpEndpoint = IpEndpoint { + addr: IpAddress::Ipv6(LOCAL_IPV6), + port: LOCAL_PORT, + }; static ECHOV6_REPR: Icmpv6Repr = Icmpv6Repr::EchoRequest { - ident: 0x1234, - seq_no: 0x5678, - data: &[0xff; 16] + ident: 0x1234, + seq_no: 0x5678, + data: &[0xff; 16], }; static LOCAL_IPV6_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::UNSPECIFIED, + src_addr: LOCAL_IPV6, dst_addr: REMOTE_IPV6, next_header: IpProtocol::Icmpv6, payload_len: 24, - hop_limit: 0x40 + hop_limit: 0x40, }); static REMOTE_IPV6_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { @@ -650,50 +918,75 @@ mod test_ipv6 { dst_addr: LOCAL_IPV6, next_header: IpProtocol::Icmpv6, payload_len: 24, - hop_limit: 0x40 + hop_limit: 0x40, }); #[test] fn test_send_unaddressable() { let mut socket = socket(buffer(0), buffer(1)); - assert_eq!(socket.send_slice(b"abcdef", IpAddress::default()), - Err(Error::Unaddressable)); + assert_eq!( + socket.send_slice(b"abcdef", IpAddress::Ipv6(Ipv6Address::default())), + Err(SendError::Unaddressable) + ); assert_eq!(socket.send_slice(b"abcdef", REMOTE_IPV6.into()), Ok(())); } #[test] fn test_send_dispatch() { let mut socket = socket(buffer(0), buffer(1)); - let caps = DeviceCapabilities::default(); + let mut cx = Context::mock(); + let checksum = ChecksumCapabilities::default(); - assert_eq!(socket.dispatch(&caps, |_| unreachable!()), - Err(Error::Exhausted)); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); // This buffer is too long - assert_eq!(socket.send_slice(&[0xff; 67], REMOTE_IPV6.into()), Err(Error::Truncated)); + assert_eq!( + socket.send_slice(&[0xff; 67], REMOTE_IPV6.into()), + Err(SendError::BufferFull) + ); assert!(socket.can_send()); let mut bytes = vec![0xff; 24]; let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); - ECHOV6_REPR.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &caps.checksum); - - assert_eq!(socket.send_slice(&packet.into_inner()[..], REMOTE_IPV6.into()), Ok(())); - assert_eq!(socket.send_slice(b"123456", REMOTE_IPV6.into()), Err(Error::Exhausted)); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + + assert_eq!( + socket.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()), + Ok(()) + ); + assert_eq!( + socket.send_slice(b"123456", REMOTE_IPV6.into()), + Err(SendError::BufferFull) + ); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&caps, |(ip_repr, icmp_repr)| { - assert_eq!(ip_repr, LOCAL_IPV6_REPR); - assert_eq!(icmp_repr, ECHOV6_REPR.into()); - Err(Error::Unaddressable) - }), Err(Error::Unaddressable)); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV6_REPR); + assert_eq!(icmp_repr, ECHOV6_REPR.into()); + Err(()) + }), + Err(()) + ); // buffer is not taken off of the tx queue due to the error assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&caps, |(ip_repr, icmp_repr)| { - assert_eq!(ip_repr, LOCAL_IPV6_REPR); - assert_eq!(icmp_repr, ECHOV6_REPR.into()); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, icmp_repr)| { + assert_eq!(ip_repr, LOCAL_IPV6_REPR); + assert_eq!(icmp_repr, ECHOV6_REPR.into()); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); // buffer is taken off of the queue this time assert!(socket.can_send()); } @@ -701,87 +994,120 @@ mod test_ipv6 { #[test] fn test_set_hop_limit() { let mut s = socket(buffer(0), buffer(1)); - let caps = DeviceCapabilities::default(); + let mut cx = Context::mock(); + let checksum = ChecksumCapabilities::default(); let mut bytes = vec![0xff; 24]; let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); - ECHOV6_REPR.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &caps.checksum); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); s.set_hop_limit(Some(0x2a)); - assert_eq!(s.send_slice(&packet.into_inner()[..], REMOTE_IPV6.into()), Ok(())); - assert_eq!(s.dispatch(&caps, |(ip_repr, _)| { - assert_eq!(ip_repr, IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address::UNSPECIFIED, - dst_addr: REMOTE_IPV6, - next_header: IpProtocol::Icmpv6, - payload_len: ECHOV6_REPR.buffer_len(), - hop_limit: 0x2a, - })); + assert_eq!( + s.send_slice(&*packet.into_inner(), REMOTE_IPV6.into()), + Ok(()) + ); + assert_eq!( + s.dispatch(&mut cx, |_, (ip_repr, _)| { + assert_eq!( + ip_repr, + IpRepr::Ipv6(Ipv6Repr { + src_addr: LOCAL_IPV6, + dst_addr: REMOTE_IPV6, + next_header: IpProtocol::Icmpv6, + payload_len: ECHOV6_REPR.buffer_len(), + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); } #[test] fn test_recv_process() { let mut socket = socket(buffer(1), buffer(1)); + let mut cx = Context::mock(); assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); assert!(!socket.can_recv()); - assert_eq!(socket.recv(), Err(Error::Exhausted)); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 24]; - let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); - ECHOV6_REPR.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &caps.checksum); - let data = &packet.into_inner()[..]; - - assert!(socket.accepts(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &caps.checksum)); - assert_eq!(socket.process(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &caps.checksum), - Ok(())); + let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); + ECHOV6_REPR.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + let data = &*packet.into_inner(); + + assert!(socket.accepts(&mut cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(&mut cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); assert!(socket.can_recv()); - assert!(socket.accepts(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &caps.checksum)); - assert_eq!(socket.process(&REMOTE_IPV6_REPR, &ECHOV6_REPR.into(), &caps.checksum), - Err(Error::Exhausted)); + assert!(socket.accepts(&mut cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into())); + socket.process(&mut cx, &REMOTE_IPV6_REPR, &ECHOV6_REPR.into()); - assert_eq!(socket.recv(), Ok((&data[..], REMOTE_IPV6.into()))); + assert_eq!(socket.recv(), Ok((data, REMOTE_IPV6.into()))); assert!(!socket.can_recv()); } #[test] fn test_accept_bad_id() { let mut socket = socket(buffer(1), buffer(1)); + let mut cx = Context::mock(); assert_eq!(socket.bind(Endpoint::Ident(0x1234)), Ok(())); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 20]; let mut packet = Icmpv6Packet::new_unchecked(&mut bytes); let icmp_repr = Icmpv6Repr::EchoRequest { - ident: 0x4321, + ident: 0x4321, seq_no: 0x5678, - data: &[0xff; 16] + data: &[0xff; 16], }; - icmp_repr.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &caps.checksum); + icmp_repr.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); // Ensure that a packet with an identifier that isn't the bound // ID is not accepted - assert!(!socket.accepts(&REMOTE_IPV6_REPR, &icmp_repr.into(), &caps.checksum)); + assert!(!socket.accepts(&mut cx, &REMOTE_IPV6_REPR, &icmp_repr.into())); } #[test] fn test_accepts_udp() { let mut socket = socket(buffer(1), buffer(1)); - assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6)), Ok(())); + let mut cx = Context::mock(); + assert_eq!(socket.bind(Endpoint::Udp(LOCAL_END_V6.into())), Ok(())); - let caps = DeviceCapabilities::default(); + let checksum = ChecksumCapabilities::default(); let mut bytes = [0xff; 18]; let mut packet = UdpPacket::new_unchecked(&mut bytes); - UDP_REPR.emit(&mut packet, &REMOTE_IPV6.into(), &LOCAL_IPV6.into(), &caps.checksum); + UDP_REPR.emit( + &mut packet, + &REMOTE_IPV6.into(), + &LOCAL_IPV6.into(), + UDP_PAYLOAD.len(), + |buf| buf.copy_from_slice(UDP_PAYLOAD), + &checksum, + ); - let data = &packet.into_inner()[..]; + let data = &*packet.into_inner(); let icmp_repr = Icmpv6Repr::DstUnreachable { reason: Icmpv6DstUnreachable::PortUnreachable, @@ -790,31 +1116,38 @@ mod test_ipv6 { dst_addr: REMOTE_IPV6, next_header: IpProtocol::Icmpv6, payload_len: 12, - hop_limit: 0x40 + hop_limit: 0x40, }, - data: data + data: data, }; - let ip_repr = IpRepr::Unspecified { - src_addr: REMOTE_IPV6.into(), - dst_addr: LOCAL_IPV6.into(), - protocol: IpProtocol::Icmpv6, + let ip_repr = IpRepr::Ipv6(Ipv6Repr { + src_addr: REMOTE_IPV6, + dst_addr: LOCAL_IPV6, + next_header: IpProtocol::Icmpv6, payload_len: icmp_repr.buffer_len(), - hop_limit: 0x40 - }; + hop_limit: 0x40, + }); assert!(!socket.can_recv()); // Ensure we can accept ICMP error response to the bound // UDP port - assert!(socket.accepts(&ip_repr, &icmp_repr.into(), &caps.checksum)); - assert_eq!(socket.process(&ip_repr, &icmp_repr.into(), &caps.checksum), - Ok(())); + assert!(socket.accepts(&mut cx, &ip_repr, &icmp_repr.into())); + socket.process(&mut cx, &ip_repr, &icmp_repr.into()); assert!(socket.can_recv()); let mut bytes = [0x00; 66]; let mut packet = Icmpv6Packet::new_unchecked(&mut bytes[..]); - icmp_repr.emit(&LOCAL_IPV6.into(), &REMOTE_IPV6.into(), &mut packet, &caps.checksum); - assert_eq!(socket.recv(), Ok((&packet.into_inner()[..], REMOTE_IPV6.into()))); + icmp_repr.emit( + &LOCAL_IPV6.into(), + &REMOTE_IPV6.into(), + &mut packet, + &checksum, + ); + assert_eq!( + socket.recv(), + Ok((&*packet.into_inner(), REMOTE_IPV6.into())) + ); assert!(!socket.can_recv()); } } diff --git a/src/socket/meta.rs b/src/socket/meta.rs deleted file mode 100644 index 5ec9d74ad..000000000 --- a/src/socket/meta.rs +++ /dev/null @@ -1,87 +0,0 @@ -use wire::IpAddress; -use super::{SocketHandle, PollAt}; -use time::{Duration, Instant}; - -/// Neighbor dependency. -/// -/// This enum tracks whether the socket should be polled based on the neighbor it is -/// going to send packets to. -#[derive(Debug)] -enum NeighborState { - /// Socket can be polled immediately. - Active, - /// Socket should not be polled until either `silent_until` passes or `neighbor` appears - /// in the neighbor cache. - Waiting { - neighbor: IpAddress, - silent_until: Instant, - } -} - -impl Default for NeighborState { - fn default() -> Self { - NeighborState::Active - } -} - -/// Network socket metadata. -/// -/// This includes things that only external (to the socket, that is) code -/// is interested in, but which are more conveniently stored inside the socket itself. -#[derive(Debug, Default)] -pub struct Meta { - /// Handle of this socket within its enclosing `SocketSet`. - /// Mainly useful for debug output. - pub(crate) handle: SocketHandle, - /// See [NeighborState](struct.NeighborState.html). - neighbor_state: NeighborState, -} - -impl Meta { - /// Minimum delay between neighbor discovery requests for this particular socket, - /// in milliseconds. - /// - /// See also `iface::NeighborCache::SILENT_TIME`. - pub(crate) const DISCOVERY_SILENT_TIME: Duration = Duration { millis: 3_000 }; - - pub(crate) fn poll_at(&self, socket_poll_at: PollAt, has_neighbor: F) -> PollAt - where F: Fn(IpAddress) -> bool - { - match self.neighbor_state { - NeighborState::Active => - socket_poll_at, - NeighborState::Waiting { neighbor, .. } - if has_neighbor(neighbor) => - socket_poll_at, - NeighborState::Waiting { silent_until, .. } => - PollAt::Time(silent_until) - } - } - - pub(crate) fn egress_permitted(&mut self, has_neighbor: F) -> bool - where F: Fn(IpAddress) -> bool - { - match self.neighbor_state { - NeighborState::Active => - true, - NeighborState::Waiting { neighbor, .. } => { - if has_neighbor(neighbor) { - net_trace!("{}: neighbor {} discovered, unsilencing", - self.handle, neighbor); - self.neighbor_state = NeighborState::Active; - true - } else { - false - } - } - } - } - - pub(crate) fn neighbor_missing(&mut self, timestamp: Instant, neighbor: IpAddress) { - net_trace!("{}: neighbor {} missing, silencing until t+{}", - self.handle, neighbor, Self::DISCOVERY_SILENT_TIME); - self.neighbor_state = NeighborState::Waiting { - neighbor, silent_until: timestamp + Self::DISCOVERY_SILENT_TIME - }; - } -} diff --git a/src/socket/mod.rs b/src/socket/mod.rs index d35724b82..7d48b4234 100644 --- a/src/socket/mod.rs +++ b/src/socket/mod.rs @@ -11,54 +11,33 @@ The interface implemented by this module uses explicit buffering: you decide on size for a buffer, allocate it, and let the networking stack use it. */ -use core::marker::PhantomData; -use time::Instant; - -mod meta; +use crate::iface::Context; +use crate::time::Instant; + +#[cfg(feature = "socket-dhcpv4")] +pub mod dhcpv4; +#[cfg(feature = "socket-dns")] +pub mod dns; +#[cfg(feature = "socket-icmp")] +pub mod icmp; #[cfg(feature = "socket-raw")] -mod raw; -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -mod icmp; -#[cfg(feature = "socket-udp")] -mod udp; +pub mod raw; #[cfg(feature = "socket-tcp")] -mod tcp; -mod set; -mod ref_; - -pub(crate) use self::meta::Meta as SocketMeta; - -#[cfg(feature = "socket-raw")] -pub use self::raw::{RawPacketMetadata, - RawSocketBuffer, - RawSocket}; - -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -pub use self::icmp::{IcmpPacketMetadata, - IcmpSocketBuffer, - Endpoint as IcmpEndpoint, - IcmpSocket}; - +pub mod tcp; #[cfg(feature = "socket-udp")] -pub use self::udp::{UdpPacketMetadata, - UdpSocketBuffer, - UdpSocket}; +pub mod udp; -#[cfg(feature = "socket-tcp")] -pub use self::tcp::{SocketBuffer as TcpSocketBuffer, - State as TcpState, - TcpSocket}; +#[cfg(feature = "async")] +mod waker; -pub use self::set::{Set as SocketSet, Item as SocketSetItem, Handle as SocketHandle}; -pub use self::set::{Iter as SocketSetIter, IterMut as SocketSetIterMut}; - -pub use self::ref_::Ref as SocketRef; -pub(crate) use self::ref_::Session as SocketSession; +#[cfg(feature = "async")] +pub(crate) use self::waker::WakerRegistration; /// Gives an indication on the next time the socket should be polled. #[derive(Debug, PartialOrd, Ord, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub(crate) enum PollAt { - /// The socket needs to be polled immidiately. + /// The socket needs to be polled immediately. Now, /// The socket needs to be polled at given [Instant][struct.Instant]. Time(Instant), @@ -66,113 +45,97 @@ pub(crate) enum PollAt { Ingress, } -impl PollAt { - #[cfg(feature = "socket-tcp")] - fn is_ingress(&self) -> bool { - match self { - &PollAt::Ingress => true, - _ => false, - } - } -} - /// A network socket. /// /// This enumeration abstracts the various types of sockets based on the IP protocol. /// To downcast a `Socket` value to a concrete socket, use the [AnySocket] trait, -/// e.g. to get `UdpSocket`, call `UdpSocket::downcast(socket)`. +/// e.g. to get `udp::Socket`, call `udp::Socket::downcast(socket)`. /// /// It is usually more convenient to use [SocketSet::get] instead. /// /// [AnySocket]: trait.AnySocket.html /// [SocketSet::get]: struct.SocketSet.html#method.get #[derive(Debug)] -pub enum Socket<'a, 'b: 'a> { +pub enum Socket<'a> { #[cfg(feature = "socket-raw")] - Raw(RawSocket<'a, 'b>), - #[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] - Icmp(IcmpSocket<'a, 'b>), + Raw(raw::Socket<'a>), + #[cfg(feature = "socket-icmp")] + Icmp(icmp::Socket<'a>), #[cfg(feature = "socket-udp")] - Udp(UdpSocket<'a, 'b>), + Udp(udp::Socket<'a>), #[cfg(feature = "socket-tcp")] - Tcp(TcpSocket<'a>), - #[doc(hidden)] - __Nonexhaustive(PhantomData<(&'a (), &'b ())>) + Tcp(tcp::Socket<'a>), + #[cfg(feature = "socket-dhcpv4")] + Dhcpv4(dhcpv4::Socket<'a>), + #[cfg(feature = "socket-dns")] + Dns(dns::Socket<'a>), } -macro_rules! dispatch_socket { - ($self_:expr, |$socket:ident| $code:expr) => { - dispatch_socket!(@inner $self_, |$socket| $code); - }; - (mut $self_:expr, |$socket:ident| $code:expr) => { - dispatch_socket!(@inner mut $self_, |$socket| $code); - }; - (@inner $( $mut_:ident )* $self_:expr, |$socket:ident| $code:expr) => { - match $self_ { +impl<'a> Socket<'a> { + pub(crate) fn poll_at(&self, cx: &mut Context) -> PollAt { + match self { #[cfg(feature = "socket-raw")] - &$( $mut_ )* Socket::Raw(ref $( $mut_ )* $socket) => $code, - #[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] - &$( $mut_ )* Socket::Icmp(ref $( $mut_ )* $socket) => $code, + Socket::Raw(s) => s.poll_at(cx), + #[cfg(feature = "socket-icmp")] + Socket::Icmp(s) => s.poll_at(cx), #[cfg(feature = "socket-udp")] - &$( $mut_ )* Socket::Udp(ref $( $mut_ )* $socket) => $code, + Socket::Udp(s) => s.poll_at(cx), #[cfg(feature = "socket-tcp")] - &$( $mut_ )* Socket::Tcp(ref $( $mut_ )* $socket) => $code, - &$( $mut_ )* Socket::__Nonexhaustive(_) => unreachable!() + Socket::Tcp(s) => s.poll_at(cx), + #[cfg(feature = "socket-dhcpv4")] + Socket::Dhcpv4(s) => s.poll_at(cx), + #[cfg(feature = "socket-dns")] + Socket::Dns(s) => s.poll_at(cx), } - }; -} - -impl<'a, 'b> Socket<'a, 'b> { - /// Return the socket handle. - #[inline] - pub fn handle(&self) -> SocketHandle { - self.meta().handle - } - - pub(crate) fn meta(&self) -> &SocketMeta { - dispatch_socket!(self, |socket| &socket.meta) - } - - pub(crate) fn meta_mut(&mut self) -> &mut SocketMeta { - dispatch_socket!(mut self, |socket| &mut socket.meta) - } - - pub(crate) fn poll_at(&self) -> PollAt { - dispatch_socket!(self, |socket| socket.poll_at()) - } -} - -impl<'a, 'b> SocketSession for Socket<'a, 'b> { - fn finish(&mut self) { - dispatch_socket!(mut self, |socket| socket.finish()) } } /// A conversion trait for network sockets. -pub trait AnySocket<'a, 'b>: SocketSession + Sized { - fn downcast<'c>(socket_ref: SocketRef<'c, Socket<'a, 'b>>) -> - Option>; +pub trait AnySocket<'a> { + fn upcast(self) -> Socket<'a>; + fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self> + where + Self: Sized; + fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> + where + Self: Sized; } macro_rules! from_socket { ($socket:ty, $variant:ident) => { - impl<'a, 'b> AnySocket<'a, 'b> for $socket { - fn downcast<'c>(ref_: SocketRef<'c, Socket<'a, 'b>>) -> - Option> { - match SocketRef::into_inner(ref_) { - &mut Socket::$variant(ref mut socket) => Some(SocketRef::new(socket)), + impl<'a> AnySocket<'a> for $socket { + fn upcast(self) -> Socket<'a> { + Socket::$variant(self) + } + + fn downcast<'c>(socket: &'c Socket<'a>) -> Option<&'c Self> { + #[allow(unreachable_patterns)] + match socket { + Socket::$variant(socket) => Some(socket), + _ => None, + } + } + + fn downcast_mut<'c>(socket: &'c mut Socket<'a>) -> Option<&'c mut Self> { + #[allow(unreachable_patterns)] + match socket { + Socket::$variant(socket) => Some(socket), _ => None, } } } - } + }; } #[cfg(feature = "socket-raw")] -from_socket!(RawSocket<'a, 'b>, Raw); -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -from_socket!(IcmpSocket<'a, 'b>, Icmp); +from_socket!(raw::Socket<'a>, Raw); +#[cfg(feature = "socket-icmp")] +from_socket!(icmp::Socket<'a>, Icmp); #[cfg(feature = "socket-udp")] -from_socket!(UdpSocket<'a, 'b>, Udp); +from_socket!(udp::Socket<'a>, Udp); #[cfg(feature = "socket-tcp")] -from_socket!(TcpSocket<'a>, Tcp); +from_socket!(tcp::Socket<'a>, Tcp); +#[cfg(feature = "socket-dhcpv4")] +from_socket!(dhcpv4::Socket<'a>, Dhcpv4); +#[cfg(feature = "socket-dns")] +from_socket!(dns::Socket<'a>, Dns); diff --git a/src/socket/raw.rs b/src/socket/raw.rs index 336076c7d..98282c9cd 100644 --- a/src/socket/raw.rs +++ b/src/socket/raw.rs @@ -1,53 +1,151 @@ use core::cmp::min; +#[cfg(feature = "async")] +use core::task::Waker; -use {Error, Result}; -use phy::ChecksumCapabilities; -use socket::{Socket, SocketMeta, SocketHandle, PollAt}; -use storage::{PacketBuffer, PacketMetadata}; -use wire::{IpVersion, IpRepr, IpProtocol}; +use crate::iface::Context; +use crate::socket::PollAt; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; + +use crate::storage::Empty; +use crate::wire::{IpProtocol, IpRepr, IpVersion}; #[cfg(feature = "proto-ipv4")] -use wire::{Ipv4Repr, Ipv4Packet}; +use crate::wire::{Ipv4Packet, Ipv4Repr}; #[cfg(feature = "proto-ipv6")] -use wire::{Ipv6Repr, Ipv6Packet}; +use crate::wire::{Ipv6Packet, Ipv6Repr}; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} /// A UDP packet metadata. -pub type RawPacketMetadata = PacketMetadata<()>; +pub type PacketMetadata = crate::storage::PacketMetadata<()>; /// A UDP packet ring buffer. -pub type RawSocketBuffer<'a, 'b> = PacketBuffer<'a, 'b, ()>; +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, ()>; /// A raw IP socket. /// /// A raw socket is bound to a specific IP protocol, and owns /// transmit and receive packet buffers. #[derive(Debug)] -pub struct RawSocket<'a, 'b: 'a> { - pub(crate) meta: SocketMeta, - ip_version: IpVersion, +pub struct Socket<'a> { + ip_version: IpVersion, ip_protocol: IpProtocol, - rx_buffer: RawSocketBuffer<'a, 'b>, - tx_buffer: RawSocketBuffer<'a, 'b>, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, } -impl<'a, 'b> RawSocket<'a, 'b> { +impl<'a> Socket<'a> { /// Create a raw IP socket bound to the given IP version and datagram protocol, /// with the given buffers. - pub fn new(ip_version: IpVersion, ip_protocol: IpProtocol, - rx_buffer: RawSocketBuffer<'a, 'b>, - tx_buffer: RawSocketBuffer<'a, 'b>) -> RawSocket<'a, 'b> { - RawSocket { - meta: SocketMeta::default(), + pub fn new( + ip_version: IpVersion, + ip_protocol: IpProtocol, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, + ) -> Socket<'a> { + Socket { ip_version, ip_protocol, rx_buffer, tx_buffer, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), } } - /// Return the socket handle. - #[inline] - pub fn handle(&self) -> SocketHandle { - self.meta.handle + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) } /// Return the IP version the socket is bound to. @@ -107,21 +205,50 @@ impl<'a, 'b> RawSocket<'a, 'b> { /// If the buffer is filled in a way that does not match the socket's /// IP version or protocol, the packet will be silently dropped. /// - /// **Note:** The IP header is parsed and reserialized, and may not match + /// **Note:** The IP header is parsed and re-serialized, and may not match /// the header actually transmitted bit for bit. - pub fn send(&mut self, size: usize) -> Result<&mut [u8]> { - let packet_buf = self.tx_buffer.enqueue(size, ())?; + pub fn send(&mut self, size: usize) -> Result<&mut [u8], SendError> { + let packet_buf = self + .tx_buffer + .enqueue(size, ()) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "raw:{}:{}: buffer to send {} octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); + Ok(packet_buf) + } - net_trace!("{}:{}:{}: buffer to send {} octets", - self.meta.handle, self.ip_version, self.ip_protocol, - packet_buf.len()); - Ok(packet_buf.as_mut()) + /// Enqueue a packet to be send and pass the buffer to the provided closure. + /// The closure then returns the size of the data written into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with(&mut self, max_size: usize, f: F) -> Result + where + F: FnOnce(&mut [u8]) -> usize, + { + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, (), f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "raw:{}:{}: buffer to send {} octets", + self.ip_version, + self.ip_protocol, + size + ); + + Ok(size) } /// Enqueue a packet to send, and fill it from a slice. /// /// See also [send](#method.send). - pub fn send_slice(&mut self, data: &[u8]) -> Result<()> { + pub fn send_slice(&mut self, data: &[u8]) -> Result<(), SendError> { self.send(data.len())?.copy_from_slice(data); Ok(()) } @@ -130,60 +257,122 @@ impl<'a, 'b> RawSocket<'a, 'b> { /// /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. /// - /// **Note:** The IP header is parsed and reserialized, and may not match + /// **Note:** The IP header is parsed and re-serialized, and may not match /// the header actually received bit for bit. - pub fn recv(&mut self) -> Result<&[u8]> { - let ((), packet_buf) = self.rx_buffer.dequeue()?; - - net_trace!("{}:{}:{}: receive {} buffered octets", - self.meta.handle, self.ip_version, self.ip_protocol, - packet_buf.len()); + pub fn recv(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "raw:{}:{}: receive {} buffered octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); Ok(packet_buf) } /// Dequeue a packet, and copy the payload into the given slice. /// /// See also [recv](#method.recv). - pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { let buffer = self.recv()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok(length) } + /// Peek at a packet in the receive buffer and return a pointer to the + /// payload without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv](#method.recv). + /// + /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. + pub fn peek(&mut self) -> Result<&[u8], RecvError> { + let ((), packet_buf) = self.rx_buffer.peek().map_err(|_| RecvError::Exhausted)?; + + net_trace!( + "raw:{}:{}: receive {} buffered octets", + self.ip_version, + self.ip_protocol, + packet_buf.len() + ); + + Ok(packet_buf) + } + + /// Peek at a packet in the receive buffer, copy the payload into the given slice, + /// and return the amount of octets copied without removing the packet from the receive buffer. + /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). + /// + /// See also [peek](#method.peek). + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result { + let buffer = self.peek()?; + let length = min(data.len(), buffer.len()); + data[..length].copy_from_slice(&buffer[..length]); + Ok(length) + } + pub(crate) fn accepts(&self, ip_repr: &IpRepr) -> bool { - if ip_repr.version() != self.ip_version { return false } - if ip_repr.protocol() != self.ip_protocol { return false } + if ip_repr.version() != self.ip_version { + return false; + } + if ip_repr.next_header() != self.ip_protocol { + return false; + } true } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, payload: &[u8], - checksum_caps: &ChecksumCapabilities) -> Result<()> { + pub(crate) fn process(&mut self, cx: &mut Context, ip_repr: &IpRepr, payload: &[u8]) { debug_assert!(self.accepts(ip_repr)); - let header_len = ip_repr.buffer_len(); - let total_len = header_len + payload.len(); - let packet_buf = self.rx_buffer.enqueue(total_len, ())?; - ip_repr.emit(&mut packet_buf.as_mut()[..header_len], &checksum_caps); - packet_buf.as_mut()[header_len..].copy_from_slice(payload); + let header_len = ip_repr.header_len(); + let total_len = header_len + payload.len(); - net_trace!("{}:{}:{}: receiving {} octets", - self.meta.handle, self.ip_version, self.ip_protocol, - packet_buf.len()); - Ok(()) + net_trace!( + "raw:{}:{}: receiving {} octets", + self.ip_version, + self.ip_protocol, + total_len + ); + + match self.rx_buffer.enqueue(total_len, ()) { + Ok(buf) => { + ip_repr.emit(&mut buf[..header_len], &cx.checksum_caps()); + buf[header_len..].copy_from_slice(payload); + } + Err(_) => net_trace!( + "raw:{}:{}: buffer full, dropped incoming packet", + self.ip_version, + self.ip_protocol + ), + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); } - pub(crate) fn dispatch(&mut self, checksum_caps: &ChecksumCapabilities, emit: F) -> - Result<()> - where F: FnOnce((IpRepr, &[u8])) -> Result<()> { - fn prepare<'a>(protocol: IpProtocol, buffer: &'a mut [u8], - _checksum_caps: &ChecksumCapabilities) -> Result<(IpRepr, &'a [u8])> { - match IpVersion::of_packet(buffer.as_ref())? { + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, &[u8])) -> Result<(), E>, + { + let ip_protocol = self.ip_protocol; + let ip_version = self.ip_version; + let _checksum_caps = &cx.checksum_caps(); + let res = self.tx_buffer.dequeue_with(|&mut (), buffer| { + match IpVersion::of_packet(buffer) { #[cfg(feature = "proto-ipv4")] - IpVersion::Ipv4 => { - let mut packet = Ipv4Packet::new_checked(buffer.as_mut())?; - if packet.protocol() != protocol { return Err(Error::Unaddressable) } + Ok(IpVersion::Ipv4) => { + let mut packet = match Ipv4Packet::new_checked(buffer) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + if packet.next_header() != ip_protocol { + net_trace!("raw: sent packet with wrong ip protocol, dropping."); + return Ok(()); + } if _checksum_caps.ipv4.tx() { packet.fill_checksum(); } else { @@ -192,46 +381,60 @@ impl<'a, 'b> RawSocket<'a, 'b> { packet.set_checksum(0); } - let packet = Ipv4Packet::new_checked(&*packet.into_inner())?; - let ipv4_repr = Ipv4Repr::parse(&packet, _checksum_caps)?; - Ok((IpRepr::Ipv4(ipv4_repr), packet.payload())) + let packet = Ipv4Packet::new_unchecked(&*packet.into_inner()); + let ipv4_repr = match Ipv4Repr::parse(&packet, _checksum_caps) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv4 packet in queue, dropping."); + return Ok(()); + } + }; + net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + emit(cx, (IpRepr::Ipv4(ipv4_repr), packet.payload())) } #[cfg(feature = "proto-ipv6")] - IpVersion::Ipv6 => { - let packet = Ipv6Packet::new_checked(buffer.as_mut())?; - if packet.next_header() != protocol { return Err(Error::Unaddressable) } + Ok(IpVersion::Ipv6) => { + let packet = match Ipv6Packet::new_checked(buffer) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + if packet.next_header() != ip_protocol { + net_trace!("raw: sent ipv6 packet with wrong ip protocol, dropping."); + return Ok(()); + } let packet = Ipv6Packet::new_unchecked(&*packet.into_inner()); - let ipv6_repr = Ipv6Repr::parse(&packet)?; - Ok((IpRepr::Ipv6(ipv6_repr), packet.payload())) - } - IpVersion::Unspecified => unreachable!(), - IpVersion::__Nonexhaustive => unreachable!() - } - } - - let handle = self.meta.handle; - let ip_protocol = self.ip_protocol; - let ip_version = self.ip_version; - self.tx_buffer.dequeue_with(|&mut (), packet_buf| { - match prepare(ip_protocol, packet_buf.as_mut(), &checksum_caps) { - Ok((ip_repr, raw_packet)) => { - net_trace!("{}:{}:{}: sending {} octets", - handle, ip_version, ip_protocol, - ip_repr.buffer_len() + raw_packet.len()); - emit((ip_repr, raw_packet)) + let ipv6_repr = match Ipv6Repr::parse(&packet) { + Ok(x) => x, + Err(_) => { + net_trace!("raw: malformed ipv6 packet in queue, dropping."); + return Ok(()); + } + }; + + net_trace!("raw:{}:{}: sending", ip_version, ip_protocol); + emit(cx, (IpRepr::Ipv6(ipv6_repr), packet.payload())) } - Err(error) => { - net_debug!("{}:{}:{}: dropping outgoing packet ({})", - handle, ip_version, ip_protocol, - error); - // Return Ok(()) so the packet is dequeued. + Err(_) => { + net_trace!("raw: sent packet with invalid IP version, dropping."); Ok(()) } } - }) + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -240,96 +443,89 @@ impl<'a, 'b> RawSocket<'a, 'b> { } } -impl<'a, 'b> Into> for RawSocket<'a, 'b> { - fn into(self) -> Socket<'a, 'b> { - Socket::Raw(self) - } -} - #[cfg(test)] mod test { - use wire::IpRepr; + use super::*; + use crate::wire::IpRepr; #[cfg(feature = "proto-ipv4")] - use wire::{Ipv4Address, Ipv4Repr}; + use crate::wire::{Ipv4Address, Ipv4Repr}; #[cfg(feature = "proto-ipv6")] - use wire::{Ipv6Address, Ipv6Repr}; - use super::*; + use crate::wire::{Ipv6Address, Ipv6Repr}; - fn buffer(packets: usize) -> RawSocketBuffer<'static, 'static> { - RawSocketBuffer::new(vec![RawPacketMetadata::EMPTY; packets], vec![0; 48 * packets]) + fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; packets], vec![0; 48 * packets]) } #[cfg(feature = "proto-ipv4")] mod ipv4_locals { use super::*; - pub fn socket(rx_buffer: RawSocketBuffer<'static, 'static>, - tx_buffer: RawSocketBuffer<'static, 'static>) - -> RawSocket<'static, 'static> { - RawSocket::new(IpVersion::Ipv4, IpProtocol::Unknown(IP_PROTO), - rx_buffer, tx_buffer) + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new( + IpVersion::Ipv4, + IpProtocol::Unknown(IP_PROTO), + rx_buffer, + tx_buffer, + ) } pub const IP_PROTO: u8 = 63; pub const HEADER_REPR: IpRepr = IpRepr::Ipv4(Ipv4Repr { src_addr: Ipv4Address([10, 0, 0, 1]), dst_addr: Ipv4Address([10, 0, 0, 2]), - protocol: IpProtocol::Unknown(IP_PROTO), + next_header: IpProtocol::Unknown(IP_PROTO), payload_len: 4, - hop_limit: 64 + hop_limit: 64, }); pub const PACKET_BYTES: [u8; 24] = [ - 0x45, 0x00, 0x00, 0x18, - 0x00, 0x00, 0x40, 0x00, - 0x40, 0x3f, 0x00, 0x00, - 0x0a, 0x00, 0x00, 0x01, - 0x0a, 0x00, 0x00, 0x02, - 0xaa, 0x00, 0x00, 0xff - ]; - pub const PACKET_PAYLOAD: [u8; 4] = [ - 0xaa, 0x00, 0x00, 0xff + 0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00, 0x40, 0x3f, 0x00, 0x00, 0x0a, 0x00, + 0x00, 0x01, 0x0a, 0x00, 0x00, 0x02, 0xaa, 0x00, 0x00, 0xff, ]; + pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; } #[cfg(feature = "proto-ipv6")] mod ipv6_locals { use super::*; - pub fn socket(rx_buffer: RawSocketBuffer<'static, 'static>, - tx_buffer: RawSocketBuffer<'static, 'static>) - -> RawSocket<'static, 'static> { - RawSocket::new(IpVersion::Ipv6, IpProtocol::Unknown(IP_PROTO), - rx_buffer, tx_buffer) + pub fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new( + IpVersion::Ipv6, + IpProtocol::Unknown(IP_PROTO), + rx_buffer, + tx_buffer, + ) } pub const IP_PROTO: u8 = 63; pub const HEADER_REPR: IpRepr = IpRepr::Ipv6(Ipv6Repr { - src_addr: Ipv6Address([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]), - dst_addr: Ipv6Address([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02]), + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ]), next_header: IpProtocol::Unknown(IP_PROTO), payload_len: 4, - hop_limit: 64 + hop_limit: 64, }); pub const PACKET_BYTES: [u8; 44] = [ - 0x60, 0x00, 0x00, 0x00, - 0x00, 0x04, 0x3f, 0x40, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02, - 0xaa, 0x00, 0x00, 0xff + 0x60, 0x00, 0x00, 0x00, 0x00, 0x04, 0x3f, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xaa, 0x00, + 0x00, 0xff, ]; - pub const PACKET_PAYLOAD: [u8; 4] = [ - 0xaa, 0x00, 0x00, 0xff - ]; + pub const PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; } macro_rules! reusable_ip_specific_tests { @@ -340,44 +536,52 @@ mod test { #[test] fn test_send_truncated() { let mut socket = $socket(buffer(0), buffer(1)); - assert_eq!(socket.send_slice(&[0; 56][..]), Err(Error::Truncated)); + assert_eq!(socket.send_slice(&[0; 56][..]), Err(SendError::BufferFull)); } #[test] fn test_send_dispatch() { - let checksum_caps = &ChecksumCapabilities::default(); let mut socket = $socket(buffer(0), buffer(1)); + let mut cx = Context::mock(); assert!(socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), - Err(Error::Exhausted)); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); assert_eq!(socket.send_slice(&$packet[..]), Ok(())); - assert_eq!(socket.send_slice(b""), Err(Error::Exhausted)); + assert_eq!(socket.send_slice(b""), Err(SendError::BufferFull)); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |(ip_repr, ip_payload)| { - assert_eq!(ip_repr, $hdr); - assert_eq!(ip_payload, &$payload); - Err(Error::Unaddressable) - }), Err(Error::Unaddressable)); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| { + assert_eq!(ip_repr, $hdr); + assert_eq!(ip_payload, &$payload); + Err(()) + }), + Err(()) + ); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(&checksum_caps, |(ip_repr, ip_payload)| { - assert_eq!(ip_repr, $hdr); - assert_eq!(ip_payload, &$payload); + assert_eq!( + socket.dispatch(&mut cx, |_, (ip_repr, ip_payload)| { + assert_eq!(ip_repr, $hdr); + assert_eq!(ip_payload, &$payload); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); assert!(socket.can_send()); } #[test] fn test_recv_truncated_slice() { let mut socket = $socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); assert!(socket.accepts(&$hdr)); - assert_eq!(socket.process(&$hdr, &$payload, - &ChecksumCapabilities::default()), Ok(())); + socket.process(&mut cx, &$hdr, &$payload); let mut slice = [0; 4]; assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); @@ -387,66 +591,100 @@ mod test { #[test] fn test_recv_truncated_packet() { let mut socket = $socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); let mut buffer = vec![0; 128]; buffer[..$packet.len()].copy_from_slice(&$packet[..]); assert!(socket.accepts(&$hdr)); - assert_eq!(socket.process(&$hdr, &buffer, &ChecksumCapabilities::default()), - Err(Error::Truncated)); + socket.process(&mut cx, &$hdr, &buffer); + } + + #[test] + fn test_peek_truncated_slice() { + let mut socket = $socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + + assert!(socket.accepts(&$hdr)); + socket.process(&mut cx, &$hdr, &$payload); + + let mut slice = [0; 4]; + assert_eq!(socket.peek_slice(&mut slice[..]), Ok(4)); + assert_eq!(&slice, &$packet[..slice.len()]); + assert_eq!(socket.recv_slice(&mut slice[..]), Ok(4)); + assert_eq!(&slice, &$packet[..slice.len()]); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); } } - } + }; } #[cfg(feature = "proto-ipv4")] - reusable_ip_specific_tests!(ipv4, ipv4_locals::socket, ipv4_locals::HEADER_REPR, - ipv4_locals::PACKET_BYTES, ipv4_locals::PACKET_PAYLOAD); + reusable_ip_specific_tests!( + ipv4, + ipv4_locals::socket, + ipv4_locals::HEADER_REPR, + ipv4_locals::PACKET_BYTES, + ipv4_locals::PACKET_PAYLOAD + ); #[cfg(feature = "proto-ipv6")] - reusable_ip_specific_tests!(ipv6, ipv6_locals::socket, ipv6_locals::HEADER_REPR, - ipv6_locals::PACKET_BYTES, ipv6_locals::PACKET_PAYLOAD); - + reusable_ip_specific_tests!( + ipv6, + ipv6_locals::socket, + ipv6_locals::HEADER_REPR, + ipv6_locals::PACKET_BYTES, + ipv6_locals::PACKET_PAYLOAD + ); #[test] #[cfg(feature = "proto-ipv4")] fn test_send_illegal() { - let checksum_caps = &ChecksumCapabilities::default(); #[cfg(feature = "proto-ipv4")] { let mut socket = ipv4_locals::socket(buffer(0), buffer(2)); + let mut cx = Context::mock(); - let mut wrong_version = ipv4_locals::PACKET_BYTES.clone(); + let mut wrong_version = ipv4_locals::PACKET_BYTES; Ipv4Packet::new_unchecked(&mut wrong_version).set_version(6); assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), - Ok(())); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); - let mut wrong_protocol = ipv4_locals::PACKET_BYTES.clone(); - Ipv4Packet::new_unchecked(&mut wrong_protocol).set_protocol(IpProtocol::Tcp); + let mut wrong_protocol = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut wrong_protocol).set_next_header(IpProtocol::Tcp); assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), - Ok(())); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); } #[cfg(feature = "proto-ipv6")] { let mut socket = ipv6_locals::socket(buffer(0), buffer(2)); + let mut cx = Context::mock(); - let mut wrong_version = ipv6_locals::PACKET_BYTES.clone(); + let mut wrong_version = ipv6_locals::PACKET_BYTES; Ipv6Packet::new_unchecked(&mut wrong_version[..]).set_version(4); assert_eq!(socket.send_slice(&wrong_version[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), - Ok(())); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); - let mut wrong_protocol = ipv6_locals::PACKET_BYTES.clone(); + let mut wrong_protocol = ipv6_locals::PACKET_BYTES; Ipv6Packet::new_unchecked(&mut wrong_protocol[..]).set_next_header(IpProtocol::Tcp); assert_eq!(socket.send_slice(&wrong_protocol[..]), Ok(())); - assert_eq!(socket.dispatch(&checksum_caps, |_| unreachable!()), - Ok(())); + assert_eq!( + socket.dispatch(&mut cx, |_, _| unreachable!()), + Ok::<_, ()>(()) + ); } } @@ -456,21 +694,26 @@ mod test { { let mut socket = ipv4_locals::socket(buffer(1), buffer(0)); assert!(!socket.can_recv()); + let mut cx = Context::mock(); - let mut cksumd_packet = ipv4_locals::PACKET_BYTES.clone(); + let mut cksumd_packet = ipv4_locals::PACKET_BYTES; Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum(); - assert_eq!(socket.recv(), Err(Error::Exhausted)); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), - Ok(())); + socket.process( + &mut cx, + &ipv4_locals::HEADER_REPR, + &ipv4_locals::PACKET_PAYLOAD, + ); assert!(socket.can_recv()); assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv4_locals::HEADER_REPR, &ipv4_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), - Err(Error::Exhausted)); + socket.process( + &mut cx, + &ipv4_locals::HEADER_REPR, + &ipv4_locals::PACKET_PAYLOAD, + ); assert_eq!(socket.recv(), Ok(&cksumd_packet[..])); assert!(!socket.can_recv()); } @@ -478,37 +721,103 @@ mod test { { let mut socket = ipv6_locals::socket(buffer(1), buffer(0)); assert!(!socket.can_recv()); + let mut cx = Context::mock(); - assert_eq!(socket.recv(), Err(Error::Exhausted)); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), - Ok(())); + socket.process( + &mut cx, + &ipv6_locals::HEADER_REPR, + &ipv6_locals::PACKET_PAYLOAD, + ); assert!(socket.can_recv()); assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); - assert_eq!(socket.process(&ipv6_locals::HEADER_REPR, &ipv6_locals::PACKET_PAYLOAD, - &ChecksumCapabilities::default()), - Err(Error::Exhausted)); + socket.process( + &mut cx, + &ipv6_locals::HEADER_REPR, + &ipv6_locals::PACKET_PAYLOAD, + ); assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..])); assert!(!socket.can_recv()); } } + #[test] + fn test_peek_process() { + #[cfg(feature = "proto-ipv4")] + { + let mut socket = ipv4_locals::socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + + let mut cksumd_packet = ipv4_locals::PACKET_BYTES; + Ipv4Packet::new_unchecked(&mut cksumd_packet).fill_checksum(); + + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process( + &mut cx, + &ipv4_locals::HEADER_REPR, + &ipv4_locals::PACKET_PAYLOAD, + ); + + assert!(socket.accepts(&ipv4_locals::HEADER_REPR)); + socket.process( + &mut cx, + &ipv4_locals::HEADER_REPR, + &ipv4_locals::PACKET_PAYLOAD, + ); + assert_eq!(socket.peek(), Ok(&cksumd_packet[..])); + assert_eq!(socket.recv(), Ok(&cksumd_packet[..])); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } + #[cfg(feature = "proto-ipv6")] + { + let mut socket = ipv6_locals::socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process( + &mut cx, + &ipv6_locals::HEADER_REPR, + &ipv6_locals::PACKET_PAYLOAD, + ); + + assert!(socket.accepts(&ipv6_locals::HEADER_REPR)); + socket.process( + &mut cx, + &ipv6_locals::HEADER_REPR, + &ipv6_locals::PACKET_PAYLOAD, + ); + assert_eq!(socket.peek(), Ok(&ipv6_locals::PACKET_BYTES[..])); + assert_eq!(socket.recv(), Ok(&ipv6_locals::PACKET_BYTES[..])); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); + } + } + #[test] fn test_doesnt_accept_wrong_proto() { #[cfg(feature = "proto-ipv4")] { - let socket = RawSocket::new(IpVersion::Ipv4, - IpProtocol::Unknown(ipv4_locals::IP_PROTO+1), buffer(1), buffer(1)); + let socket = Socket::new( + IpVersion::Ipv4, + IpProtocol::Unknown(ipv4_locals::IP_PROTO + 1), + buffer(1), + buffer(1), + ); assert!(!socket.accepts(&ipv4_locals::HEADER_REPR)); #[cfg(feature = "proto-ipv6")] assert!(!socket.accepts(&ipv6_locals::HEADER_REPR)); } #[cfg(feature = "proto-ipv6")] { - let socket = RawSocket::new(IpVersion::Ipv6, - IpProtocol::Unknown(ipv6_locals::IP_PROTO+1), buffer(1), buffer(1)); + let socket = Socket::new( + IpVersion::Ipv6, + IpProtocol::Unknown(ipv6_locals::IP_PROTO + 1), + buffer(1), + buffer(1), + ); assert!(!socket.accepts(&ipv6_locals::HEADER_REPR)); #[cfg(feature = "proto-ipv4")] assert!(!socket.accepts(&ipv4_locals::HEADER_REPR)); diff --git a/src/socket/ref_.rs b/src/socket/ref_.rs deleted file mode 100644 index db3205c84..000000000 --- a/src/socket/ref_.rs +++ /dev/null @@ -1,89 +0,0 @@ -use core::ops::{Deref, DerefMut}; - -#[cfg(feature = "socket-raw")] -use socket::RawSocket; -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -use socket::IcmpSocket; -#[cfg(feature = "socket-udp")] -use socket::UdpSocket; -#[cfg(feature = "socket-tcp")] -use socket::TcpSocket; - -/// A trait for tracking a socket usage session. -/// -/// Allows implementation of custom drop logic that runs only if the socket was changed -/// in specific ways. For example, drop logic for UDP would check if the local endpoint -/// has changed, and if yes, notify the socket set. -#[doc(hidden)] -pub trait Session { - fn finish(&mut self) {} -} - -#[cfg(feature = "socket-raw")] -impl<'a, 'b> Session for RawSocket<'a, 'b> {} -#[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] -impl<'a, 'b> Session for IcmpSocket<'a, 'b> {} -#[cfg(feature = "socket-udp")] -impl<'a, 'b> Session for UdpSocket<'a, 'b> {} -#[cfg(feature = "socket-tcp")] -impl<'a> Session for TcpSocket<'a> {} - -/// A smart pointer to a socket. -/// -/// Allows the network stack to efficiently determine if the socket state was changed in any way. -pub struct Ref<'a, T: Session + 'a> { - /// Reference to the socket. - /// - /// This is almost always `Some` except when dropped in `into_inner` which removes the socket - /// reference. This properly tracks the initialization state without any additional bytes as - /// the `None` variant occupies the `0` pattern which is invalid for the reference. - socket: Option<&'a mut T>, -} - -impl<'a, T: Session + 'a> Ref<'a, T> { - /// Wrap a pointer to a socket to make a smart pointer. - /// - /// Calling this function is only necessary if your code is using [into_inner]. - /// - /// [into_inner]: #method.into_inner - pub fn new(socket: &'a mut T) -> Self { - Ref { socket: Some(socket) } - } - - /// Unwrap a smart pointer to a socket. - /// - /// The finalization code is not run. Prompt operation of the network stack depends - /// on wrapping the returned pointer back and dropping it. - /// - /// Calling this function is only necessary to achieve composability if you *must* - /// map a `&mut SocketRef<'a, XSocket>` to a `&'a mut XSocket` (note the lifetimes); - /// be sure to call [new] afterwards. - /// - /// [new]: #method.new - pub fn into_inner(mut ref_: Self) -> &'a mut T { - ref_.socket.take().unwrap() - } -} - -impl<'a, T: Session> Deref for Ref<'a, T> { - type Target = T; - - fn deref(&self) -> &Self::Target { - // Deref is only used while the socket is still in place (into inner has not been called). - self.socket.as_ref().unwrap() - } -} - -impl<'a, T: Session> DerefMut for Ref<'a, T> { - fn deref_mut(&mut self) -> &mut Self::Target { - self.socket.as_mut().unwrap() - } -} - -impl<'a, T: Session> Drop for Ref<'a, T> { - fn drop(&mut self) { - if let Some(socket) = self.socket.take() { - Session::finish(socket); - } - } -} diff --git a/src/socket/set.rs b/src/socket/set.rs deleted file mode 100644 index b41a5383c..000000000 --- a/src/socket/set.rs +++ /dev/null @@ -1,221 +0,0 @@ -use core::{fmt, slice}; -use managed::ManagedSlice; - -use super::{Socket, SocketRef, AnySocket}; -#[cfg(feature = "socket-tcp")] -use super::TcpState; - -/// An item of a socket set. -/// -/// The only reason this struct is public is to allow the socket set storage -/// to be allocated externally. -#[derive(Debug)] -pub struct Item<'a, 'b: 'a> { - socket: Socket<'a, 'b>, - refs: usize -} - -/// A handle, identifying a socket in a set. -#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Default, Hash)] -pub struct Handle(usize); - -impl fmt::Display for Handle { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "#{}", self.0) - } -} - -/// An extensible set of sockets. -/// -/// The lifetimes `'b` and `'c` are used when storing a `Socket<'b, 'c>`. -#[derive(Debug)] -pub struct Set<'a, 'b: 'a, 'c: 'a + 'b> { - sockets: ManagedSlice<'a, Option>> -} - -impl<'a, 'b: 'a, 'c: 'a + 'b> Set<'a, 'b, 'c> { - /// Create a socket set using the provided storage. - pub fn new(sockets: SocketsT) -> Set<'a, 'b, 'c> - where SocketsT: Into>>> { - let sockets = sockets.into(); - Set { - sockets: sockets - } - } - - /// Add a socket to the set with the reference count 1, and return its handle. - /// - /// # Panics - /// This function panics if the storage is fixed-size (not a `Vec`) and is full. - pub fn add(&mut self, socket: T) -> Handle - where T: Into> - { - fn put<'b, 'c>(index: usize, slot: &mut Option>, - mut socket: Socket<'b, 'c>) -> Handle { - net_trace!("[{}]: adding", index); - let handle = Handle(index); - socket.meta_mut().handle = handle; - *slot = Some(Item { socket: socket, refs: 1 }); - handle - } - - let socket = socket.into(); - - for (index, slot) in self.sockets.iter_mut().enumerate() { - if slot.is_none() { - return put(index, slot, socket) - } - } - - match self.sockets { - ManagedSlice::Borrowed(_) => { - panic!("adding a socket to a full SocketSet") - } - #[cfg(any(feature = "std", feature = "alloc"))] - ManagedSlice::Owned(ref mut sockets) => { - sockets.push(None); - let index = sockets.len() - 1; - return put(index, &mut sockets[index], socket) - } - } - } - - /// Get a socket from the set by its handle, as mutable. - /// - /// # Panics - /// This function may panic if the handle does not belong to this socket set - /// or the socket has the wrong type. - pub fn get>(&mut self, handle: Handle) -> SocketRef { - match self.sockets[handle.0].as_mut() { - Some(item) => { - T::downcast(SocketRef::new(&mut item.socket)) - .expect("handle refers to a socket of a wrong type") - } - None => panic!("handle does not refer to a valid socket") - } - } - - /// Remove a socket from the set, without changing its state. - /// - /// # Panics - /// This function may panic if the handle does not belong to this socket set. - pub fn remove(&mut self, handle: Handle) -> Socket<'b, 'c> { - net_trace!("[{}]: removing", handle.0); - match self.sockets[handle.0].take() { - Some(item) => item.socket, - None => panic!("handle does not refer to a valid socket") - } - } - - /// Increase reference count by 1. - /// - /// # Panics - /// This function may panic if the handle does not belong to this socket set. - pub fn retain(&mut self, handle: Handle) { - self.sockets[handle.0] - .as_mut() - .expect("handle does not refer to a valid socket") - .refs += 1 - } - - /// Decrease reference count by 1. - /// - /// # Panics - /// This function may panic if the handle does not belong to this socket set, - /// or if the reference count is already zero. - pub fn release(&mut self, handle: Handle) { - let refs = &mut self.sockets[handle.0] - .as_mut() - .expect("handle does not refer to a valid socket") - .refs; - if *refs == 0 { panic!("decreasing reference count past zero") } - *refs -= 1 - } - - /// Prune the sockets in this set. - /// - /// Pruning affects sockets with reference count 0. Open sockets are closed. - /// Closed sockets are removed and dropped. - pub fn prune(&mut self) { - for (index, item) in self.sockets.iter_mut().enumerate() { - let mut may_remove = false; - if let &mut Some(Item { refs: 0, ref mut socket }) = item { - match socket { - #[cfg(feature = "socket-raw")] - &mut Socket::Raw(_) => - may_remove = true, - #[cfg(all(feature = "socket-icmp", any(feature = "proto-ipv4", feature = "proto-ipv6")))] - &mut Socket::Icmp(_) => - may_remove = true, - #[cfg(feature = "socket-udp")] - &mut Socket::Udp(_) => - may_remove = true, - #[cfg(feature = "socket-tcp")] - &mut Socket::Tcp(ref mut socket) => - if socket.state() == TcpState::Closed { - may_remove = true - } else { - socket.close() - }, - &mut Socket::__Nonexhaustive(_) => unreachable!() - } - } - if may_remove { - net_trace!("[{}]: pruning", index); - *item = None - } - } - } - - /// Iterate every socket in this set. - pub fn iter<'d>(&'d self) -> Iter<'d, 'b, 'c> { - Iter { lower: self.sockets.iter() } - } - - /// Iterate every socket in this set, as SocketRef. - pub fn iter_mut<'d>(&'d mut self) -> IterMut<'d, 'b, 'c> { - IterMut { lower: self.sockets.iter_mut() } - } -} - -/// Immutable socket set iterator. -/// -/// This struct is created by the [iter](struct.SocketSet.html#method.iter) -/// on [socket sets](struct.SocketSet.html). -pub struct Iter<'a, 'b: 'a, 'c: 'a + 'b> { - lower: slice::Iter<'a, Option>> -} - -impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for Iter<'a, 'b, 'c> { - type Item = &'a Socket<'b, 'c>; - - fn next(&mut self) -> Option { - while let Some(item_opt) = self.lower.next() { - if let Some(item) = item_opt.as_ref() { - return Some(&item.socket) - } - } - None - } -} - -/// Mutable socket set iterator. -/// -/// This struct is created by the [iter_mut](struct.SocketSet.html#method.iter_mut) -/// on [socket sets](struct.SocketSet.html). -pub struct IterMut<'a, 'b: 'a, 'c: 'a + 'b> { - lower: slice::IterMut<'a, Option>>, -} - -impl<'a, 'b: 'a, 'c: 'a + 'b> Iterator for IterMut<'a, 'b, 'c> { - type Item = SocketRef<'a, Socket<'b, 'c>>; - - fn next(&mut self) -> Option { - while let Some(item_opt) = self.lower.next() { - if let Some(item) = item_opt.as_mut() { - return Some(SocketRef::new(&mut item.socket)) - } - } - None - } -} diff --git a/src/socket/tcp.rs b/src/socket/tcp.rs index f6c0f97a4..0384d865a 100644 --- a/src/socket/tcp.rs +++ b/src/socket/tcp.rs @@ -2,14 +2,102 @@ // the parts of RFC 1122 that discuss TCP. Consult RFC 7414 when implementing // a new feature. +use core::fmt::Display; +#[cfg(feature = "async")] +use core::task::Waker; use core::{cmp, fmt, mem}; -use {Error, Result}; -use phy::DeviceCapabilities; -use time::{Duration, Instant}; -use socket::{Socket, SocketMeta, SocketHandle, PollAt}; -use storage::{Assembler, RingBuffer}; -use wire::{IpProtocol, IpRepr, IpAddress, IpEndpoint, TcpSeqNumber, TcpRepr, TcpControl}; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::socket::{Context, PollAt}; +use crate::storage::{Assembler, RingBuffer}; +use crate::time::{Duration, Instant}; +use crate::wire::{ + IpAddress, IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, TcpControl, TcpRepr, TcpSeqNumber, + TCP_HEADER_LEN, +}; + +macro_rules! tcp_trace { + ($($arg:expr),*) => (net_log!(trace, $($arg),*)); +} + +/// Error returned by [`Socket::listen`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ListenError { + InvalidState, + Unaddressable, +} + +impl Display for ListenError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ListenError::InvalidState => write!(f, "invalid state"), + ListenError::Unaddressable => write!(f, "unaddressable destination"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ListenError {} + +/// Error returned by [`Socket::connect`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum ConnectError { + InvalidState, + Unaddressable, +} + +impl Display for ConnectError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + ConnectError::InvalidState => write!(f, "invalid state"), + ConnectError::Unaddressable => write!(f, "unaddressable destination"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for ConnectError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + InvalidState, +} + +impl Display for SendError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + SendError::InvalidState => write!(f, "invalid state"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + InvalidState, + Finished, +} + +impl Display for RecvError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + RecvError::InvalidState => write!(f, "invalid state"), + RecvError::Finished => write!(f, "operation finished"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} /// A TCP socket ring buffer. pub type SocketBuffer<'a> = RingBuffer<'a, u8>; @@ -18,6 +106,7 @@ pub type SocketBuffer<'a> = RingBuffer<'a, u8>; /// /// [RFC 793]: https://tools.ietf.org/html/rfc793 #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum State { Closed, Listen, @@ -29,87 +118,195 @@ pub enum State { CloseWait, Closing, LastAck, - TimeWait + TimeWait, } impl fmt::Display for State { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &State::Closed => write!(f, "CLOSED"), - &State::Listen => write!(f, "LISTEN"), - &State::SynSent => write!(f, "SYN-SENT"), - &State::SynReceived => write!(f, "SYN-RECEIVED"), - &State::Established => write!(f, "ESTABLISHED"), - &State::FinWait1 => write!(f, "FIN-WAIT-1"), - &State::FinWait2 => write!(f, "FIN-WAIT-2"), - &State::CloseWait => write!(f, "CLOSE-WAIT"), - &State::Closing => write!(f, "CLOSING"), - &State::LastAck => write!(f, "LAST-ACK"), - &State::TimeWait => write!(f, "TIME-WAIT") + match *self { + State::Closed => write!(f, "CLOSED"), + State::Listen => write!(f, "LISTEN"), + State::SynSent => write!(f, "SYN-SENT"), + State::SynReceived => write!(f, "SYN-RECEIVED"), + State::Established => write!(f, "ESTABLISHED"), + State::FinWait1 => write!(f, "FIN-WAIT-1"), + State::FinWait2 => write!(f, "FIN-WAIT-2"), + State::CloseWait => write!(f, "CLOSE-WAIT"), + State::Closing => write!(f, "CLOSING"), + State::LastAck => write!(f, "LAST-ACK"), + State::TimeWait => write!(f, "TIME-WAIT"), + } + } +} + +// Conservative initial RTT estimate. +const RTTE_INITIAL_RTT: u32 = 300; +const RTTE_INITIAL_DEV: u32 = 100; + +// Minimum "safety margin" for the RTO that kicks in when the +// variance gets very low. +const RTTE_MIN_MARGIN: u32 = 5; + +const RTTE_MIN_RTO: u32 = 10; +const RTTE_MAX_RTO: u32 = 10000; + +#[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct RttEstimator { + // Using u32 instead of Duration to save space (Duration is i64) + rtt: u32, + deviation: u32, + timestamp: Option<(Instant, TcpSeqNumber)>, + max_seq_sent: Option, + rto_count: u8, +} + +impl Default for RttEstimator { + fn default() -> Self { + Self { + rtt: RTTE_INITIAL_RTT, + deviation: RTTE_INITIAL_DEV, + timestamp: None, + max_seq_sent: None, + rto_count: 0, + } + } +} + +impl RttEstimator { + fn retransmission_timeout(&self) -> Duration { + let margin = RTTE_MIN_MARGIN.max(self.deviation * 4); + let ms = (self.rtt + margin).clamp(RTTE_MIN_RTO, RTTE_MAX_RTO); + Duration::from_millis(ms as u64) + } + + fn sample(&mut self, new_rtt: u32) { + // "Congestion Avoidance and Control", Van Jacobson, Michael J. Karels, 1988 + self.rtt = (self.rtt * 7 + new_rtt + 7) / 8; + let diff = (self.rtt as i32 - new_rtt as i32).unsigned_abs(); + self.deviation = (self.deviation * 3 + diff + 3) / 4; + + self.rto_count = 0; + + let rto = self.retransmission_timeout().total_millis(); + tcp_trace!( + "rtte: sample={:?} rtt={:?} dev={:?} rto={:?}", + new_rtt, + self.rtt, + self.deviation, + rto + ); + } + + fn on_send(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if self + .max_seq_sent + .map(|max_seq_sent| seq > max_seq_sent) + .unwrap_or(true) + { + self.max_seq_sent = Some(seq); + if self.timestamp.is_none() { + self.timestamp = Some((timestamp, seq)); + tcp_trace!("rtte: sampling at seq={:?}", seq); + } + } + } + + fn on_ack(&mut self, timestamp: Instant, seq: TcpSeqNumber) { + if let Some((sent_timestamp, sent_seq)) = self.timestamp { + if seq >= sent_seq { + self.sample((timestamp - sent_timestamp).total_millis() as u32); + self.timestamp = None; + } + } + } + + fn on_retransmit(&mut self) { + if self.timestamp.is_some() { + tcp_trace!("rtte: abort sampling due to retransmit"); + } + self.timestamp = None; + self.rto_count = self.rto_count.saturating_add(1); + if self.rto_count >= 3 { + // This happens in 2 scenarios: + // - The RTT is higher than the initial estimate + // - The network conditions change, suddenly making the RTT much higher + // In these cases, the estimator can get stuck, because it can't sample because + // all packets sent would incur a retransmit. To avoid this, force an estimate + // increase if we see 3 consecutive retransmissions without any successful sample. + self.rto_count = 0; + self.rtt = RTTE_MAX_RTO.min(self.rtt * 2); + let rto = self.retransmission_timeout().total_millis(); + tcp_trace!( + "rtte: too many retransmissions, increasing: rtt={:?} dev={:?} rto={:?}", + self.rtt, + self.deviation, + rto + ); } } } #[derive(Debug, Clone, Copy, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] enum Timer { Idle { keep_alive_at: Option, }, Retransmit { expires_at: Instant, - delay: Duration + delay: Duration, }, FastRetransmit, Close { - expires_at: Instant - } + expires_at: Instant, + }, } -const RETRANSMIT_DELAY: Duration = Duration { millis: 100 }; -const CLOSE_DELAY: Duration = Duration { millis: 10_000 }; +const ACK_DELAY_DEFAULT: Duration = Duration::from_millis(10); +const CLOSE_DELAY: Duration = Duration::from_millis(10_000); -impl Default for Timer { - fn default() -> Timer { - Timer::Idle { keep_alive_at: None } +impl Timer { + fn new() -> Timer { + Timer::Idle { + keep_alive_at: None, + } } -} -impl Timer { fn should_keep_alive(&self, timestamp: Instant) -> bool { match *self { - Timer::Idle { keep_alive_at: Some(keep_alive_at) } - if timestamp >= keep_alive_at => { - true - } - _ => false + Timer::Idle { + keep_alive_at: Some(keep_alive_at), + } if timestamp >= keep_alive_at => true, + _ => false, } } fn should_retransmit(&self, timestamp: Instant) -> Option { match *self { - Timer::Retransmit { expires_at, delay } - if timestamp >= expires_at => { + Timer::Retransmit { expires_at, delay } if timestamp >= expires_at => { Some(timestamp - expires_at + delay) - }, + } Timer::FastRetransmit => Some(Duration::from_millis(0)), - _ => None + _ => None, } } fn should_close(&self, timestamp: Instant) -> bool { match *self { - Timer::Close { expires_at } - if timestamp >= expires_at => { - true - } - _ => false + Timer::Close { expires_at } if timestamp >= expires_at => true, + _ => false, } } fn poll_at(&self) -> PollAt { match *self { - Timer::Idle { keep_alive_at: Some(keep_alive_at) } => PollAt::Time(keep_alive_at), - Timer::Idle { keep_alive_at: None } => PollAt::Ingress, + Timer::Idle { + keep_alive_at: Some(keep_alive_at), + } => PollAt::Time(keep_alive_at), + Timer::Idle { + keep_alive_at: None, + } => PollAt::Ingress, Timer::Retransmit { expires_at, .. } => PollAt::Time(expires_at), Timer::FastRetransmit => PollAt::Now, Timer::Close { expires_at } => PollAt::Time(expires_at), @@ -118,46 +315,40 @@ impl Timer { fn set_for_idle(&mut self, timestamp: Instant, interval: Option) { *self = Timer::Idle { - keep_alive_at: interval.map(|interval| timestamp + interval) + keep_alive_at: interval.map(|interval| timestamp + interval), } } fn set_keep_alive(&mut self) { - match *self { - Timer::Idle { ref mut keep_alive_at } - if keep_alive_at.is_none() => { + if let Timer::Idle { keep_alive_at } = self { + if keep_alive_at.is_none() { *keep_alive_at = Some(Instant::from_millis(0)) } - _ => () } } fn rewind_keep_alive(&mut self, timestamp: Instant, interval: Option) { - match self { - &mut Timer::Idle { ref mut keep_alive_at } => { - *keep_alive_at = interval.map(|interval| timestamp + interval) - } - _ => () + if let Timer::Idle { keep_alive_at } = self { + *keep_alive_at = interval.map(|interval| timestamp + interval) } } - fn set_for_retransmit(&mut self, timestamp: Instant) { + fn set_for_retransmit(&mut self, timestamp: Instant, delay: Duration) { match *self { Timer::Idle { .. } | Timer::FastRetransmit { .. } => { *self = Timer::Retransmit { - expires_at: timestamp + RETRANSMIT_DELAY, - delay: RETRANSMIT_DELAY, + expires_at: timestamp + delay, + delay, } } - Timer::Retransmit { expires_at, delay } - if timestamp >= expires_at => { + Timer::Retransmit { expires_at, delay } if timestamp >= expires_at => { *self = Timer::Retransmit { expires_at: timestamp + delay, - delay: delay * 2 + delay: delay * 2, } } Timer::Retransmit { .. } => (), - Timer::Close { .. } => () + Timer::Close { .. } => (), } } @@ -167,18 +358,38 @@ impl Timer { fn set_for_close(&mut self, timestamp: Instant) { *self = Timer::Close { - expires_at: timestamp + CLOSE_DELAY + expires_at: timestamp + CLOSE_DELAY, } } fn is_retransmit(&self) -> bool { match *self { - Timer::Retransmit {..} | Timer::FastRetransmit => true, + Timer::Retransmit { .. } | Timer::FastRetransmit => true, _ => false, } } } +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +enum AckDelayTimer { + Idle, + Waiting(Instant), + Immediate, +} + +#[derive(Debug, Copy, Clone, Eq, PartialEq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +struct Tuple { + local: IpEndpoint, + remote: IpEndpoint, +} + +impl Display for Tuple { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{}:{}", self.local, self.remote) + } +} + /// A Transmission Control Protocol socket. /// /// A TCP socket may passively listen for connections or actively connect to another endpoint. @@ -186,39 +397,31 @@ impl Timer { /// accept several connections, as many sockets must be allocated, or any new connection /// attempts will be reset. #[derive(Debug)] -pub struct TcpSocket<'a> { - pub(crate) meta: SocketMeta, - state: State, - timer: Timer, - assembler: Assembler, - rx_buffer: SocketBuffer<'a>, +pub struct Socket<'a> { + state: State, + timer: Timer, + rtte: RttEstimator, + assembler: Assembler, + rx_buffer: SocketBuffer<'a>, rx_fin_received: bool, - tx_buffer: SocketBuffer<'a>, + tx_buffer: SocketBuffer<'a>, /// Interval after which, if no inbound packets are received, the connection is aborted. - timeout: Option, + timeout: Option, /// Interval at which keep-alive packets will be sent. - keep_alive: Option, + keep_alive: Option, /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. - hop_limit: Option, + hop_limit: Option, /// Address passed to listen(). Listen address is set when listen() is called and /// used every time the socket is reset back to the LISTEN state. - listen_address: IpAddress, - /// Current local endpoint. This is used for both filtering the incoming packets and - /// setting the source address. When listening or initiating connection on/from - /// an unspecified address, this field is updated with the chosen source address before - /// any packets are sent. - local_endpoint: IpEndpoint, - /// Current remote endpoint. This is used for both filtering the incoming packets and - /// setting the destination address. If the remote endpoint is unspecified, it means that - /// aborting the connection will not send an RST, and, in TIME-WAIT state, will not - /// send an ACK. - remote_endpoint: IpEndpoint, + listen_endpoint: IpListenEndpoint, + /// Current 4-tuple (local and remote endpoints). + tuple: Option, /// The sequence number corresponding to the beginning of the transmit buffer. /// I.e. an ACK(local_seq_no+n) packet removes n bytes from the transmit buffer. - local_seq_no: TcpSeqNumber, + local_seq_no: TcpSeqNumber, /// The sequence number corresponding to the beginning of the receive buffer. /// I.e. userspace reading n bytes adds n to remote_seq_no. - remote_seq_no: TcpSeqNumber, + remote_seq_no: TcpSeqNumber, /// The last sequence number sent. /// I.e. in an idle socket, local_seq_no+tx_buffer.len(). remote_last_seq: TcpSeqNumber, @@ -230,33 +433,52 @@ pub struct TcpSocket<'a> { /// The sending window scaling factor advertised to remotes which support RFC 1323. /// It is zero if the window <= 64KiB and/or the remote does not support it. remote_win_shift: u8, - /// The speculative remote window size. - /// I.e. the actual remote window size minus the count of in-flight octets. - remote_win_len: usize, + /// The remote window size, relative to local_seq_no + /// I.e. we're allowed to send octets until local_seq_no+remote_win_len + remote_win_len: usize, /// The receive window scaling factor for remotes which support RFC 1323, None if unsupported. remote_win_scale: Option, /// Whether or not the remote supports selective ACK as described in RFC 2018. remote_has_sack: bool, /// The maximum number of data octets that the remote side may receive. - remote_mss: usize, + remote_mss: usize, /// The timestamp of the last packet received. - remote_last_ts: Option, - /// The sequence number of the last packet recived, used for sACK + remote_last_ts: Option, + /// The sequence number of the last packet received, used for sACK local_rx_last_seq: Option, - /// The ACK number of the last packet recived. + /// The ACK number of the last packet received. local_rx_last_ack: Option, - /// The number of packets recived directly after + /// The number of packets received directly after /// each other which have the same ACK number. local_rx_dup_acks: u8, + + /// Duration for Delayed ACK. If None no ACKs will be delayed. + ack_delay: Option, + /// Delayed ack timer. If set, packets containing exclusively + /// ACK or window updates (ie, no data) won't be sent until expiry. + ack_delay_timer: AckDelayTimer, + + /// Used for rate-limiting: No more challenge ACKs will be sent until this instant. + challenge_ack_timer: Instant, + + /// Nagle's Algorithm enabled. + nagle: bool, + + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, } const DEFAULT_MSS: usize = 536; -impl<'a> TcpSocket<'a> { +impl<'a> Socket<'a> { #[allow(unused_comparisons)] // small usize platforms always pass rx_capacity check /// Create a socket using the given buffers. - pub fn new(rx_buffer: T, tx_buffer: T) -> TcpSocket<'a> - where T: Into> { + pub fn new(rx_buffer: T, tx_buffer: T) -> Socket<'a> + where + T: Into>, + { let (rx_buffer, tx_buffer) = (rx_buffer.into(), tx_buffer.into()); let rx_capacity = rx_buffer.capacity(); @@ -267,44 +489,80 @@ impl<'a> TcpSocket<'a> { if rx_capacity > (1 << 30) { panic!("receiving buffer too large, cannot exceed 1 GiB") } - let rx_cap_log2 = mem::size_of::() * 8 - - rx_capacity.leading_zeros() as usize; - - TcpSocket { - meta: SocketMeta::default(), - state: State::Closed, - timer: Timer::default(), - assembler: Assembler::new(rx_buffer.capacity()), - tx_buffer: tx_buffer, - rx_buffer: rx_buffer, + let rx_cap_log2 = mem::size_of::() * 8 - rx_capacity.leading_zeros() as usize; + + Socket { + state: State::Closed, + timer: Timer::new(), + rtte: RttEstimator::default(), + assembler: Assembler::new(), + tx_buffer, + rx_buffer, rx_fin_received: false, - timeout: None, - keep_alive: None, - hop_limit: None, - listen_address: IpAddress::default(), - local_endpoint: IpEndpoint::default(), - remote_endpoint: IpEndpoint::default(), - local_seq_no: TcpSeqNumber::default(), - remote_seq_no: TcpSeqNumber::default(), + timeout: None, + keep_alive: None, + hop_limit: None, + listen_endpoint: IpListenEndpoint::default(), + tuple: None, + local_seq_no: TcpSeqNumber::default(), + remote_seq_no: TcpSeqNumber::default(), remote_last_seq: TcpSeqNumber::default(), remote_last_ack: None, remote_last_win: 0, - remote_win_len: 0, + remote_win_len: 0, remote_win_shift: rx_cap_log2.saturating_sub(16) as u8, remote_win_scale: None, remote_has_sack: false, - remote_mss: DEFAULT_MSS, - remote_last_ts: None, + remote_mss: DEFAULT_MSS, + remote_last_ts: None, local_rx_last_ack: None, local_rx_last_seq: None, local_rx_dup_acks: 0, + ack_delay: Some(ACK_DELAY_DEFAULT), + ack_delay_timer: AckDelayTimer::Idle, + challenge_ack_timer: Instant::from_secs(0), + nagle: true, + + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), } } - /// Return the socket handle. - #[inline] - pub fn handle(&self) -> SocketHandle { - self.meta.handle + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) } /// Return the timeout duration. @@ -314,14 +572,30 @@ impl<'a> TcpSocket<'a> { self.timeout } + /// Return the ACK delay duration. + /// + /// See also the [set_ack_delay](#method.set_ack_delay) method. + pub fn ack_delay(&self) -> Option { + self.ack_delay + } + + /// Return whether Nagle's Algorithm is enabled. + /// + /// See also the [set_nagle_enabled](#method.set_nagle_enabled) method. + pub fn nagle_enabled(&self) -> bool { + self.nagle + } + /// Return the current window field value, including scaling according to RFC 1323. /// /// Used in internal calculations as well as packet generation. /// #[inline] fn scaled_window(&self) -> u16 { - cmp::min(self.rx_buffer.window() >> self.remote_win_shift as usize, - (1 << 16) - 1) as u16 + cmp::min( + self.rx_buffer.window() >> self.remote_win_shift as usize, + (1 << 16) - 1, + ) as u16 } /// Set the timeout duration. @@ -339,6 +613,29 @@ impl<'a> TcpSocket<'a> { self.timeout = duration } + /// Set the ACK delay duration. + /// + /// By default, the ACK delay is set to 10ms. + pub fn set_ack_delay(&mut self, duration: Option) { + self.ack_delay = duration + } + + /// Enable or disable Nagle's Algorithm. + /// + /// Also known as "tinygram prevention". By default, it is enabled. + /// Disabling it is equivalent to Linux's TCP_NODELAY flag. + /// + /// When enabled, Nagle's Algorithm prevents sending segments smaller than MSS if + /// there is data in flight (sent but not acknowledged). In other words, it ensures + /// at most only one segment smaller than MSS is in flight at a time. + /// + /// It ensures better network utilization by preventing sending many very small packets, + /// at the cost of increased latency in some situations, particularly when the remote peer + /// has ACK delay enabled. + pub fn set_nagle_enabled(&mut self, enabled: bool) { + self.nagle = enabled + } + /// Return the keep-alive interval. /// /// See also the [set_keep_alive](#method.set_keep_alive) method. @@ -348,7 +645,7 @@ impl<'a> TcpSocket<'a> { /// Set the keep-alive interval. /// - /// An idle socket with a keep-alive interval set will transmit a "challenge ACK" packet + /// An idle socket with a keep-alive interval set will transmit a "keep-alive ACK" packet /// every time it receives no communication during that interval. As a result, three things /// may happen: /// @@ -394,16 +691,16 @@ impl<'a> TcpSocket<'a> { self.hop_limit = hop_limit } - /// Return the local endpoint. + /// Return the local endpoint, or None if not connected. #[inline] - pub fn local_endpoint(&self) -> IpEndpoint { - self.local_endpoint + pub fn local_endpoint(&self) -> Option { + Some(self.tuple?.local) } - /// Return the remote endpoint. + /// Return the remote endpoint, or None if not connected. #[inline] - pub fn remote_endpoint(&self) -> IpEndpoint { - self.remote_endpoint + pub fn remote_endpoint(&self) -> Option { + Some(self.tuple?.remote) } /// Return the connection state, in terms of the TCP state machine. @@ -413,31 +710,36 @@ impl<'a> TcpSocket<'a> { } fn reset(&mut self) { - let rx_cap_log2 = mem::size_of::() * 8 - - self.rx_buffer.capacity().leading_zeros() as usize; + let rx_cap_log2 = + mem::size_of::() * 8 - self.rx_buffer.capacity().leading_zeros() as usize; - self.state = State::Closed; - self.timer = Timer::default(); - self.assembler = Assembler::new(self.rx_buffer.capacity()); + self.state = State::Closed; + self.timer = Timer::new(); + self.rtte = RttEstimator::default(); + self.assembler = Assembler::new(); self.tx_buffer.clear(); self.rx_buffer.clear(); self.rx_fin_received = false; - self.keep_alive = None; - self.timeout = None; - self.hop_limit = None; - self.listen_address = IpAddress::default(); - self.local_endpoint = IpEndpoint::default(); - self.remote_endpoint = IpEndpoint::default(); - self.local_seq_no = TcpSeqNumber::default(); - self.remote_seq_no = TcpSeqNumber::default(); + self.listen_endpoint = IpListenEndpoint::default(); + self.tuple = None; + self.local_seq_no = TcpSeqNumber::default(); + self.remote_seq_no = TcpSeqNumber::default(); self.remote_last_seq = TcpSeqNumber::default(); self.remote_last_ack = None; self.remote_last_win = 0; - self.remote_win_len = 0; + self.remote_win_len = 0; self.remote_win_scale = None; self.remote_win_shift = rx_cap_log2.saturating_sub(16) as u8; - self.remote_mss = DEFAULT_MSS; - self.remote_last_ts = None; + self.remote_mss = DEFAULT_MSS; + self.remote_last_ts = None; + self.ack_delay_timer = AckDelayTimer::Idle; + self.challenge_ack_timer = Instant::from_secs(0); + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } } /// Start listening on the given endpoint. @@ -445,17 +747,22 @@ impl<'a> TcpSocket<'a> { /// This function returns `Err(Error::Illegal)` if the socket was already open /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` /// if the port in the given endpoint is zero. - pub fn listen(&mut self, local_endpoint: T) -> Result<()> - where T: Into { + pub fn listen(&mut self, local_endpoint: T) -> Result<(), ListenError> + where + T: Into, + { let local_endpoint = local_endpoint.into(); - if local_endpoint.port == 0 { return Err(Error::Unaddressable) } + if local_endpoint.port == 0 { + return Err(ListenError::Unaddressable); + } - if self.is_open() { return Err(Error::Illegal) } + if self.is_open() { + return Err(ListenError::InvalidState); + } self.reset(); - self.listen_address = local_endpoint.addr; - self.local_endpoint = local_endpoint; - self.remote_endpoint = IpEndpoint::default(); + self.listen_endpoint = local_endpoint; + self.tuple = None; self.set_state(State::Listen); Ok(()) } @@ -465,8 +772,33 @@ impl<'a> TcpSocket<'a> { /// The local port must be provided explicitly. Assuming `fn get_ephemeral_port() -> u16` /// allocates a port between 49152 and 65535, a connection may be established as follows: /// - /// ```rust,ignore - /// socket.connect((IpAddress::v4(10, 0, 0, 1), 80), get_ephemeral_port()) + /// ```no_run + /// # #[cfg(all( + /// # feature = "medium-ethernet", + /// # feature = "proto-ipv4", + /// # ))] + /// # { + /// # use smoltcp::socket::tcp::{Socket, SocketBuffer}; + /// # use smoltcp::iface::Interface; + /// # use smoltcp::wire::IpAddress; + /// # + /// # fn get_ephemeral_port() -> u16 { + /// # 49152 + /// # } + /// # + /// # let mut socket = Socket::new( + /// # SocketBuffer::new(vec![0; 1200]), + /// # SocketBuffer::new(vec![0; 1200]) + /// # ); + /// # + /// # let mut iface: Interface = todo!(); + /// # + /// socket.connect( + /// iface.context(), + /// (IpAddress::v4(10, 0, 0, 1), 80), + /// get_ephemeral_port() + /// ).unwrap(); + /// # } /// ``` /// /// The local address may optionally be provided. @@ -474,36 +806,72 @@ impl<'a> TcpSocket<'a> { /// This function returns an error if the socket was open; see [is_open](#method.is_open). /// It also returns an error if the local or remote port is zero, or if the remote address /// is unspecified. - pub fn connect(&mut self, remote_endpoint: T, local_endpoint: U) -> Result<()> - where T: Into, U: Into { - let remote_endpoint = remote_endpoint.into(); - let local_endpoint = local_endpoint.into(); - - if self.is_open() { return Err(Error::Illegal) } - if !remote_endpoint.is_specified() { return Err(Error::Unaddressable) } - if local_endpoint.port == 0 { return Err(Error::Unaddressable) } - - // If local address is not provided, use an unspecified address but a specified protocol. - // This lets us lower IpRepr later to determine IP header size and calculate MSS, - // but without committing to a specific address right away. - let local_addr = match local_endpoint.addr { - IpAddress::Unspecified => remote_endpoint.addr.to_unspecified(), - ip => ip, + pub fn connect( + &mut self, + cx: &mut Context, + remote_endpoint: T, + local_endpoint: U, + ) -> Result<(), ConnectError> + where + T: Into, + U: Into, + { + let remote_endpoint: IpEndpoint = remote_endpoint.into(); + let local_endpoint: IpListenEndpoint = local_endpoint.into(); + + if self.is_open() { + return Err(ConnectError::InvalidState); + } + if remote_endpoint.port == 0 || remote_endpoint.addr.is_unspecified() { + return Err(ConnectError::Unaddressable); + } + if local_endpoint.port == 0 { + return Err(ConnectError::Unaddressable); + } + + // If local address is not provided, choose it automatically. + let local_endpoint = IpEndpoint { + addr: match local_endpoint.addr { + Some(addr) => { + if addr.is_unspecified() { + return Err(ConnectError::Unaddressable); + } + addr + } + None => cx + .get_source_address(remote_endpoint.addr) + .ok_or(ConnectError::Unaddressable)?, + }, + port: local_endpoint.port, }; - let local_endpoint = IpEndpoint { addr: local_addr, ..local_endpoint }; - // Carry over the local sequence number. - let local_seq_no = self.local_seq_no; + if local_endpoint.addr.version() != remote_endpoint.addr.version() { + return Err(ConnectError::Unaddressable); + } self.reset(); - self.local_endpoint = local_endpoint; - self.remote_endpoint = remote_endpoint; - self.local_seq_no = local_seq_no; - self.remote_last_seq = local_seq_no; + self.tuple = Some(Tuple { + local: local_endpoint, + remote: remote_endpoint, + }); self.set_state(State::SynSent); + + let seq = Self::random_seq_no(cx); + self.local_seq_no = seq; + self.remote_last_seq = seq; Ok(()) } + #[cfg(test)] + fn random_seq_no(_cx: &mut Context) -> TcpSeqNumber { + TcpSeqNumber(10000) + } + + #[cfg(not(test))] + fn random_seq_no(cx: &mut Context) -> TcpSeqNumber { + TcpSeqNumber(cx.rand().rand_u32() as i32) + } + /// Close the transmit half of the full-duplex connection. /// /// Note that there is no corresponding function for the receive half of the full-duplex @@ -512,23 +880,23 @@ impl<'a> TcpSocket<'a> { pub fn close(&mut self) { match self.state { // In the LISTEN state there is no established connection. - State::Listen => - self.set_state(State::Closed), + State::Listen => self.set_state(State::Closed), // In the SYN-SENT state the remote endpoint is not yet synchronized and, upon // receiving an RST, will abort the connection. - State::SynSent => - self.set_state(State::Closed), + State::SynSent => self.set_state(State::Closed), // In the SYN-RECEIVED, ESTABLISHED and CLOSE-WAIT states the transmit half // of the connection is open, and needs to be explicitly closed with a FIN. - State::SynReceived | State::Established => - self.set_state(State::FinWait1), - State::CloseWait => - self.set_state(State::LastAck), + State::SynReceived | State::Established => self.set_state(State::FinWait1), + State::CloseWait => self.set_state(State::LastAck), // In the FIN-WAIT-1, FIN-WAIT-2, CLOSING, LAST-ACK, TIME-WAIT and CLOSED states, // the transmit half of the connection is already closed, and no further // action is needed. - State::FinWait1 | State::FinWait2 | State::Closing | - State::TimeWait | State::LastAck | State::Closed => () + State::FinWait1 + | State::FinWait2 + | State::Closing + | State::TimeWait + | State::LastAck + | State::Closed => (), } } @@ -550,7 +918,7 @@ impl<'a> TcpSocket<'a> { pub fn is_listening(&self) -> bool { match self.state { State::Listen => true, - _ => false + _ => false, } } @@ -567,7 +935,7 @@ impl<'a> TcpSocket<'a> { match self.state { State::Closed => false, State::TimeWait => false, - _ => true + _ => true, } } @@ -581,7 +949,7 @@ impl<'a> TcpSocket<'a> { /// If a connection is established, [abort](#method.close) will send a reset to /// the remote endpoint. /// - /// In terms of the TCP state machine, the socket must be in the `CLOSED`, `TIME-WAIT`, + /// In terms of the TCP state machine, the socket must not be in the `CLOSED`, `TIME-WAIT`, /// or `LISTEN` state. #[inline] pub fn is_active(&self) -> bool { @@ -589,7 +957,7 @@ impl<'a> TcpSocket<'a> { State::Closed => false, State::TimeWait => false, State::Listen => false, - _ => true + _ => true, } } @@ -609,7 +977,7 @@ impl<'a> TcpSocket<'a> { // In CLOSE-WAIT, the remote endpoint has closed our receive half of the connection // but we still can transmit indefinitely. State::CloseWait => true, - _ => false + _ => false, } } @@ -629,16 +997,18 @@ impl<'a> TcpSocket<'a> { // we still can receive indefinitely. State::FinWait1 | State::FinWait2 => true, // If we have something in the receive buffer, we can receive that. - _ if self.rx_buffer.len() > 0 => true, - _ => false + _ if !self.rx_buffer.is_empty() => true, + _ => false, } } /// Check whether the transmit half of the full-duplex connection is open - /// (see [may_send](#method.may_send), and the transmit buffer is not full. + /// (see [may_send](#method.may_send)), and the transmit buffer is not full. #[inline] pub fn can_send(&self) -> bool { - if !self.may_send() { return false } + if !self.may_send() { + return false; + } !self.tx_buffer.is_full() } @@ -656,30 +1026,40 @@ impl<'a> TcpSocket<'a> { } /// Check whether the receive half of the full-duplex connection buffer is open - /// (see [may_recv](#method.may_recv), and the receive buffer is not empty. + /// (see [may_recv](#method.may_recv)), and the receive buffer is not empty. #[inline] pub fn can_recv(&self) -> bool { - if !self.may_recv() { return false } + if !self.may_recv() { + return false; + } !self.rx_buffer.is_empty() } - fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) { - if !self.may_send() { return Err(Error::Illegal) } + fn send_impl<'b, F, R>(&'b mut self, f: F) -> Result + where + F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R), + { + if !self.may_send() { + return Err(SendError::InvalidState); + } // The connection might have been idle for a long time, and so remote_last_ts // would be far in the past. Unless we clear it here, we'll abort the connection // down over in dispatch() by erroneously detecting it as timed out. - if self.tx_buffer.is_empty() { self.remote_last_ts = None } + if self.tx_buffer.is_empty() { + self.remote_last_ts = None + } let _old_length = self.tx_buffer.len(); let (size, result) = f(&mut self.tx_buffer); if size > 0 { #[cfg(any(test, feature = "verbose"))] - net_trace!("{}:{}:{}: tx buffer: enqueueing {} octets (now {})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - size, _old_length + size); + tcp_trace!( + "tx buffer: enqueueing {} octets (now {})", + size, + _old_length + size + ); } Ok(result) } @@ -687,13 +1067,13 @@ impl<'a> TcpSocket<'a> { /// Call `f` with the largest contiguous slice of octets in the transmit buffer, /// and enqueue the amount of elements returned by `f`. /// - /// This function returns `Err(Error::Illegal) if the transmit half of + /// This function returns `Err(Error::Illegal)` if the transmit half of /// the connection is not open; see [may_send](#method.may_send). - pub fn send<'b, F, R>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut [u8]) -> (usize, R) { - self.send_impl(|tx_buffer| { - tx_buffer.enqueue_many_with(f) - }) + pub fn send<'b, F, R>(&'b mut self, f: F) -> Result + where + F: FnOnce(&'b mut [u8]) -> (usize, R), + { + self.send_impl(|tx_buffer| tx_buffer.enqueue_many_with(f)) } /// Enqueue a sequence of octets to be sent, and fill it from a slice. @@ -702,29 +1082,31 @@ impl<'a> TcpSocket<'a> { /// by the amount of free space in the transmit buffer; down to zero. /// /// See also [send](#method.send). - pub fn send_slice(&mut self, data: &[u8]) -> Result { + pub fn send_slice(&mut self, data: &[u8]) -> Result { self.send_impl(|tx_buffer| { let size = tx_buffer.enqueue_slice(data); (size, size) }) } - fn recv_error_check(&mut self) -> Result<()> { + fn recv_error_check(&mut self) -> Result<(), RecvError> { // We may have received some data inside the initial SYN, but until the connection // is fully open we must not dequeue any data, as it may be overwritten by e.g. // another (stale) SYN. (We do not support TCP Fast Open.) if !self.may_recv() { if self.rx_fin_received { - return Err(Error::Finished) + return Err(RecvError::Finished); } - return Err(Error::Illegal) + return Err(RecvError::InvalidState); } Ok(()) } - fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R) { + fn recv_impl<'b, F, R>(&'b mut self, f: F) -> Result + where + F: FnOnce(&'b mut SocketBuffer<'a>) -> (usize, R), + { self.recv_error_check()?; let _old_length = self.rx_buffer.len(); @@ -732,9 +1114,11 @@ impl<'a> TcpSocket<'a> { self.remote_seq_no += size; if size > 0 { #[cfg(any(test, feature = "verbose"))] - net_trace!("{}:{}:{}: rx buffer: dequeueing {} octets (now {})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - size, _old_length - size); + tcp_trace!( + "rx buffer: dequeueing {} octets (now {})", + size, + _old_length - size + ); } Ok(result) } @@ -749,11 +1133,11 @@ impl<'a> TcpSocket<'a> { /// /// In all other cases, `Err(Error::Illegal)` is returned and previously received data (if any) /// may be incomplete (truncated). - pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut [u8]) -> (usize, R) { - self.recv_impl(|rx_buffer| { - rx_buffer.dequeue_many_with(f) - }) + pub fn recv<'b, F, R>(&'b mut self, f: F) -> Result + where + F: FnOnce(&'b mut [u8]) -> (usize, R), + { + self.recv_impl(|rx_buffer| rx_buffer.dequeue_many_with(f)) } /// Dequeue a sequence of received octets, and fill a slice from it. @@ -762,7 +1146,7 @@ impl<'a> TcpSocket<'a> { /// by the amount of occupied space in the receive buffer; down to zero. /// /// See also [recv](#method.recv). - pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result { self.recv_impl(|rx_buffer| { let size = rx_buffer.dequeue_slice(data); (size, size) @@ -773,15 +1157,13 @@ impl<'a> TcpSocket<'a> { /// the receive buffer, and return a pointer to it. /// /// This function otherwise behaves identically to [recv](#method.recv). - pub fn peek(&mut self, size: usize) -> Result<&[u8]> { + pub fn peek(&mut self, size: usize) -> Result<&[u8], RecvError> { self.recv_error_check()?; let buffer = self.rx_buffer.get_allocated(0, size); - if buffer.len() > 0 { + if !buffer.is_empty() { #[cfg(any(test, feature = "verbose"))] - net_trace!("{}:{}:{}: rx buffer: peeking at {} octets", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - buffer.len()); + tcp_trace!("rx buffer: peeking at {} octets", buffer.len()); } Ok(buffer) } @@ -790,7 +1172,7 @@ impl<'a> TcpSocket<'a> { /// the receive buffer, and fill a slice from it. /// /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). - pub fn peek_slice(&mut self, data: &mut [u8]) -> Result { + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result { let buffer = self.peek(data.len())?; let data = &mut data[..buffer.len()]; data.copy_from_slice(buffer); @@ -815,40 +1197,42 @@ impl<'a> TcpSocket<'a> { fn set_state(&mut self, state: State) { if self.state != state { - if self.remote_endpoint.addr.is_unspecified() { - net_trace!("{}:{}: state={}=>{}", - self.meta.handle, self.local_endpoint, - self.state, state); - } else { - net_trace!("{}:{}:{}: state={}=>{}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - self.state, state); - } + tcp_trace!("state={}=>{}", self.state, state); + } + + self.state = state; + + #[cfg(feature = "async")] + { + // Wake all tasks waiting. Even if we haven't received/sent data, this + // is needed because return values of functions may change depending on the state. + // For example, a pending read has to fail with an error if the socket is closed. + self.rx_waker.wake(); + self.tx_waker.wake(); } - self.state = state } pub(crate) fn reply(ip_repr: &IpRepr, repr: &TcpRepr) -> (IpRepr, TcpRepr<'static>) { let reply_repr = TcpRepr { - src_port: repr.dst_port, - dst_port: repr.src_port, - control: TcpControl::None, - seq_number: TcpSeqNumber(0), - ack_number: None, - window_len: 0, + src_port: repr.dst_port, + dst_port: repr.src_port, + control: TcpControl::None, + seq_number: TcpSeqNumber(0), + ack_number: None, + window_len: 0, window_scale: None, max_seg_size: None, sack_permitted: false, - sack_ranges: [None, None, None], - payload: &[] - }; - let ip_reply_repr = IpRepr::Unspecified { - src_addr: ip_repr.dst_addr(), - dst_addr: ip_repr.src_addr(), - protocol: IpProtocol::Tcp, - payload_len: reply_repr.buffer_len(), - hop_limit: 64 + sack_ranges: [None, None, None], + payload: &[], }; + let ip_reply_repr = IpRepr::new( + ip_repr.dst_addr(), + ip_repr.src_addr(), + IpProtocol::Tcp, + reply_repr.buffer_len(), + 64, + ); (ip_reply_repr, reply_repr) } @@ -861,7 +1245,7 @@ impl<'a> TcpSocket<'a> { // of why we sometimes send an RST and sometimes an RST|ACK reply_repr.control = TcpControl::Rst; reply_repr.seq_number = repr.ack_number.unwrap_or_default(); - if repr.control == TcpControl::Syn { + if repr.control == TcpControl::Syn && repr.ack_number.is_none() { reply_repr.ack_number = Some(repr.seq_number + repr.segment_len()); } @@ -897,11 +1281,11 @@ impl<'a> TcpSocket<'a> { reply_repr.sack_ranges[0] = None; if let Some(last_seg_seq) = self.local_rx_last_seq.map(|s| s.0 as u32) { - reply_repr.sack_ranges[0] = self.assembler.iter_data( - reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) + reply_repr.sack_ranges[0] = self + .assembler + .iter_data(reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) .map(|(left, right)| (left as u32, right as u32)) - .skip_while(|(left, right)| *left > last_seg_seq || *right < last_seg_seq) - .next(); + .find(|(left, right)| *left <= last_seg_seq && *right >= last_seg_seq); } if reply_repr.sack_ranges[0].is_none() { @@ -909,11 +1293,12 @@ impl<'a> TcpSocket<'a> { // number has advanced, or there was no previous sACK. // // While the RFC says we SHOULD keep a list of reported sACK ranges, and iterate - // through those, that is currently infeasable. Instead, we offer the range with + // through those, that is currently infeasible. Instead, we offer the range with // the lowest sequence number (if one exists) to hint at what segments would // most quickly advance the acknowledgement number. - reply_repr.sack_ranges[0] = self.assembler.iter_data( - reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) + reply_repr.sack_ranges[0] = self + .assembler + .iter_data(reply_repr.ack_number.map(|s| s.0 as usize).unwrap_or(0)) .map(|(left, right)| (left as u32, right as u32)) .next(); } @@ -924,31 +1309,57 @@ impl<'a> TcpSocket<'a> { (ip_reply_repr, reply_repr) } - pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { - if self.state == State::Closed { return false } + fn challenge_ack_reply( + &mut self, + cx: &mut Context, + ip_repr: &IpRepr, + repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + if cx.now() < self.challenge_ack_timer { + return None; + } + + // Rate-limit to 1 per second max. + self.challenge_ack_timer = cx.now() + Duration::from_secs(1); + + return Some(self.ack_reply(ip_repr, repr)); + } + + pub(crate) fn accepts(&self, _cx: &mut Context, ip_repr: &IpRepr, repr: &TcpRepr) -> bool { + if self.state == State::Closed { + return false; + } // If we're still listening for SYNs and the packet has an ACK, it cannot // be destined to this socket, but another one may well listen on the same // local endpoint. - if self.state == State::Listen && repr.ack_number.is_some() { return false } - - // Reject packets with a wrong destination. - if self.local_endpoint.port != repr.dst_port { return false } - if !self.local_endpoint.addr.is_unspecified() && - self.local_endpoint.addr != ip_repr.dst_addr() { return false } - - // Reject packets from a source to which we aren't connected. - if self.remote_endpoint.port != 0 && - self.remote_endpoint.port != repr.src_port { return false } - if !self.remote_endpoint.addr.is_unspecified() && - self.remote_endpoint.addr != ip_repr.src_addr() { return false } + if self.state == State::Listen && repr.ack_number.is_some() { + return false; + } - true + if let Some(tuple) = &self.tuple { + // Reject packets not matching the 4-tuple + ip_repr.dst_addr() == tuple.local.addr + && repr.dst_port == tuple.local.port + && ip_repr.src_addr() == tuple.remote.addr + && repr.src_port == tuple.remote.port + } else { + // We're listening, reject packets not matching the listen endpoint. + let addr_ok = match self.listen_endpoint.addr { + Some(addr) => ip_repr.dst_addr() == addr, + None => true, + }; + addr_ok && repr.dst_port != 0 && repr.dst_port == self.listen_endpoint.port + } } - pub(crate) fn process(&mut self, timestamp: Instant, ip_repr: &IpRepr, repr: &TcpRepr) -> - Result)>> { - debug_assert!(self.accepts(ip_repr, repr)); + pub(crate) fn process( + &mut self, + cx: &mut Context, + ip_repr: &IpRepr, + repr: &TcpRepr, + ) -> Option<(IpRepr, TcpRepr<'static>)> { + debug_assert!(self.accepts(cx, ip_repr, repr)); // Consider how much the sequence number space differs from the transmit buffer space. let (sent_syn, sent_fin) = match self.state { @@ -956,111 +1367,157 @@ impl<'a> TcpSocket<'a> { State::SynSent | State::SynReceived => (true, false), // In FIN-WAIT-1, LAST-ACK, or CLOSING, we've just sent a FIN. State::FinWait1 | State::LastAck | State::Closing => (false, true), - // In all other states we've already got acknowledgemetns for + // In all other states we've already got acknowledgements for // all of the control flags we sent. - _ => (false, false) + _ => (false, false), }; let control_len = (sent_syn as usize) + (sent_fin as usize); // Reject unacceptable acknowledgements. - match (self.state, repr) { + match (self.state, repr.control, repr.ack_number) { // An RST received in response to initial SYN is acceptable if it acknowledges // the initial SYN. - (State::SynSent, &TcpRepr { - control: TcpControl::Rst, ack_number: None, .. - }) => { - net_debug!("{}:{}:{}: unacceptable RST (expecting RST|ACK) \ - in response to initial SYN", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - return Err(Error::Dropped) - } - (State::SynSent, &TcpRepr { - control: TcpControl::Rst, ack_number: Some(ack_number), .. - }) => { + (State::SynSent, TcpControl::Rst, None) => { + net_debug!("unacceptable RST (expecting RST|ACK) in response to initial SYN"); + return None; + } + (State::SynSent, TcpControl::Rst, Some(ack_number)) => { if ack_number != self.local_seq_no + 1 { - net_debug!("{}:{}:{}: unacceptable RST|ACK in response to initial SYN", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - return Err(Error::Dropped) + net_debug!("unacceptable RST|ACK in response to initial SYN"); + return None; } } // Any other RST need only have a valid sequence number. - (_, &TcpRepr { control: TcpControl::Rst, .. }) => (), + (_, TcpControl::Rst, _) => (), // The initial SYN cannot contain an acknowledgement. - (State::Listen, &TcpRepr { ack_number: None, .. }) => (), - // This case is handled above. - (State::Listen, &TcpRepr { ack_number: Some(_), .. }) => unreachable!(), + (State::Listen, _, None) => (), + // This case is handled in `accepts()`. + (State::Listen, _, Some(_)) => unreachable!(), // Every packet after the initial SYN must be an acknowledgement. - (_, &TcpRepr { ack_number: None, .. }) => { - net_debug!("{}:{}:{}: expecting an ACK", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - return Err(Error::Dropped) + (_, _, None) => { + net_debug!("expecting an ACK"); + return None; + } + // SYN|ACK in the SYN-SENT state must have the exact ACK number. + (State::SynSent, TcpControl::Syn, Some(ack_number)) => { + if ack_number != self.local_seq_no + 1 { + net_debug!("unacceptable SYN|ACK in response to initial SYN"); + return Some(Self::rst_reply(ip_repr, repr)); + } + } + // ACKs in the SYN-SENT state are invalid. + (State::SynSent, TcpControl::None, Some(ack_number)) => { + // If the sequence number matches, ignore it instead of RSTing. + // I'm not sure why, I think it may be a workaround for broken TCP + // servers, or a defense against reordering. Either way, if Linux + // does it, we do too. + if ack_number == self.local_seq_no + 1 { + net_debug!( + "expecting a SYN|ACK, received an ACK with the right ack_number, ignoring." + ); + return None; + } + + net_debug!( + "expecting a SYN|ACK, received an ACK with the wrong ack_number, sending RST." + ); + return Some(Self::rst_reply(ip_repr, repr)); + } + // Anything else in the SYN-SENT state is invalid. + (State::SynSent, _, _) => { + net_debug!("expecting a SYN|ACK"); + return None; + } + // ACK in the SYN-RECEIVED state must have the exact ACK number, or we RST it. + (State::SynReceived, _, Some(ack_number)) => { + if ack_number != self.local_seq_no + 1 { + net_debug!("unacceptable ACK in response to SYN|ACK"); + return Some(Self::rst_reply(ip_repr, repr)); + } } // Every acknowledgement must be for transmitted but unacknowledged data. - (_, &TcpRepr { ack_number: Some(ack_number), .. }) => { + (_, _, Some(ack_number)) => { let unacknowledged = self.tx_buffer.len() + control_len; - if ack_number < self.local_seq_no { - net_debug!("{}:{}:{}: duplicate ACK ({} not in {}...{})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - ack_number, self.local_seq_no, self.local_seq_no + unacknowledged); - return Err(Error::Dropped) + + // Acceptable ACK range (both inclusive) + let mut ack_min = self.local_seq_no; + let ack_max = self.local_seq_no + unacknowledged; + + // If we have sent a SYN, it MUST be acknowledged. + if sent_syn { + ack_min += 1; + } + + if ack_number < ack_min { + net_debug!( + "duplicate ACK ({} not in {}...{})", + ack_number, + ack_min, + ack_max + ); + return None; } - if ack_number > self.local_seq_no + unacknowledged { - net_debug!("{}:{}:{}: unacceptable ACK ({} not in {}...{})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - ack_number, self.local_seq_no, self.local_seq_no + unacknowledged); - return Ok(Some(self.ack_reply(ip_repr, &repr))) + if ack_number > ack_max { + net_debug!( + "unacceptable ACK ({} not in {}...{})", + ack_number, + ack_min, + ack_max + ); + return self.challenge_ack_reply(cx, ip_repr, repr); } } } - let window_start = self.remote_seq_no + self.rx_buffer.len(); - let window_end = self.remote_seq_no + self.rx_buffer.capacity(); + let window_start = self.remote_seq_no + self.rx_buffer.len(); + let window_end = self.remote_seq_no + self.rx_buffer.capacity(); let segment_start = repr.seq_number; - let segment_end = repr.seq_number + repr.segment_len(); + let segment_end = repr.seq_number + repr.segment_len(); let payload_offset; match self.state { // In LISTEN and SYN-SENT states, we have not yet synchronized with the remote end. - State::Listen | State::SynSent => - payload_offset = 0, + State::Listen | State::SynSent => payload_offset = 0, // In all other states, segments must occupy a valid portion of the receive window. _ => { let mut segment_in_window = true; if window_start == window_end && segment_start != segment_end { - net_debug!("{}:{}:{}: non-zero-length segment with zero receive window, \ - will only send an ACK", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + net_debug!( + "non-zero-length segment with zero receive window, will only send an ACK" + ); segment_in_window = false; } if segment_start == segment_end && segment_end == window_start - 1 { - net_debug!("{}:{}:{}: received a keep-alive or window probe packet, \ - will send an ACK", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + net_debug!("received a keep-alive or window probe packet, will send an ACK"); segment_in_window = false; - } else if !((window_start <= segment_start && segment_start <= window_end) && - (window_start <= segment_end && segment_end <= window_end)) { - net_debug!("{}:{}:{}: segment not in receive window \ - ({}..{} not intersecting {}..{}), will send challenge ACK", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - segment_start, segment_end, window_start, window_end); + } else if !((window_start <= segment_start && segment_start <= window_end) + && (window_start <= segment_end && segment_end <= window_end)) + { + net_debug!( + "segment not in receive window ({}..{} not intersecting {}..{}), will send challenge ACK", + segment_start, + segment_end, + window_start, + window_end + ); segment_in_window = false; } if segment_in_window { // We've checked that segment_start >= window_start above. - payload_offset = (segment_start - window_start) as usize; + payload_offset = segment_start - window_start; self.local_rx_last_seq = Some(repr.seq_number); } else { // If we're in the TIME-WAIT state, restart the TIME-WAIT timeout, since // the remote end may not have realized we've closed the connection. if self.state == State::TimeWait { - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now()); } - return Ok(Some(self.ack_reply(ip_repr, &repr))) + return self.challenge_ack_reply(cx, ip_repr, repr); } } } @@ -1069,22 +1526,29 @@ impl<'a> TcpSocket<'a> { // from the sequence space. let mut ack_len = 0; let mut ack_of_fin = false; + let mut ack_all = false; if repr.control != TcpControl::Rst { if let Some(ack_number) = repr.ack_number { - ack_len = ack_number - self.local_seq_no; - // There could have been no data sent before the SYN, so we always remove it - // from the sequence space. - if sent_syn { - ack_len -= 1 - } - // We could've sent data before the FIN, so only remove FIN from the sequence - // space if all of that data is acknowledged. - if sent_fin && self.tx_buffer.len() + 1 == ack_len { - ack_len -= 1; - net_trace!("{}:{}:{}: received ACK of FIN", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - ack_of_fin = true; + // Sequence number corresponding to the first byte in `tx_buffer`. + // This normally equals `local_seq_no`, but is 1 higher if we have sent a SYN, + // as the SYN occupies 1 sequence number "before" the data. + let tx_buffer_start_seq = self.local_seq_no + (sent_syn as usize); + + if ack_number >= tx_buffer_start_seq { + ack_len = ack_number - tx_buffer_start_seq; + + // We could've sent data before the FIN, so only remove FIN from the sequence + // space if all of that data is acknowledged. + if sent_fin && self.tx_buffer.len() + 1 == ack_len { + ack_len -= 1; + tcp_trace!("received ACK of FIN"); + ack_of_fin = true; + } + + ack_all = self.remote_last_seq == ack_number } + + self.rtte.on_ack(cx.now(), ack_number); } } @@ -1101,97 +1565,106 @@ impl<'a> TcpSocket<'a> { // Validate and update the state. match (self.state, control) { // RSTs are not accepted in the LISTEN state. - (State::Listen, TcpControl::Rst) => - return Err(Error::Dropped), + (State::Listen, TcpControl::Rst) => return None, // RSTs in SYN-RECEIVED flip the socket back to the LISTEN state. (State::SynReceived, TcpControl::Rst) => { - net_trace!("{}:{}:{}: received RST", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - self.local_endpoint.addr = self.listen_address; - self.remote_endpoint = IpEndpoint::default(); + tcp_trace!("received RST"); + self.tuple = None; self.set_state(State::Listen); - return Ok(None) + return None; } // RSTs in any other state close the socket. (_, TcpControl::Rst) => { - net_trace!("{}:{}:{}: received RST", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + tcp_trace!("received RST"); self.set_state(State::Closed); - self.local_endpoint = IpEndpoint::default(); - self.remote_endpoint = IpEndpoint::default(); - return Ok(None) + self.tuple = None; + return None; } // SYN packets in the LISTEN state change it to SYN-RECEIVED. (State::Listen, TcpControl::Syn) => { - net_trace!("{}:{}: received SYN", - self.meta.handle, self.local_endpoint); - self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); - self.remote_endpoint = IpEndpoint::new(ip_repr.src_addr(), repr.src_port); - // FIXME: use something more secure here - self.local_seq_no = TcpSeqNumber(-repr.seq_number.0); - self.remote_seq_no = repr.seq_number + 1; - self.remote_last_seq = self.local_seq_no; - self.remote_has_sack = repr.sack_permitted; + tcp_trace!("received SYN"); if let Some(max_seg_size) = repr.max_seg_size { + if max_seg_size == 0 { + tcp_trace!("received SYNACK with zero MSS, ignoring"); + return None; + } self.remote_mss = max_seg_size as usize } + + self.tuple = Some(Tuple { + local: IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port), + remote: IpEndpoint::new(ip_repr.src_addr(), repr.src_port), + }); + self.local_seq_no = Self::random_seq_no(cx); + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no; + self.remote_has_sack = repr.sack_permitted; self.remote_win_scale = repr.window_scale; - // No window scaling means don't do any window shifting + // Remote doesn't support window scaling, don't do it. if self.remote_win_scale.is_none() { self.remote_win_shift = 0; } self.set_state(State::SynReceived); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // ACK packets in the SYN-RECEIVED state change it to ESTABLISHED. (State::SynReceived, TcpControl::None) => { self.set_state(State::Established); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // FIN packets in the SYN-RECEIVED state change it to CLOSE-WAIT. // It's not obvious from RFC 793 that this is permitted, but // 7th and 8th steps in the "SEGMENT ARRIVES" event describe this behavior. (State::SynReceived, TcpControl::Fin) => { - self.remote_seq_no += 1; + self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::CloseWait); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // SYN|ACK packets in the SYN-SENT state change it to ESTABLISHED. (State::SynSent, TcpControl::Syn) => { - net_trace!("{}:{}:{}: received SYN|ACK", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - self.local_endpoint = IpEndpoint::new(ip_repr.dst_addr(), repr.dst_port); - self.remote_seq_no = repr.seq_number + 1; - self.remote_last_seq = self.local_seq_no + 1; - self.remote_last_ack = Some(repr.seq_number); + tcp_trace!("received SYN|ACK"); if let Some(max_seg_size) = repr.max_seg_size { + if max_seg_size == 0 { + tcp_trace!("received SYNACK with zero MSS, ignoring"); + return None; + } self.remote_mss = max_seg_size as usize; } + + self.remote_seq_no = repr.seq_number + 1; + self.remote_last_seq = self.local_seq_no + 1; + self.remote_last_ack = Some(repr.seq_number); + self.remote_win_scale = repr.window_scale; + // Remote doesn't support window scaling, don't do it. + if self.remote_win_scale.is_none() { + self.remote_win_shift = 0; + } + self.set_state(State::Established); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // ACK packets in ESTABLISHED state reset the retransmit timer, // except for duplicate ACK packets which preserve it. (State::Established, TcpControl::None) => { - if !self.timer.is_retransmit() || ack_len != 0 { - self.timer.set_for_idle(timestamp, self.keep_alive); + if !self.timer.is_retransmit() || ack_all { + self.timer.set_for_idle(cx.now(), self.keep_alive); } - }, + } // FIN packets in ESTABLISHED state indicate the remote side has closed. (State::Established, TcpControl::Fin) => { - self.remote_seq_no += 1; + self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::CloseWait); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // ACK packets in FIN-WAIT-1 state change it to FIN-WAIT-2, if we've already @@ -1200,80 +1673,96 @@ impl<'a> TcpSocket<'a> { if ack_of_fin { self.set_state(State::FinWait2); } - self.timer.set_for_idle(timestamp, self.keep_alive); + if ack_all { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } } // FIN packets in FIN-WAIT-1 state change it to CLOSING, or to TIME-WAIT // if they also acknowledge our FIN. (State::FinWait1, TcpControl::Fin) => { - self.remote_seq_no += 1; + self.remote_seq_no += 1; self.rx_fin_received = true; if ack_of_fin { self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now()); } else { self.set_state(State::Closing); - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } } // Data packets in FIN-WAIT-2 reset the idle timer. (State::FinWait2, TcpControl::None) => { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // FIN packets in FIN-WAIT-2 state change it to TIME-WAIT. (State::FinWait2, TcpControl::Fin) => { - self.remote_seq_no += 1; + self.remote_seq_no += 1; self.rx_fin_received = true; self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now()); } // ACK packets in CLOSING state change it to TIME-WAIT. (State::Closing, TcpControl::None) => { if ack_of_fin { self.set_state(State::TimeWait); - self.timer.set_for_close(timestamp); + self.timer.set_for_close(cx.now()); } else { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } } // ACK packets in CLOSE-WAIT state reset the retransmit timer. (State::CloseWait, TcpControl::None) => { - self.timer.set_for_idle(timestamp, self.keep_alive); + self.timer.set_for_idle(cx.now(), self.keep_alive); } // ACK packets in LAST-ACK state change it to CLOSED. (State::LastAck, TcpControl::None) => { - // Clear the remote endpoint, or we'll send an RST there. - self.set_state(State::Closed); - self.local_endpoint = IpEndpoint::default(); - self.remote_endpoint = IpEndpoint::default(); + if ack_of_fin { + // Clear the remote endpoint, or we'll send an RST there. + self.set_state(State::Closed); + self.tuple = None; + } else { + self.timer.set_for_idle(cx.now(), self.keep_alive); + } } _ => { - net_debug!("{}:{}:{}: unexpected packet {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, repr); - return Err(Error::Dropped) + net_debug!("unexpected packet {}", repr); + return None; } } // Update remote state. - self.remote_last_ts = Some(timestamp); + self.remote_last_ts = Some(cx.now()); // RFC 1323: The window field (SEG.WND) in the header of every incoming segment, with the // exception of SYN segments, is left-shifted by Snd.Wind.Scale bits before updating SND.WND. - self.remote_win_len = (repr.window_len as usize) << (self.remote_win_scale.unwrap_or(0) as usize); + let scale = match repr.control { + TcpControl::Syn => 0, + _ => self.remote_win_scale.unwrap_or(0), + }; + let new_remote_win_len = (repr.window_len as usize) << (scale as usize); + let is_window_update = new_remote_win_len != self.remote_win_len; + self.remote_win_len = new_remote_win_len; if ack_len > 0 { // Dequeue acknowledged octets. debug_assert!(self.tx_buffer.len() >= ack_len); - net_trace!("{}:{}:{}: tx buffer: dequeueing {} octets (now {})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - ack_len, self.tx_buffer.len() - ack_len); + tcp_trace!( + "tx buffer: dequeueing {} octets (now {})", + ack_len, + self.tx_buffer.len() - ack_len + ); self.tx_buffer.dequeue_allocated(ack_len); + + // There's new room available in tx_buffer, wake the waiting task if any. + #[cfg(feature = "async")] + self.tx_waker.wake(); } if let Some(ack_number) = repr.ack_number { @@ -1282,35 +1771,42 @@ impl<'a> TcpSocket<'a> { // Detect and react to duplicate ACKs by: // 1. Check if duplicate ACK and change self.local_rx_dup_acks accordingly - // 2. If exactly 3 duplicate ACKs recived, set for fast retransmit + // 2. If exactly 3 duplicate ACKs received, set for fast retransmit // 3. Update the last received ACK (self.local_rx_last_ack) match self.local_rx_last_ack { // Duplicate ACK if payload empty and ACK doesn't move send window -> - // Increment duplicate ACK count and set for retransmit if we just recived + // Increment duplicate ACK count and set for retransmit if we just received // the third duplicate ACK - Some(ref last_rx_ack) if - repr.payload.len() == 0 && - *last_rx_ack == ack_number && - ack_number < self.remote_last_seq => { + Some(last_rx_ack) + if repr.payload.is_empty() + && last_rx_ack == ack_number + && ack_number < self.remote_last_seq + && !is_window_update => + { // Increment duplicate ACK count self.local_rx_dup_acks = self.local_rx_dup_acks.saturating_add(1); - net_debug!("{}:{}:{}: received duplicate ACK for seq {} (duplicate nr {}{})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, ack_number, - self.local_rx_dup_acks, if self.local_rx_dup_acks == u8::max_value() { "+" } else { "" }); + net_debug!( + "received duplicate ACK for seq {} (duplicate nr {}{})", + ack_number, + self.local_rx_dup_acks, + if self.local_rx_dup_acks == u8::max_value() { + "+" + } else { + "" + } + ); if self.local_rx_dup_acks == 3 { self.timer.set_for_fast_retransmit(); - net_debug!("{}:{}:{}: started fast retransmit", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + net_debug!("started fast retransmit"); } - }, - // No duplicate ACK -> Reset state and update last recived ACK + } + // No duplicate ACK -> Reset state and update last received ACK _ => { if self.local_rx_dup_acks > 0 { self.local_rx_dup_acks = 0; - net_debug!("{}:{}:{}: reset duplicate ACK count", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + net_debug!("reset duplicate ACK count"); } self.local_rx_last_ack = Some(ack_number); } @@ -1329,42 +1825,74 @@ impl<'a> TcpSocket<'a> { } let payload_len = repr.payload.len(); - if payload_len == 0 { return Ok(None) } + if payload_len == 0 { + return None; + } let assembler_was_empty = self.assembler.is_empty(); // Try adding payload octets to the assembler. - match self.assembler.add(payload_offset, payload_len) { - Ok(()) => { - debug_assert!(self.assembler.total_size() == self.rx_buffer.capacity()); - // Place payload octets into the buffer. - net_trace!("{}:{}:{}: rx buffer: receiving {} octets at offset {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - payload_len, payload_offset); - self.rx_buffer.write_unallocated(payload_offset, repr.payload); - } - Err(()) => { - net_debug!("{}:{}:{}: assembler: too many holes to add {} octets at offset {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - payload_len, payload_offset); - return Err(Error::Dropped) - } - } + let Ok(contig_len) = self.assembler.add_then_remove_front(payload_offset, payload_len) else { + net_debug!( + "assembler: too many holes to add {} octets at offset {}", + payload_len, + payload_offset + ); + return None; + }; - if let Some(contig_len) = self.assembler.remove_front() { - debug_assert!(self.assembler.total_size() == self.rx_buffer.capacity()); + // Place payload octets into the buffer. + tcp_trace!( + "rx buffer: receiving {} octets at offset {}", + payload_len, + payload_offset + ); + let len_written = self + .rx_buffer + .write_unallocated(payload_offset, repr.payload); + debug_assert!(len_written == payload_len); + + if contig_len != 0 { // Enqueue the contiguous data octets in front of the buffer. - net_trace!("{}:{}:{}: rx buffer: enqueueing {} octets (now {})", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - contig_len, self.rx_buffer.len() + contig_len); + tcp_trace!( + "rx buffer: enqueueing {} octets (now {})", + contig_len, + self.rx_buffer.len() + contig_len + ); self.rx_buffer.enqueue_unallocated(contig_len); + + // There's new data in rx_buffer, notify waiting task if any. + #[cfg(feature = "async")] + self.rx_waker.wake(); } if !self.assembler.is_empty() { // Print the ranges recorded in the assembler. - net_trace!("{}:{}:{}: assembler: {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - self.assembler); + tcp_trace!("assembler: {}", self.assembler); + } + + // Handle delayed acks + if let Some(ack_delay) = self.ack_delay { + if self.ack_to_transmit() || self.window_to_update() { + self.ack_delay_timer = match self.ack_delay_timer { + AckDelayTimer::Idle => { + tcp_trace!("starting delayed ack timer"); + + AckDelayTimer::Waiting(cx.now() + ack_delay) + } + // RFC1122 says "in a stream of full-sized segments there SHOULD be an ACK + // for at least every second segment". + // For now, we send an ACK every second received packet, full-sized or not. + AckDelayTimer::Waiting(_) => { + tcp_trace!("delayed ack timer already started, forcing expiry"); + AckDelayTimer::Immediate + } + AckDelayTimer::Immediate => { + tcp_trace!("delayed ack timer already force-expired"); + AckDelayTimer::Immediate + } + }; + } } // Per RFC 5681, we should send an immediate ACK when either: @@ -1374,37 +1902,89 @@ impl<'a> TcpSocket<'a> { // Note that we change the transmitter state here. // This is fine because smoltcp assumes that it can always transmit zero or one // packets for every packet it receives. - net_trace!("{}:{}:{}: ACKing incoming segment", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - Ok(Some(self.ack_reply(ip_repr, &repr))) + tcp_trace!("ACKing incoming segment"); + Some(self.ack_reply(ip_repr, repr)) } else { - Ok(None) + None } } fn timed_out(&self, timestamp: Instant) -> bool { match (self.remote_last_ts, self.timeout) { - (Some(remote_last_ts), Some(timeout)) => - timestamp >= remote_last_ts + timeout, - (_, _) => - false + (Some(remote_last_ts), Some(timeout)) => timestamp >= remote_last_ts + timeout, + (_, _) => false, } } - fn seq_to_transmit(&self) -> bool { - let control; - match self.state { - State::SynSent | State::SynReceived => - control = TcpControl::Syn, - State::FinWait1 | State::LastAck => - control = TcpControl::Fin, - _ => control = TcpControl::None + fn seq_to_transmit(&self, cx: &mut Context) -> bool { + let ip_header_len = match self.tuple.unwrap().local.addr { + #[cfg(feature = "proto-ipv4")] + IpAddress::Ipv4(_) => crate::wire::IPV4_HEADER_LEN, + #[cfg(feature = "proto-ipv6")] + IpAddress::Ipv6(_) => crate::wire::IPV6_HEADER_LEN, + }; + + // Max segment size we're able to send due to MTU limitations. + let local_mss = cx.ip_mtu() - ip_header_len - TCP_HEADER_LEN; + + // The effective max segment size, taking into account our and remote's limits. + let effective_mss = local_mss.min(self.remote_mss); + + // Have we sent data that hasn't been ACKed yet? + let data_in_flight = self.remote_last_seq != self.local_seq_no; + + // If we want to send a SYN and we haven't done so, do it! + if matches!(self.state, State::SynSent | State::SynReceived) && !data_in_flight { + return true; } - if self.remote_win_len > 0 { - self.remote_last_seq < self.local_seq_no + self.tx_buffer.len() + control.len() + // max sequence number we can send. + let max_send_seq = + self.local_seq_no + core::cmp::min(self.remote_win_len, self.tx_buffer.len()); + + // Max amount of octets we can send. + let max_send = if max_send_seq >= self.remote_last_seq { + max_send_seq - self.remote_last_seq } else { - false + 0 + }; + + // Can we send at least 1 octet? + let mut can_send = max_send != 0; + // Can we send at least 1 full segment? + let can_send_full = max_send >= effective_mss; + + // Do we have to send a FIN? + let want_fin = match self.state { + State::FinWait1 => true, + State::Closing => true, + State::LastAck => true, + _ => false, + }; + + // If we're applying the Nagle algorithm we don't want to send more + // until one of: + // * There's no data in flight + // * We can send a full packet + // * We have all the data we'll ever send (we're closing send) + if self.nagle && data_in_flight && !can_send_full && !want_fin { + can_send = false; + } + + // Can we actually send the FIN? We can send it if: + // 1. We have unsent data that fits in the remote window. + // 2. We have no unsent data. + // This condition matches only if #2, because #1 is already covered by can_data and we're ORing them. + let can_fin = want_fin && self.remote_last_seq == self.local_seq_no + self.tx_buffer.len(); + + can_send || can_fin + } + + fn delayed_ack_expired(&self, timestamp: Instant) -> bool { + match self.ack_delay_timer { + AckDelayTimer::Idle => true, + AckDelayTimer::Waiting(t) => t <= timestamp, + AckDelayTimer::Immediate => true, } } @@ -1418,16 +1998,22 @@ impl<'a> TcpSocket<'a> { fn window_to_update(&self) -> bool { match self.state { - State::SynSent | State::SynReceived | State::Established | State::FinWait1 | State::FinWait2 => - (self.rx_buffer.window() >> self.remote_win_shift) as u16 > self.remote_last_win, + State::SynSent + | State::SynReceived + | State::Established + | State::FinWait1 + | State::FinWait2 => self.scaled_window() > self.remote_last_win, _ => false, } } - pub(crate) fn dispatch(&mut self, timestamp: Instant, caps: &DeviceCapabilities, - emit: F) -> Result<()> - where F: FnOnce((IpRepr, TcpRepr)) -> Result<()> { - if !self.remote_endpoint.is_specified() { return Err(Error::Exhausted) } + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, (IpRepr, TcpRepr)) -> Result<(), E>, + { + if self.tuple.is_none() { + return Ok(()); + } if self.remote_last_ts.is_none() { // We get here in exactly two cases: @@ -1437,84 +2023,87 @@ impl<'a> TcpSocket<'a> { // period of time, it isn't anymore, and the local endpoint is talking. // So, we start counting the timeout not from the last received packet // but from the first transmitted one. - self.remote_last_ts = Some(timestamp); + self.remote_last_ts = Some(cx.now()); } // Check if any state needs to be changed because of a timer. - if self.timed_out(timestamp) { + if self.timed_out(cx.now()) { // If a timeout expires, we should abort the connection. - net_debug!("{}:{}:{}: timeout exceeded", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + net_debug!("timeout exceeded"); self.set_state(State::Closed); - } else if !self.seq_to_transmit() { - if let Some(retransmit_delta) = self.timer.should_retransmit(timestamp) { + } else if !self.seq_to_transmit(cx) { + if let Some(retransmit_delta) = self.timer.should_retransmit(cx.now()) { // If a retransmit timer expired, we should resend data starting at the last ACK. - net_debug!("{}:{}:{}: retransmitting at t+{}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - retransmit_delta); + net_debug!("retransmitting at t+{}", retransmit_delta); + + // Rewind "last sequence number sent", as if we never + // had sent them. This will cause all data in the queue + // to be sent again. self.remote_last_seq = self.local_seq_no; + + // Clear the `should_retransmit` state. If we can't retransmit right + // now for whatever reason (like zero window), this avoids an + // infinite polling loop where `poll_at` returns `Now` but `dispatch` + // can't actually do anything. + self.timer.set_for_idle(cx.now(), self.keep_alive); + + // Inform RTTE, so that it can avoid bogus measurements. + self.rtte.on_retransmit(); } } // Decide whether we're sending a packet. - if self.seq_to_transmit() { + if self.seq_to_transmit(cx) { // If we have data to transmit and it fits into partner's window, do it. - net_trace!("{}:{}:{}: outgoing segment will send data or flags", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.ack_to_transmit() { + tcp_trace!("outgoing segment will send data or flags"); + } else if self.ack_to_transmit() && self.delayed_ack_expired(cx.now()) { // If we have data to acknowledge, do it. - net_trace!("{}:{}:{}: outgoing segment will acknowledge", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.window_to_update() { + tcp_trace!("outgoing segment will acknowledge"); + } else if self.window_to_update() && self.delayed_ack_expired(cx.now()) { // If we have window length increase to advertise, do it. - net_trace!("{}:{}:{}: outgoing segment will update window", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + tcp_trace!("outgoing segment will update window"); } else if self.state == State::Closed { // If we need to abort the connection, do it. - net_trace!("{}:{}:{}: outgoing segment will abort connection", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_retransmit(timestamp).is_some() { - // If we have packets to retransmit, do it. - net_trace!("{}:{}:{}: retransmit timer expired", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_keep_alive(timestamp) { + tcp_trace!("outgoing segment will abort connection"); + } else if self.timer.should_keep_alive(cx.now()) { // If we need to transmit a keep-alive packet, do it. - net_trace!("{}:{}:{}: keep-alive timer expired", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if self.timer.should_close(timestamp) { + tcp_trace!("keep-alive timer expired"); + } else if self.timer.should_close(cx.now()) { // If we have spent enough time in the TIME-WAIT state, close the socket. - net_trace!("{}:{}:{}: TIME-WAIT timer expired", - self.meta.handle, self.local_endpoint, self.remote_endpoint); + tcp_trace!("TIME-WAIT timer expired"); self.reset(); - return Err(Error::Exhausted) + return Ok(()); } else { - return Err(Error::Exhausted) + return Ok(()); } + // NOTE(unwrap): we check tuple is not None the first thing in this function. + let tuple = self.tuple.unwrap(); + // Construct the lowered IP representation. // We might need this to calculate the MSS, so do it early. - let mut ip_repr = IpRepr::Unspecified { - src_addr: self.local_endpoint.addr, - dst_addr: self.remote_endpoint.addr, - protocol: IpProtocol::Tcp, - hop_limit: self.hop_limit.unwrap_or(64), - payload_len: 0 - }.lower(&[])?; + let mut ip_repr = IpRepr::new( + tuple.local.addr, + tuple.remote.addr, + IpProtocol::Tcp, + 0, + self.hop_limit.unwrap_or(64), + ); // Construct the basic TCP representation, an empty ACK packet. // We'll adjust this to be more specific as needed. let mut repr = TcpRepr { - src_port: self.local_endpoint.port, - dst_port: self.remote_endpoint.port, - control: TcpControl::None, - seq_number: self.remote_last_seq, - ack_number: Some(self.remote_seq_no + self.rx_buffer.len()), - window_len: self.scaled_window(), + src_port: tuple.local.port, + dst_port: tuple.remote.port, + control: TcpControl::None, + seq_number: self.remote_last_seq, + ack_number: Some(self.remote_seq_no + self.rx_buffer.len()), + window_len: self.scaled_window(), window_scale: None, max_seg_size: None, sack_permitted: false, - sack_ranges: [None, None, None], - payload: &[] + sack_ranges: [None, None, None], + payload: &[], }; match self.state { @@ -1525,51 +2114,77 @@ impl<'a> TcpSocket<'a> { } // We never transmit anything in the LISTEN state. - State::Listen => return Err(Error::Exhausted), + State::Listen => return Ok(()), // We transmit a SYN in the SYN-SENT state. // We transmit a SYN|ACK in the SYN-RECEIVED state. State::SynSent | State::SynReceived => { repr.control = TcpControl::Syn; + // window len must NOT be scaled in SYNs. + repr.window_len = self.rx_buffer.window().min((1 << 16) - 1) as u16; if self.state == State::SynSent { repr.ack_number = None; repr.window_scale = Some(self.remote_win_shift); repr.sack_permitted = true; } else { repr.sack_permitted = self.remote_has_sack; - repr.window_scale = self.remote_win_scale.map( - |_| self.remote_win_shift); + repr.window_scale = self.remote_win_scale.map(|_| self.remote_win_shift); } } // We transmit data in all states where we may have data in the buffer, - // or the transmit half of the connection is still open: - // the ESTABLISHED, FIN-WAIT-1, CLOSE-WAIT and LAST-ACK states. - State::Established | State::FinWait1 | State::CloseWait | State::LastAck => { + // or the transmit half of the connection is still open. + State::Established + | State::FinWait1 + | State::Closing + | State::CloseWait + | State::LastAck => { // Extract as much data as the remote side can receive in this packet // from the transmit buffer. + + // Right edge of window, ie the max sequence number we're allowed to send. + let win_right_edge = self.local_seq_no + self.remote_win_len; + + // Max amount of octets we're allowed to send according to the remote window. + let win_limit = if win_right_edge >= self.remote_last_seq { + win_right_edge - self.remote_last_seq + } else { + // This can happen if we've sent some data and later the remote side + // has shrunk its window so that data is no longer inside the window. + // This should be very rare and is strongly discouraged by the RFCs, + // but it does happen in practice. + // http://www.tcpipguide.com/free/t_TCPWindowManagementIssues.htm + 0 + }; + + // Maximum size we're allowed to send. This can be limited by 3 factors: + // 1. remote window + // 2. MSS the remote is willing to accept, probably determined by their MTU + // 3. MSS we can send, determined by our MTU. + let size = win_limit + .min(self.remote_mss) + .min(cx.ip_mtu() - ip_repr.header_len() - TCP_HEADER_LEN); + let offset = self.remote_last_seq - self.local_seq_no; - let size = cmp::min(self.remote_win_len, self.remote_mss); repr.payload = self.tx_buffer.get_allocated(offset, size); + // If we've sent everything we had in the buffer, follow it with the PSH or FIN // flags, depending on whether the transmit half of the connection is open. if offset + repr.payload.len() == self.tx_buffer.len() { match self.state { - State::FinWait1 | State::LastAck => - repr.control = TcpControl::Fin, - State::Established | State::CloseWait if repr.payload.len() > 0 => - repr.control = TcpControl::Psh, - _ => () + State::FinWait1 | State::LastAck | State::Closing => { + repr.control = TcpControl::Fin + } + State::Established | State::CloseWait if !repr.payload.is_empty() => { + repr.control = TcpControl::Psh + } + _ => (), } } } - // We do not transmit anything in the FIN-WAIT-2 state. - State::FinWait2 => return Err(Error::Exhausted), - - // We do not transmit data or control flags in the CLOSING or TIME-WAIT states, - // but we may retransmit an ACK. - State::Closing | State::TimeWait => () + // In FIN-WAIT-2 and TIME-WAIT states we may only transmit ACKs for incoming data or FIN + State::FinWait2 | State::TimeWait => {} } // There might be more than one reason to send a packet. E.g. the keep-alive timer @@ -1577,9 +2192,9 @@ impl<'a> TcpSocket<'a> { // sequence space will elicit an ACK, we only need to send an explicit packet if we // couldn't fill the sequence space with anything. let is_keep_alive; - if self.timer.should_keep_alive(timestamp) && repr.is_empty() { + if self.timer.should_keep_alive(cx.now()) && repr.is_empty() { repr.seq_number = repr.seq_number - 1; - repr.payload = b"\x00"; // RFC 1122 says we should do this + repr.payload = b"\x00"; // RFC 1122 says we should do this is_keep_alive = true; } else { is_keep_alive = false; @@ -1587,34 +2202,30 @@ impl<'a> TcpSocket<'a> { // Trace a summary of what will be sent. if is_keep_alive { - net_trace!("{}:{}:{}: sending a keep-alive", - self.meta.handle, self.local_endpoint, self.remote_endpoint); - } else if repr.payload.len() > 0 { - net_trace!("{}:{}:{}: tx buffer: sending {} octets at offset {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - repr.payload.len(), self.remote_last_seq - self.local_seq_no); + tcp_trace!("sending a keep-alive"); + } else if !repr.payload.is_empty() { + tcp_trace!( + "tx buffer: sending {} octets at offset {}", + repr.payload.len(), + self.remote_last_seq - self.local_seq_no + ); } - if repr.control != TcpControl::None || repr.payload.len() == 0 { - let flags = - match (repr.control, repr.ack_number) { - (TcpControl::Syn, None) => "SYN", - (TcpControl::Syn, Some(_)) => "SYN|ACK", - (TcpControl::Fin, Some(_)) => "FIN|ACK", - (TcpControl::Rst, Some(_)) => "RST|ACK", - (TcpControl::Psh, Some(_)) => "PSH|ACK", - (TcpControl::None, Some(_)) => "ACK", - _ => "" - }; - net_trace!("{}:{}:{}: sending {}", - self.meta.handle, self.local_endpoint, self.remote_endpoint, - flags); + if repr.control != TcpControl::None || repr.payload.is_empty() { + let flags = match (repr.control, repr.ack_number) { + (TcpControl::Syn, None) => "SYN", + (TcpControl::Syn, Some(_)) => "SYN|ACK", + (TcpControl::Fin, Some(_)) => "FIN|ACK", + (TcpControl::Rst, Some(_)) => "RST|ACK", + (TcpControl::Psh, Some(_)) => "PSH|ACK", + (TcpControl::None, Some(_)) => "ACK", + _ => "", + }; + tcp_trace!("sending {}", flags); } if repr.control == TcpControl::Syn { // Fill the MSS option. See RFC 6691 for an explanation of this calculation. - let mut max_segment_size = caps.max_transmission_unit; - max_segment_size -= ip_repr.buffer_len(); - max_segment_size -= repr.mss_header_len(); + let max_segment_size = cx.ip_mtu() - ip_repr.header_len() - TCP_HEADER_LEN; repr.max_seg_size = Some(max_segment_size as u16); } @@ -1626,39 +2237,64 @@ impl<'a> TcpSocket<'a> { // to not waste time waiting for the retransmit timer on packets that we know // for sure will not be successfully transmitted. ip_repr.set_payload_len(repr.buffer_len()); - emit((ip_repr, repr))?; + emit(cx, (ip_repr, repr))?; // We've sent something, whether useful data or a keep-alive packet, so rewind // the keep-alive timer. - self.timer.rewind_keep_alive(timestamp, self.keep_alive); + self.timer.rewind_keep_alive(cx.now(), self.keep_alive); + + // Reset delayed-ack timer + match self.ack_delay_timer { + AckDelayTimer::Idle => {} + AckDelayTimer::Waiting(_) => { + tcp_trace!("stop delayed ack timer") + } + AckDelayTimer::Immediate => { + tcp_trace!("stop delayed ack timer (was force-expired)") + } + } + self.ack_delay_timer = AckDelayTimer::Idle; // Leave the rest of the state intact if sending a keep-alive packet, since those // carry a fake segment. - if is_keep_alive { return Ok(()) } + if is_keep_alive { + return Ok(()); + } // We've sent a packet successfully, so we can update the internal state now. self.remote_last_seq = repr.seq_number + repr.segment_len(); self.remote_last_ack = repr.ack_number; self.remote_last_win = repr.window_len; - if !self.seq_to_transmit() && repr.segment_len() > 0 { + if repr.segment_len() > 0 { + self.rtte + .on_send(cx.now(), repr.seq_number + repr.segment_len()); + } + + if !self.seq_to_transmit(cx) && repr.segment_len() > 0 { // If we've transmitted all data we could (and there was something at all, // data or flag, to transmit, not just an ACK), wind up the retransmit timer. - self.timer.set_for_retransmit(timestamp); + self.timer + .set_for_retransmit(cx.now(), self.rtte.retransmission_timeout()); } if self.state == State::Closed { // When aborting a connection, forget about it after sending a single RST packet. - self.local_endpoint = IpEndpoint::default(); - self.remote_endpoint = IpEndpoint::default(); + self.tuple = None; + #[cfg(feature = "async")] + { + // Wake tx now so that async users can wait for the RST to be sent + self.tx_waker.wake(); + } } Ok(()) } - pub(crate) fn poll_at(&self) -> PollAt { + #[allow(clippy::if_same_then_else)] + pub(crate) fn poll_at(&self, cx: &mut Context) -> PollAt { // The logic here mirrors the beginning of dispatch() closely. - if !self.remote_endpoint.is_specified() { + if self.tuple.is_none() { // No one to talk to, nothing to transmit. PollAt::Ingress } else if self.remote_last_ts.is_none() { @@ -1667,10 +2303,19 @@ impl<'a> TcpSocket<'a> { } else if self.state == State::Closed { // Socket was aborted, we have an RST packet to transmit. PollAt::Now - } else if self.seq_to_transmit() || self.ack_to_transmit() || self.window_to_update() { + } else if self.seq_to_transmit(cx) { // We have a data or flag packet to transmit. PollAt::Now } else { + let want_ack = self.ack_to_transmit() || self.window_to_update(); + + let delayed_ack_poll_at = match (want_ack, self.ack_delay_timer) { + (false, _) => PollAt::Ingress, + (true, AckDelayTimer::Idle) => PollAt::Now, + (true, AckDelayTimer::Waiting(t)) => PollAt::Time(t), + (true, AckDelayTimer::Immediate) => PollAt::Now, + }; + let timeout_poll_at = match (self.remote_last_ts, self.timeout) { // If we're transmitting or retransmitting data, we need to poll at the moment // when the timeout would expire. @@ -1680,21 +2325,15 @@ impl<'a> TcpSocket<'a> { }; // We wait for the earliest of our timers to fire. - *[self.timer.poll_at(), timeout_poll_at] + *[self.timer.poll_at(), timeout_poll_at, delayed_ack_poll_at] .iter() - .filter(|x| !x.is_ingress()) - .min().unwrap_or(&PollAt::Ingress) + .min() + .unwrap_or(&PollAt::Ingress) } } } -impl<'a, 'b> Into> for TcpSocket<'a> { - fn into(self) -> Socket<'a, 'b> { - Socket::Tcp(self) - } -} - -impl<'a> fmt::Write for TcpSocket<'a> { +impl<'a> fmt::Write for Socket<'a> { fn write_str(&mut self, slice: &str) -> fmt::Result { let slice = slice.as_bytes(); if self.send_slice(slice) == Ok(slice.len()) { @@ -1707,113 +2346,201 @@ impl<'a> fmt::Write for TcpSocket<'a> { #[cfg(test)] mod test { + use super::*; + use crate::wire::IpRepr; use core::i32; + use std::ops::{Deref, DerefMut}; use std::vec::Vec; - use wire::{IpAddress, IpRepr, IpCidr}; - use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3, MOCK_UNSPECIFIED}; - use super::*; // =========================================================================================// // Constants // =========================================================================================// - const LOCAL_PORT: u16 = 80; - const REMOTE_PORT: u16 = 49500; - const LOCAL_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_1, port: LOCAL_PORT }; - const REMOTE_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_2, port: REMOTE_PORT }; - const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000); - const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10000); - - const SEND_IP_TEMPL: IpRepr = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_1, dst_addr: MOCK_IP_ADDR_2, - protocol: IpProtocol::Tcp, payload_len: 20, - hop_limit: 64 + const LOCAL_PORT: u16 = 80; + const REMOTE_PORT: u16 = 49500; + const LISTEN_END: IpListenEndpoint = IpListenEndpoint { + addr: None, + port: LOCAL_PORT, + }; + const LOCAL_END: IpEndpoint = IpEndpoint { + addr: LOCAL_ADDR.into_address(), + port: LOCAL_PORT, + }; + const REMOTE_END: IpEndpoint = IpEndpoint { + addr: REMOTE_ADDR.into_address(), + port: REMOTE_PORT, + }; + const TUPLE: Tuple = Tuple { + local: LOCAL_END, + remote: REMOTE_END, }; + const LOCAL_SEQ: TcpSeqNumber = TcpSeqNumber(10000); + const REMOTE_SEQ: TcpSeqNumber = TcpSeqNumber(-10001); + + cfg_if::cfg_if! { + if #[cfg(feature = "proto-ipv4")] { + use crate::wire::Ipv4Address as IpvXAddress; + use crate::wire::Ipv4Repr as IpvXRepr; + use IpRepr::Ipv4 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 1]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 2]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 3]); + + const BASE_MSS: u16 = 1460; + } else { + use crate::wire::Ipv6Address as IpvXAddress; + use crate::wire::Ipv6Repr as IpvXRepr; + use IpRepr::Ipv6 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ]); + + const BASE_MSS: u16 = 1440; + } + } + + const SEND_IP_TEMPL: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Tcp, + payload_len: 20, + hop_limit: 64, + }); const SEND_TEMPL: TcpRepr<'static> = TcpRepr { - src_port: REMOTE_PORT, dst_port: LOCAL_PORT, + src_port: REMOTE_PORT, + dst_port: LOCAL_PORT, control: TcpControl::None, - seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)), - window_len: 256, window_scale: None, + seq_number: TcpSeqNumber(0), + ack_number: Some(TcpSeqNumber(0)), + window_len: 256, + window_scale: None, max_seg_size: None, sack_permitted: false, sack_ranges: [None, None, None], - payload: &[] - }; - const _RECV_IP_TEMPL: IpRepr = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_1, dst_addr: MOCK_IP_ADDR_2, - protocol: IpProtocol::Tcp, payload_len: 20, - hop_limit: 64 + payload: &[], }; - const RECV_TEMPL: TcpRepr<'static> = TcpRepr { - src_port: LOCAL_PORT, dst_port: REMOTE_PORT, + const _RECV_IP_TEMPL: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Tcp, + payload_len: 20, + hop_limit: 64, + }); + const RECV_TEMPL: TcpRepr<'static> = TcpRepr { + src_port: LOCAL_PORT, + dst_port: REMOTE_PORT, control: TcpControl::None, - seq_number: TcpSeqNumber(0), ack_number: Some(TcpSeqNumber(0)), - window_len: 64, window_scale: None, + seq_number: TcpSeqNumber(0), + ack_number: Some(TcpSeqNumber(0)), + window_len: 64, + window_scale: None, max_seg_size: None, sack_permitted: false, sack_ranges: [None, None, None], - payload: &[] + payload: &[], }; - #[cfg(feature = "proto-ipv6")] - const BASE_MSS: u16 = 1460; - #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] - const BASE_MSS: u16 = 1480; - // =========================================================================================// // Helper functions // =========================================================================================// - fn send(socket: &mut TcpSocket, timestamp: Instant, repr: &TcpRepr) -> - Result>> { - let ip_repr = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_2, - dst_addr: MOCK_IP_ADDR_1, - protocol: IpProtocol::Tcp, + struct TestSocket { + socket: Socket<'static>, + cx: Context, + } + + impl Deref for TestSocket { + type Target = Socket<'static>; + fn deref(&self) -> &Self::Target { + &self.socket + } + } + + impl DerefMut for TestSocket { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.socket + } + } + + fn send( + socket: &mut TestSocket, + timestamp: Instant, + repr: &TcpRepr, + ) -> Option> { + socket.cx.set_now(timestamp); + + let ip_repr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, payload_len: repr.buffer_len(), - hop_limit: 64 - }; + hop_limit: 64, + }); net_trace!("send: {}", repr); - assert!(socket.accepts(&ip_repr, repr)); - match socket.process(timestamp, &ip_repr, repr) { - Ok(Some((_ip_repr, repr))) => { + assert!(socket.socket.accepts(&mut socket.cx, &ip_repr, repr)); + + match socket.socket.process(&mut socket.cx, &ip_repr, repr) { + Some((_ip_repr, repr)) => { net_trace!("recv: {}", repr); - Ok(Some(repr)) + Some(repr) } - Ok(None) => Ok(None), - Err(err) => Err(err) + None => None, } } - fn recv(socket: &mut TcpSocket, timestamp: Instant, mut f: F) - where F: FnMut(Result) { - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1520; - let result = socket.dispatch(timestamp, &caps, |(ip_repr, tcp_repr)| { - let ip_repr = ip_repr.lower(&[IpCidr::new(LOCAL_END.addr, 24)]).unwrap(); - - assert_eq!(ip_repr.protocol(), IpProtocol::Tcp); - assert_eq!(ip_repr.src_addr(), MOCK_IP_ADDR_1); - assert_eq!(ip_repr.dst_addr(), MOCK_IP_ADDR_2); - assert_eq!(ip_repr.payload_len(), tcp_repr.buffer_len()); - - net_trace!("recv: {}", tcp_repr); - Ok(f(Ok(tcp_repr))) - }); + fn recv(socket: &mut TestSocket, timestamp: Instant, mut f: F) + where + F: FnMut(Result), + { + socket.cx.set_now(timestamp); + + let mut sent = 0; + let result = socket + .socket + .dispatch(&mut socket.cx, |_, (ip_repr, tcp_repr)| { + assert_eq!(ip_repr.next_header(), IpProtocol::Tcp); + assert_eq!(ip_repr.src_addr(), LOCAL_ADDR.into()); + assert_eq!(ip_repr.dst_addr(), REMOTE_ADDR.into()); + assert_eq!(ip_repr.payload_len(), tcp_repr.buffer_len()); + + net_trace!("recv: {}", tcp_repr); + sent += 1; + Ok(f(Ok(tcp_repr))) + }); match result { - Ok(()) => (), - Err(e) => f(Err(e)) + Ok(()) => assert_eq!(sent, 1, "Exactly one packet should be sent"), + Err(e) => f(Err(e)), } } + fn recv_nothing(socket: &mut TestSocket, timestamp: Instant) { + socket.cx.set_now(timestamp); + + let result: Result<(), ()> = socket + .socket + .dispatch(&mut socket.cx, |_, (_ip_repr, _tcp_repr)| { + panic!("Should not send a packet") + }); + + assert_eq!(result, Ok(())) + } + macro_rules! send { ($socket:ident, $repr:expr) => (send!($socket, time 0, $repr)); ($socket:ident, $repr:expr, $result:expr) => (send!($socket, time 0, $repr, $result)); ($socket:ident, time $time:expr, $repr:expr) => - (send!($socket, time $time, $repr, Ok(None))); + (send!($socket, time $time, $repr, None)); ($socket:ident, time $time:expr, $repr:expr, $result:expr) => (assert_eq!(send(&mut $socket, Instant::from_millis($time), &$repr), $result)); } @@ -1821,7 +2548,7 @@ mod test { macro_rules! recv { ($socket:ident, [$( $repr:expr ),*]) => ({ $( recv!($socket, Ok($repr)); )* - recv!($socket, Err(Error::Exhausted)) + recv_nothing!($socket) }); ($socket:ident, $result:expr) => (recv!($socket, time 0, $result)); @@ -1838,177 +2565,150 @@ mod test { (recv(&mut $socket, Instant::from_millis($time), |repr| assert_eq!(repr, $result))); } - macro_rules! sanity { - ($socket1:expr, $socket2:expr) => ({ - let (s1, s2) = ($socket1, $socket2); - assert_eq!(s1.state, s2.state, "state"); - assert_eq!(s1.listen_address, s2.listen_address, "listen_address"); - assert_eq!(s1.local_endpoint, s2.local_endpoint, "local_endpoint"); - assert_eq!(s1.remote_endpoint, s2.remote_endpoint, "remote_endpoint"); - assert_eq!(s1.local_seq_no, s2.local_seq_no, "local_seq_no"); - assert_eq!(s1.remote_seq_no, s2.remote_seq_no, "remote_seq_no"); - assert_eq!(s1.remote_last_seq, s2.remote_last_seq, "remote_last_seq"); - assert_eq!(s1.remote_last_ack, s2.remote_last_ack, "remote_last_ack"); - assert_eq!(s1.remote_last_win, s2.remote_last_win, "remote_last_win"); - assert_eq!(s1.remote_win_len, s2.remote_win_len, "remote_win_len"); - assert_eq!(s1.timer, s2.timer, "timer"); - }) - } - - #[cfg(feature = "log")] - fn init_logger() { - extern crate log; - - struct Logger; - static LOGGER: Logger = Logger; - - impl log::Log for Logger { - fn enabled(&self, _metadata: &log::Metadata) -> bool { - true - } - - fn log(&self, record: &log::Record) { - println!("{}", record.args()); - } - - fn flush(&self) { - } - } - - // If it fails, that just means we've already set it to the same value. - let _ = log::set_logger(&LOGGER); - log::set_max_level(log::LevelFilter::Trace); - - println!(""); + macro_rules! recv_nothing { + ($socket:ident) => (recv_nothing!($socket, time 0)); + ($socket:ident, time $time:expr) => (recv_nothing(&mut $socket, Instant::from_millis($time))); } - fn socket() -> TcpSocket<'static> { + macro_rules! sanity { + ($socket1:expr, $socket2:expr) => {{ + let (s1, s2) = ($socket1, $socket2); + assert_eq!(s1.state, s2.state, "state"); + assert_eq!(s1.tuple, s2.tuple, "tuple"); + assert_eq!(s1.local_seq_no, s2.local_seq_no, "local_seq_no"); + assert_eq!(s1.remote_seq_no, s2.remote_seq_no, "remote_seq_no"); + assert_eq!(s1.remote_last_seq, s2.remote_last_seq, "remote_last_seq"); + assert_eq!(s1.remote_last_ack, s2.remote_last_ack, "remote_last_ack"); + assert_eq!(s1.remote_last_win, s2.remote_last_win, "remote_last_win"); + assert_eq!(s1.remote_win_len, s2.remote_win_len, "remote_win_len"); + assert_eq!(s1.timer, s2.timer, "timer"); + }}; + } + + fn socket() -> TestSocket { socket_with_buffer_sizes(64, 64) } - fn socket_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TcpSocket<'static> { - #[cfg(feature = "log")] - init_logger(); - + fn socket_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { let rx_buffer = SocketBuffer::new(vec![0; rx_len]); let tx_buffer = SocketBuffer::new(vec![0; tx_len]); - TcpSocket::new(rx_buffer, tx_buffer) + let mut socket = Socket::new(rx_buffer, tx_buffer); + socket.set_ack_delay(None); + let cx = Context::mock(); + TestSocket { socket, cx } } - fn socket_syn_received_with_buffer_sizes( - tx_len: usize, - rx_len: usize - ) -> TcpSocket<'static> { + fn socket_syn_received_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { let mut s = socket_with_buffer_sizes(tx_len, rx_len); - s.state = State::SynReceived; - s.local_endpoint = LOCAL_END; - s.remote_endpoint = REMOTE_END; - s.local_seq_no = LOCAL_SEQ; - s.remote_seq_no = REMOTE_SEQ + 1; + s.state = State::SynReceived; + s.tuple = Some(TUPLE); + s.local_seq_no = LOCAL_SEQ; + s.remote_seq_no = REMOTE_SEQ + 1; s.remote_last_seq = LOCAL_SEQ; - s.remote_win_len = 256; + s.remote_win_len = 256; s } - fn socket_syn_received() -> TcpSocket<'static> { + fn socket_syn_received() -> TestSocket { socket_syn_received_with_buffer_sizes(64, 64) } - fn socket_syn_sent() -> TcpSocket<'static> { - let mut s = socket(); - s.state = State::SynSent; - s.local_endpoint = IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_PORT); - s.remote_endpoint = REMOTE_END; - s.local_seq_no = LOCAL_SEQ; + fn socket_syn_sent_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { + let mut s = socket_with_buffer_sizes(tx_len, rx_len); + s.state = State::SynSent; + s.tuple = Some(TUPLE); + s.local_seq_no = LOCAL_SEQ; s.remote_last_seq = LOCAL_SEQ; s } - fn socket_syn_sent_with_local_ipendpoint(local: IpEndpoint) -> TcpSocket<'static> { - let mut s = socket(); - s.state = State::SynSent; - s.local_endpoint = local; - s.remote_endpoint = REMOTE_END; - s.local_seq_no = LOCAL_SEQ; - s.remote_last_seq = LOCAL_SEQ; - s + fn socket_syn_sent() -> TestSocket { + socket_syn_sent_with_buffer_sizes(64, 64) } - fn socket_established_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TcpSocket<'static> { + fn socket_established_with_buffer_sizes(tx_len: usize, rx_len: usize) -> TestSocket { let mut s = socket_syn_received_with_buffer_sizes(tx_len, rx_len); - s.state = State::Established; - s.local_seq_no = LOCAL_SEQ + 1; + s.state = State::Established; + s.local_seq_no = LOCAL_SEQ + 1; s.remote_last_seq = LOCAL_SEQ + 1; s.remote_last_ack = Some(REMOTE_SEQ + 1); s.remote_last_win = 64; s } - fn socket_established() -> TcpSocket<'static> { + fn socket_established() -> TestSocket { socket_established_with_buffer_sizes(64, 64) } - fn socket_fin_wait_1() -> TcpSocket<'static> { + fn socket_fin_wait_1() -> TestSocket { let mut s = socket_established(); - s.state = State::FinWait1; + s.state = State::FinWait1; s } - fn socket_fin_wait_2() -> TcpSocket<'static> { + fn socket_fin_wait_2() -> TestSocket { let mut s = socket_fin_wait_1(); - s.state = State::FinWait2; - s.local_seq_no = LOCAL_SEQ + 1 + 1; + s.state = State::FinWait2; + s.local_seq_no = LOCAL_SEQ + 1 + 1; s.remote_last_seq = LOCAL_SEQ + 1 + 1; s } - fn socket_closing() -> TcpSocket<'static> { + fn socket_closing() -> TestSocket { let mut s = socket_fin_wait_1(); - s.state = State::Closing; + s.state = State::Closing; s.remote_last_seq = LOCAL_SEQ + 1 + 1; - s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; s } - fn socket_time_wait(from_closing: bool) -> TcpSocket<'static> { + fn socket_time_wait(from_closing: bool) -> TestSocket { let mut s = socket_fin_wait_2(); - s.state = State::TimeWait; - s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s.state = State::TimeWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; if from_closing { s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); } - s.timer = Timer::Close { expires_at: Instant::from_secs(1) + CLOSE_DELAY }; + s.timer = Timer::Close { + expires_at: Instant::from_secs(1) + CLOSE_DELAY, + }; s } - fn socket_close_wait() -> TcpSocket<'static> { + fn socket_close_wait() -> TestSocket { let mut s = socket_established(); - s.state = State::CloseWait; - s.remote_seq_no = REMOTE_SEQ + 1 + 1; + s.state = State::CloseWait; + s.remote_seq_no = REMOTE_SEQ + 1 + 1; s.remote_last_ack = Some(REMOTE_SEQ + 1 + 1); s } - fn socket_last_ack() -> TcpSocket<'static> { + fn socket_last_ack() -> TestSocket { let mut s = socket_close_wait(); - s.state = State::LastAck; + s.state = State::LastAck; s } - fn socket_recved() -> TcpSocket<'static> { + fn socket_recved() -> TestSocket { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); s } @@ -2017,14 +2717,14 @@ mod test { // =========================================================================================// #[test] fn test_closed_reject() { - let s = socket(); + let mut s = socket(); assert_eq!(s.state, State::Closed); let tcp_repr = TcpRepr { control: TcpControl::Syn, ..SEND_TEMPL }; - assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); } #[test] @@ -2037,7 +2737,7 @@ mod test { control: TcpControl::Syn, ..SEND_TEMPL }; - assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); } #[test] @@ -2050,49 +2750,61 @@ mod test { // =========================================================================================// // Tests for the LISTEN state. // =========================================================================================// - fn socket_listen() -> TcpSocket<'static> { + fn socket_listen() -> TestSocket { let mut s = socket(); - s.state = State::Listen; - s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT); + s.state = State::Listen; + s.listen_endpoint = LISTEN_END; s } #[test] fn test_listen_sack_option() { let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - sack_permitted: false, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + sack_permitted: false, + ..SEND_TEMPL + } + ); assert!(!s.remote_has_sack); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - sack_permitted: true, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + sack_permitted: true, + ..SEND_TEMPL + } + ); assert!(s.remote_has_sack); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - sack_permitted: true, - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + sack_permitted: true, + ..RECV_TEMPL + }] + ); } #[test] @@ -2113,25 +2825,31 @@ mod test { ] { let mut s = socket_with_buffer_sizes(64, *buffer_size); s.state = State::Listen; - s.local_endpoint = IpEndpoint::new(IpAddress::default(), LOCAL_PORT); + s.listen_endpoint = LISTEN_END; assert_eq!(s.remote_win_shift, *shift_amt); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - window_scale: Some(0), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + window_scale: Some(0), + ..SEND_TEMPL + } + ); assert_eq!(s.remote_win_shift, *shift_amt); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - window_scale: Some(*shift_amt), - window_len: cmp::min(*buffer_size >> *shift_amt, 65535) as u16, - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + window_scale: Some(*shift_amt), + window_len: cmp::min(*buffer_size, 65535) as u16, + ..RECV_TEMPL + }] + ); } } @@ -2145,31 +2863,34 @@ mod test { #[test] fn test_listen_validation() { let mut s = socket(); - assert_eq!(s.listen(0), Err(Error::Unaddressable)); + assert_eq!(s.listen(0), Err(ListenError::Unaddressable)); } #[test] fn test_listen_twice() { let mut s = socket(); assert_eq!(s.listen(80), Ok(())); - assert_eq!(s.listen(80), Err(Error::Illegal)); + assert_eq!(s.listen(80), Err(ListenError::InvalidState)); } #[test] fn test_listen_syn() { let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); sanity!(s, socket_syn_received()); } #[test] fn test_listen_syn_reject_ack() { - let s = socket_listen(); + let mut s = socket_listen(); let tcp_repr = TcpRepr { control: TcpControl::Syn, @@ -2177,7 +2898,7 @@ mod test { ack_number: Some(LOCAL_SEQ), ..SEND_TEMPL }; - assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); assert_eq!(s.state, State::Listen); } @@ -2185,12 +2906,16 @@ mod test { #[test] fn test_listen_rst() { let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, - ack_number: None, - ..SEND_TEMPL - }, Err(Error::Dropped)); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Listen); } #[test] @@ -2207,131 +2932,232 @@ mod test { #[test] fn test_syn_received_ack() { let mut s = socket_syn_received(); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Established); sanity!(s, socket_established()); } #[test] - fn test_syn_received_fin() { + fn test_syn_received_ack_too_low() { let mut s = socket_syn_received(); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6 + 1), - window_len: 58, - ..RECV_TEMPL - }]); - assert_eq!(s.state, State::CloseWait); - sanity!(s, TcpSocket { - remote_last_ack: Some(REMOTE_SEQ + 1 + 6 + 1), - remote_last_win: 58, - ..socket_close_wait() - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ), // wrong + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynReceived); } #[test] - fn test_syn_received_rst() { + fn test_syn_received_ack_too_high() { let mut s = socket_syn_received(); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ), - ..SEND_TEMPL - }); - assert_eq!(s.state, State::Listen); - assert_eq!(s.local_endpoint, IpEndpoint::new(IpAddress::Unspecified, LOCAL_END.port)); - assert_eq!(s.remote_endpoint, IpEndpoint::default()); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 2), // wrong + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 2, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynReceived); } #[test] - fn test_syn_received_no_window_scaling() { - let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - ..SEND_TEMPL - }); - assert_eq!(s.state(), State::SynReceived); - assert_eq!(s.local_endpoint(), LOCAL_END); - assert_eq!(s.remote_endpoint(), REMOTE_END); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - window_scale: None, - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - window_scale: None, - ..SEND_TEMPL - }); - assert_eq!(s.remote_win_scale, None); + fn test_syn_received_fin() { + let mut s = socket_syn_received(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 1), + window_len: 58, + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::CloseWait); + + let mut s2 = socket_close_wait(); + s2.remote_last_ack = Some(REMOTE_SEQ + 1 + 6 + 1); + s2.remote_last_win = 58; + sanity!(s, s2); } #[test] - fn test_syn_received_window_scaling() { - for scale in 0..14 { - let mut s = socket_listen(); - send!(s, TcpRepr { + fn test_syn_received_rst() { + let mut s = socket_syn_received(); + s.listen_endpoint = LISTEN_END; + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Listen); + assert_eq!(s.listen_endpoint, LISTEN_END); + assert_eq!(s.tuple, None); + } + + #[test] + fn test_syn_received_no_window_scaling() { + let mut s = socket_listen(); + send!( + s, + TcpRepr { control: TcpControl::Syn, seq_number: REMOTE_SEQ, ack_number: None, - window_scale: Some(scale), ..SEND_TEMPL - }); - assert_eq!(s.state(), State::SynReceived); - assert_eq!(s.local_endpoint(), LOCAL_END); - assert_eq!(s.remote_endpoint(), REMOTE_END); - recv!(s, [TcpRepr { + } + ); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { control: TcpControl::Syn, seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), max_seg_size: Some(BASE_MSS), - window_scale: Some(0), + window_scale: None, ..RECV_TEMPL - }]); - send!(s, TcpRepr { + }] + ); + send!( + s, + TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), window_scale: None, ..SEND_TEMPL - }); + } + ); + assert_eq!(s.remote_win_shift, 0); + assert_eq!(s.remote_win_scale, None); + } + + #[test] + fn test_syn_received_window_scaling() { + for scale in 0..14 { + let mut s = socket_listen(); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + window_scale: Some(scale), + ..SEND_TEMPL + } + ); + assert_eq!(s.state(), State::SynReceived); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_scale: None, + ..SEND_TEMPL + } + ); assert_eq!(s.remote_win_scale, Some(scale)); } } @@ -2350,143 +3176,313 @@ mod test { #[test] fn test_connect_validation() { let mut s = socket(); - assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), - Err(Error::Unaddressable)); - assert_eq!(s.connect(REMOTE_END, (MOCK_UNSPECIFIED, 0)), - Err(Error::Unaddressable)); - assert_eq!(s.connect((MOCK_UNSPECIFIED, 0), LOCAL_END), - Err(Error::Unaddressable)); - assert_eq!(s.connect((IpAddress::Unspecified, 80), LOCAL_END), - Err(Error::Unaddressable)); - s.connect(REMOTE_END, LOCAL_END).expect("Connect failed with valid parameters"); - assert_eq!(s.local_endpoint(), LOCAL_END); - assert_eq!(s.remote_endpoint(), REMOTE_END); + assert_eq!( + s.socket + .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 0)), + Err(ConnectError::Unaddressable) + ); + assert_eq!( + s.socket + .connect(&mut s.cx, REMOTE_END, (IpvXAddress::UNSPECIFIED, 1024)), + Err(ConnectError::Unaddressable) + ); + assert_eq!( + s.socket + .connect(&mut s.cx, (IpvXAddress::UNSPECIFIED, 0), LOCAL_END), + Err(ConnectError::Unaddressable) + ); + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END) + .expect("Connect failed with valid parameters"); + assert_eq!(s.tuple, Some(TUPLE)); } #[test] fn test_connect() { let mut s = socket(); s.local_seq_no = LOCAL_SEQ; - s.connect(REMOTE_END, LOCAL_END.port).unwrap(); - assert_eq!(s.local_endpoint, IpEndpoint::new(MOCK_UNSPECIFIED, LOCAL_END.port)); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: None, - max_seg_size: Some(BASE_MSS), - window_scale: Some(0), - sack_permitted: true, - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: Some(LOCAL_SEQ + 1), - max_seg_size: Some(BASE_MSS - 80), - window_scale: Some(0), - ..SEND_TEMPL - }); - assert_eq!(s.local_endpoint, LOCAL_END); + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END.port) + .unwrap(); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + } + ); + assert_eq!(s.tuple, Some(TUPLE)); } #[test] fn test_connect_unspecified_local() { let mut s = socket(); - assert_eq!(s.connect(REMOTE_END, (MOCK_UNSPECIFIED, 80)), - Ok(())); - s.abort(); - assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), - Ok(())); - s.abort(); + assert_eq!(s.socket.connect(&mut s.cx, REMOTE_END, 80), Ok(())); } #[test] fn test_connect_specified_local() { let mut s = socket(); - assert_eq!(s.connect(REMOTE_END, (MOCK_IP_ADDR_2, 80)), - Ok(())); + assert_eq!( + s.socket.connect(&mut s.cx, REMOTE_END, (REMOTE_ADDR, 80)), + Ok(()) + ); } #[test] fn test_connect_twice() { let mut s = socket(); - assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), - Ok(())); - assert_eq!(s.connect(REMOTE_END, (IpAddress::Unspecified, 80)), - Err(Error::Illegal)); + assert_eq!(s.socket.connect(&mut s.cx, REMOTE_END, 80), Ok(())); + assert_eq!( + s.socket.connect(&mut s.cx, REMOTE_END, 80), + Err(ConnectError::InvalidState) + ); } #[test] fn test_syn_sent_sanity() { let mut s = socket(); - s.local_seq_no = LOCAL_SEQ; - s.connect(REMOTE_END, LOCAL_END).unwrap(); - sanity!(s, socket_syn_sent_with_local_ipendpoint(LOCAL_END)); + s.local_seq_no = LOCAL_SEQ; + s.socket.connect(&mut s.cx, REMOTE_END, LOCAL_END).unwrap(); + sanity!(s, socket_syn_sent()); } #[test] fn test_syn_sent_syn_ack() { let mut s = socket_syn_sent(); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: None, - max_seg_size: Some(BASE_MSS), - window_scale: Some(0), - sack_permitted: true, - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: Some(LOCAL_SEQ + 1), - max_seg_size: Some(BASE_MSS - 80), - window_scale: Some(0), - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - recv!(s, time 1000, Err(Error::Exhausted)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + recv_nothing!(s, time 1000); assert_eq!(s.state, State::Established); sanity!(s, socket_established()); } + #[test] + fn test_syn_sent_syn_ack_not_incremented() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), // WRONG + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(0), + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + assert_eq!(s.state, State::SynSent); + } + #[test] fn test_syn_sent_rst() { let mut s = socket_syn_sent(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closed); } #[test] fn test_syn_sent_rst_no_ack() { let mut s = socket_syn_sent(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, - ack_number: None, - ..SEND_TEMPL - }, Err(Error::Dropped)); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::SynSent); } #[test] fn test_syn_sent_rst_bad_ack() { let mut s = socket_syn_sent(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, - ack_number: Some(TcpSeqNumber(1234)), - ..SEND_TEMPL - }, Err(Error::Dropped)); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, + ack_number: Some(TcpSeqNumber(1234)), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, // Unexpected + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), // Correct + ..SEND_TEMPL + } + ); + + // It should trigger no response and change no state + recv!(s, []); + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack_seq_1() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ), // WRONG + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ, // matching the ack_number of the unexpected ack + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + + // It should trigger a RST, and change no state + assert_eq!(s.state, State::SynSent); + } + + #[test] + fn test_syn_sent_bad_ack_seq_2() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::None, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 123456), // WRONG + ..SEND_TEMPL + }, + Some(TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 123456, // matching the ack_number of the unexpected ack + ack_number: None, + window_len: 0, + ..RECV_TEMPL + }) + ); + + // It should trigger a RST, and change no state assert_eq!(s.state, State::SynSent); } @@ -2514,18 +3510,92 @@ mod test { (1048576, 5), ] { let mut s = socket_with_buffer_sizes(64, *buffer_size); + s.local_seq_no = LOCAL_SEQ; assert_eq!(s.remote_win_shift, *shift_amt); - s.connect(REMOTE_END, LOCAL_END).unwrap(); - recv!(s, [TcpRepr { + s.socket.connect(&mut s.cx, REMOTE_END, LOCAL_END).unwrap(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(*shift_amt), + window_len: cmp::min(*buffer_size, 65535) as u16, + sack_permitted: true, + ..RECV_TEMPL + }] + ); + } + } + + #[test] + fn test_syn_sent_syn_ack_no_window_scaling() { + let mut s = socket_syn_sent_with_buffer_sizes(1048576, 1048576); + recv!( + s, + [TcpRepr { control: TcpControl::Syn, + seq_number: LOCAL_SEQ, ack_number: None, max_seg_size: Some(BASE_MSS), - window_scale: Some(*shift_amt), - window_len: cmp::min(*buffer_size >> *shift_amt, 65535) as u16, + // scaling does NOT apply to the window value in SYN packets + window_len: 65535, + window_scale: Some(5), sack_permitted: true, ..RECV_TEMPL - }]); - } + }] + ); + assert_eq!(s.remote_win_shift, 5); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: None, + window_len: 42, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_shift, 0); + assert_eq!(s.remote_win_scale, None); + assert_eq!(s.remote_win_len, 42); + } + + #[test] + fn test_syn_sent_syn_ack_window_scaling() { + let mut s = socket_syn_sent(); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: None, + max_seg_size: Some(BASE_MSS), + window_scale: Some(0), + sack_permitted: true, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + max_seg_size: Some(BASE_MSS - 80), + window_scale: Some(7), + window_len: 42, + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Established); + assert_eq!(s.remote_win_scale, Some(7)); + // scaling does NOT apply to the window value in SYN packets + assert_eq!(s.remote_win_len, 42); } // =========================================================================================// @@ -2535,22 +3605,28 @@ mod test { #[test] fn test_established_recv() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); assert_eq!(s.rx_buffer.dequeue_many(6), &b"abcdef"[..]); } - fn setup_rfc2018_cases() -> (TcpSocket<'static>, Vec) { + fn setup_rfc2018_cases() -> (TestSocket, Vec) { // This is a utility function used by the tests for RFC 2018 cases. It configures a socket // in a particular way suitable for those cases. // @@ -2563,25 +3639,34 @@ mod test { let mut segment: Vec = Vec::with_capacity(500); // move the last ack to 5000 by sending ten of them - for _ in 0..50 { segment.extend_from_slice(b"abcdefghij") } + for _ in 0..50 { + segment.extend_from_slice(b"abcdefghij") + } for offset in (0..5000).step_by(500) { - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + offset, - ack_number: Some(LOCAL_SEQ + 1), - payload: &segment, - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + offset + 500), - window_len: 3500, - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + offset, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + offset + 500), + window_len: 3500, + ..RECV_TEMPL + }] + ); s.recv(|data| { assert_eq!(data.len(), 500); assert_eq!(data, segment.as_slice()); (500, ()) - }).unwrap(); + }) + .unwrap(); } assert_eq!(s.remote_last_win, 3500); (s, segment) @@ -2614,21 +3699,29 @@ mod test { // 8500 5000 5500 9000 // for offset in (500..3500).step_by(500) { - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + offset + 5000, - ack_number: Some(LOCAL_SEQ + 1), - payload: &segment, - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 5000), - window_len: 4000, - sack_ranges: [ - Some((REMOTE_SEQ.0 as u32 + 1 + 5500, - REMOTE_SEQ.0 as u32 + 1 + 5500 + offset as u32)), - None, None], - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + offset + 5000, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 5000), + window_len: 4000, + sack_ranges: [ + Some(( + REMOTE_SEQ.0 as u32 + 1 + 5500, + REMOTE_SEQ.0 as u32 + 1 + 5500 + offset as u32 + )), + None, + None + ], + ..RECV_TEMPL + }) + ); } } @@ -2638,31 +3731,39 @@ mod test { // Update our scaling parameters for a TCP with a scaled buffer. assert_eq!(s.rx_buffer.len(), 0); s.rx_buffer = SocketBuffer::new(vec![0; 262143]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); + s.assembler = Assembler::new(); s.remote_win_scale = Some(0); s.remote_last_win = 65535; s.remote_win_shift = 2; // Create a TCP segment that will mostly fill an IP frame. let mut segment: Vec = Vec::with_capacity(1400); - for _ in 0..100 { segment.extend_from_slice(b"abcdefghijklmn") } + for _ in 0..100 { + segment.extend_from_slice(b"abcdefghijklmn") + } assert_eq!(segment.len(), 1400); // Send the frame - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &segment, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &segment, + ..SEND_TEMPL + } + ); // Ensure that the received window size is shifted right by 2. - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1400), - window_len: 65185, - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1400), + window_len: 65185, + ..RECV_TEMPL + }] + ); } #[test] @@ -2670,52 +3771,71 @@ mod test { let mut s = socket_established(); // First roundtrip after establishing. s.send_slice(b"abcdef").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); assert_eq!(s.tx_buffer.len(), 6); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); assert_eq!(s.tx_buffer.len(), 0); // Second roundtrip. s.send_slice(b"foobar").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"foobar"[..], - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6), + ..SEND_TEMPL + } + ); assert_eq!(s.tx_buffer.len(), 0); } #[test] fn test_established_send_no_ack_send() { let mut s = socket_established(); + s.set_nagle_enabled(false); s.send_slice(b"abcdef").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); s.send_slice(b"foobar").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"foobar"[..], - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); } #[test] @@ -2728,17 +3848,71 @@ mod test { let mut s = socket_established(); s.remote_win_len = 16; s.send_slice(&data[..]).unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &data[0..16], - ..RECV_TEMPL - }, TcpRepr { - seq_number: LOCAL_SEQ + 1 + 16, - ack_number: Some(REMOTE_SEQ + 1), - payload: &data[16..32], - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &data[0..16], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_established_send_window_shrink() { + let mut s = socket_established(); + + // 6 octets fit on the remote side's window, so we send them. + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + assert_eq!(s.tx_buffer.len(), 6); + + println!( + "local_seq_no={} remote_win_len={} remote_last_seq={}", + s.local_seq_no, s.remote_win_len, s.remote_last_seq + ); + + // - Peer doesn't ack them yet + // - Sends data so we need to reply with an ACK + // - ...AND and sends a window announcement that SHRINKS the window, so data we've + // previously sent is now outside the window. Yes, this is allowed by TCP. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 3, + payload: &b"xyzxyz"[..], + ..SEND_TEMPL + } + ); + assert_eq!(s.tx_buffer.len(), 6); + + println!( + "local_seq_no={} remote_win_len={} remote_last_seq={}", + s.local_seq_no, s.remote_win_len, s.remote_last_seq + ); + + // More data should not get sent since it doesn't fit in the window + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 64 - 6, + ..RECV_TEMPL + }] + ); } #[test] @@ -2759,33 +3933,43 @@ mod test { #[test] fn test_established_no_ack() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: None, - ..SEND_TEMPL - }, Err(Error::Dropped)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + } + ); } #[test] fn test_established_bad_ack() { let mut s = socket_established(); // Already acknowledged data. - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)), - ..SEND_TEMPL - }, Err(Error::Dropped)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(TcpSeqNumber(LOCAL_SEQ.0 - 1)), + ..SEND_TEMPL + } + ); assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); // Data not yet transmitted. - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 10), - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 10), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); } @@ -2793,32 +3977,70 @@ mod test { fn test_established_bad_seq() { let mut s = socket_established(); // Data outside of receive window. - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 256, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); + assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); + + // Challenge ACKs are rate-limited, we don't get a second one immediately. + send!( + s, + time 100, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + + // If we wait a bit, we do get a new one. + send!( + s, + time 2000, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 256, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); } #[test] fn test_established_fin() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); assert_eq!(s.state, State::CloseWait); sanity!(s, socket_close_wait()); } @@ -2826,29 +4048,37 @@ mod test { #[test] fn test_established_fin_after_missing() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1 + 6, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"123456"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); assert_eq!(s.state, State::Established); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6 + 6), - window_len: 52, - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6 + 6), + window_len: 52, + ..RECV_TEMPL + }) + ); assert_eq!(s.state, State::Established); } @@ -2856,42 +4086,54 @@ mod test { fn test_established_send_fin() { let mut s = socket_established(); s.send_slice(b"abcdef").unwrap(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::CloseWait); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); } #[test] fn test_established_rst() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closed); } #[test] fn test_established_rst_no_ack() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1, - ack_number: None, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1, + ack_number: None, + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closed); } @@ -2908,55 +4150,69 @@ mod test { let mut s = socket_established(); s.abort(); assert_eq!(s.state, State::Closed); - recv!(s, [TcpRepr { - control: TcpControl::Rst, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + control: TcpControl::Rst, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); } #[test] fn test_established_rst_bad_seq() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, // Wrong seq - ack_number: None, - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, // Wrong seq + ack_number: None, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); assert_eq!(s.state, State::Established); // Send something to advance seq by 1 - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, // correct seq - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"a"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, // correct seq + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"a"[..], + ..SEND_TEMPL + } + ); // Send wrong rst again, check that the challenge ack is correctly updated // The ack number must be updated even if we don't call dispatch on the socket // See https://github.com/smoltcp-rs/smoltcp/issues/338 - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ, // Wrong seq - ack_number: None, - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 2), // this has changed - window_len: 63, - ..RECV_TEMPL - }))); + send!( + s, + time 2000, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ, // Wrong seq + ack_number: None, + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 2), // this has changed + window_len: 63, + ..RECV_TEMPL + }) + ); } - // =========================================================================================// // Tests for the FIN-WAIT-1 state. // =========================================================================================// @@ -2964,17 +4220,23 @@ mod test { #[test] fn test_fin_wait_1_fin_ack() { let mut s = socket_fin_wait_1(); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait2); sanity!(s, socket_fin_wait_2()); } @@ -2982,18 +4244,24 @@ mod test { #[test] fn test_fin_wait_1_fin_fin() { let mut s = socket_fin_wait_1(); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closing); sanity!(s, socket_closing()); } @@ -3004,34 +4272,44 @@ mod test { s.remote_win_len = 6; s.send_slice(b"abcdef123456").unwrap(); s.close(); - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - })); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6), - ..SEND_TEMPL - }); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait1); } #[test] fn test_fin_wait_1_recv() { let mut s = socket_fin_wait_1(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait1); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); + }) + .unwrap(); } #[test] @@ -3061,17 +4339,29 @@ mod test { #[test] fn test_fin_wait_2_recv() { let mut s = socket_fin_wait_2(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait2); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); + }) + .unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + ..RECV_TEMPL + }] + ); } #[test] @@ -3088,11 +4378,14 @@ mod test { #[test] fn test_closing_ack_fin() { let mut s = socket_closing(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); send!(s, time 1_000, TcpRepr { seq_number: REMOTE_SEQ + 1 + 1, ack_number: Some(LOCAL_SEQ + 1 + 1), @@ -3116,11 +4409,14 @@ mod test { #[test] fn test_time_wait_from_fin_wait_2_ack() { let mut s = socket_time_wait(false); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); } #[test] @@ -3139,34 +4435,45 @@ mod test { #[test] fn test_time_wait_retransmit() { let mut s = socket_time_wait(false); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); send!(s, time 5_000, TcpRepr { control: TcpControl::Fin, seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1 + 1), ..SEND_TEMPL - }, Ok(Some(TcpRepr { + }, Some(TcpRepr { seq_number: LOCAL_SEQ + 1 + 1, ack_number: Some(REMOTE_SEQ + 1 + 1), ..RECV_TEMPL - }))); - assert_eq!(s.timer, Timer::Close { expires_at: Instant::from_secs(5) + CLOSE_DELAY }); + })); + assert_eq!( + s.timer, + Timer::Close { + expires_at: Instant::from_secs(5) + CLOSE_DELAY + } + ); } #[test] fn test_time_wait_timeout() { let mut s = socket_time_wait(false); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); assert_eq!(s.state, State::TimeWait); - recv!(s, time 60_000, Err(Error::Exhausted)); + recv_nothing!(s, time 60_000); assert_eq!(s.state, State::Closed); } @@ -3178,17 +4485,23 @@ mod test { fn test_close_wait_ack() { let mut s = socket_close_wait(); s.send_slice(b"abcdef").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + ..SEND_TEMPL + } + ); } #[test] @@ -3205,18 +4518,61 @@ mod test { #[test] fn test_last_ack_fin_ack() { let mut s = socket_last_ack(); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); assert_eq!(s.state, State::LastAck); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closed); + } + + #[test] + fn test_last_ack_ack_not_of_fin() { + let mut s = socket_last_ack(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::LastAck); + + // ACK received that doesn't ack the FIN: socket should stay in LastAck. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::LastAck); + + // ACK received of fin: socket should change to Closed. + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closed); } @@ -3234,34 +4590,42 @@ mod test { #[test] fn test_listen() { let mut s = socket(); - s.listen(IpEndpoint::new(IpAddress::default(), LOCAL_PORT)).unwrap(); + s.listen(LISTEN_END).unwrap(); assert_eq!(s.state, State::Listen); } #[test] fn test_three_way_handshake() { let mut s = socket_listen(); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + ..SEND_TEMPL + } + ); assert_eq!(s.state(), State::SynReceived); - assert_eq!(s.local_endpoint(), LOCAL_END); - assert_eq!(s.remote_endpoint(), REMOTE_END); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, - ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + assert_eq!(s.tuple, Some(TUPLE)); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state(), State::Established); assert_eq!(s.local_seq_no, LOCAL_SEQ + 1); assert_eq!(s.remote_seq_no, REMOTE_SEQ + 1); @@ -3270,31 +4634,43 @@ mod test { #[test] fn test_remote_close() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::CloseWait); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); s.close(); assert_eq!(s.state, State::LastAck); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closed); } @@ -3303,30 +4679,42 @@ mod test { let mut s = socket_established(); s.close(); assert_eq!(s.state, State::FinWait1); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait2); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::TimeWait); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); } #[test] @@ -3334,30 +4722,43 @@ mod test { let mut s = socket_established(); s.close(); assert_eq!(s.state, State::FinWait1); - recv!(s, [TcpRepr { // due to reordering, this is logically located... - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + // due to reordering, this is logically located... + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closing); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); // ... at this point - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::TimeWait); recv!(s, []); } @@ -3367,59 +4768,163 @@ mod test { let mut s = socket_established(); s.close(); assert_eq!(s.state, State::FinWait1); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::TimeWait); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); } #[test] - fn test_fin_with_data() { + fn test_simultaneous_close_raced() { let mut s = socket_established(); - s.send_slice(b"abcdef").unwrap(); s.close(); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]) + assert_eq!(s.state, State::FinWait1); + + // Socket receives FIN before it has a chance to send its own FIN + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + + // FIN + ack-of-FIN + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::Closing); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); } #[test] - fn test_mutual_close_with_data_1() { + fn test_simultaneous_close_raced_with_data() { let mut s = socket_established(); s.send_slice(b"abcdef").unwrap(); s.close(); assert_eq!(s.state, State::FinWait1); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), - ..SEND_TEMPL - }); + + // Socket receives FIN before it has a chance to send its own data+FIN + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::Closing); + + // data + FIN + ack-of-FIN + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + assert_eq!(s.state, State::Closing); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + assert_eq!(s.state, State::TimeWait); + recv!(s, []); + } + + #[test] + fn test_fin_with_data() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ) + } + + #[test] + fn test_mutual_close_with_data_1() { + let mut s = socket_established(); + s.send_slice(b"abcdef").unwrap(); + s.close(); + assert_eq!(s.state, State::FinWait1); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); } #[test] @@ -3428,30 +4933,42 @@ mod test { s.send_slice(b"abcdef").unwrap(); s.close(); assert_eq!(s.state, State::FinWait1); - recv!(s, [TcpRepr { - control: TcpControl::Fin, - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), - ..SEND_TEMPL - }); + recv!( + s, + [TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::FinWait2); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6 + 1, - ack_number: Some(REMOTE_SEQ + 1 + 1), - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 1), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 1, + ack_number: Some(REMOTE_SEQ + 1 + 1), + ..RECV_TEMPL + }] + ); assert_eq!(s.state, State::TimeWait); } @@ -3463,17 +4980,21 @@ mod test { fn test_duplicate_seq_ack() { let mut s = socket_recved(); // remote retransmission - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); } #[test] @@ -3486,8 +5007,8 @@ mod test { payload: &b"abcdef"[..], ..RECV_TEMPL })); - recv!(s, time 1050, Err(Error::Exhausted)); - recv!(s, time 1100, Ok(TcpRepr { + recv_nothing!(s, time 1050); + recv!(s, time 2000, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], @@ -3498,7 +5019,7 @@ mod test { #[test] fn test_data_retransmit_bursts() { let mut s = socket_established(); - s.remote_win_len = 6; + s.remote_mss = 6; s.send_slice(b"abcdef012345").unwrap(); recv!(s, time 0, Ok(TcpRepr { @@ -3508,7 +5029,6 @@ mod test { payload: &b"abcdef"[..], ..RECV_TEMPL }), exact); - s.remote_win_len = 6; recv!(s, time 0, Ok(TcpRepr { control: TcpControl::Psh, seq_number: LOCAL_SEQ + 1 + 6, @@ -3516,28 +5036,104 @@ mod test { payload: &b"012345"[..], ..RECV_TEMPL }), exact); - s.remote_win_len = 6; - recv!(s, time 0, Err(Error::Exhausted)); + recv_nothing!(s, time 0); - recv!(s, time 50, Err(Error::Exhausted)); + recv_nothing!(s, time 50); - recv!(s, time 100, Ok(TcpRepr { + recv!(s, time 1000, Ok(TcpRepr { control: TcpControl::None, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL }), exact); - s.remote_win_len = 6; - recv!(s, time 150, Ok(TcpRepr { + recv!(s, time 1500, Ok(TcpRepr { control: TcpControl::Psh, seq_number: LOCAL_SEQ + 1 + 6, ack_number: Some(REMOTE_SEQ + 1), payload: &b"012345"[..], ..RECV_TEMPL }), exact); - s.remote_win_len = 6; - recv!(s, time 200, Err(Error::Exhausted)); + recv_nothing!(s, time 1550); + } + + #[test] + fn test_data_retransmit_bursts_half_ack() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef012345").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + // Acknowledge the first packet + send!(s, time 5, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + // The second packet should be re-sent. + recv!(s, time 1500, Ok(TcpRepr { + control: TcpControl::Psh, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + + recv_nothing!(s, time 1550); + } + + #[test] + fn test_data_retransmit_bursts_half_ack_close() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef012345").unwrap(); + s.close(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + // Acknowledge the first packet + send!(s, time 5, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6), + window_len: 6, + ..SEND_TEMPL + }); + // The second packet should be re-sent. + recv!(s, time 1500, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"012345"[..], + ..RECV_TEMPL + }), exact); + + recv_nothing!(s, time 1550); } #[test] @@ -3550,26 +5146,32 @@ mod test { max_seg_size: Some(BASE_MSS), ..RECV_TEMPL })); - recv!(s, time 150, Ok(TcpRepr { // retransmit + recv!(s, time 750, Ok(TcpRepr { // retransmit control: TcpControl::Syn, seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), max_seg_size: Some(BASE_MSS), ..RECV_TEMPL })); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); assert_eq!(s.state(), State::Established); s.send_slice(b"abcdef").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }]) + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ) } #[test] @@ -3586,11 +5188,14 @@ mod test { // Retransmit timer is on because all data was sent assert_eq!(s.tx_buffer.len(), 3); // ACK nothing new - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); // Retransmit recv!(s, time 4000, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, @@ -3764,17 +5369,17 @@ mod test { #[test] fn test_fast_retransmit_after_triple_duplicate_ack() { let mut s = socket_established(); + s.remote_mss = 6; // Normal ACK of previously recived segment send!(s, time 0, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); // Send a long string of text divided into several packets - // because of previously recieved "window_len" + // because of previously received "window_len" s.send_slice(b"xxxxxxyyyyyywwwwwwzzzzzz").unwrap(); // This packet is lost recv!(s, time 1000, Ok(TcpRepr { @@ -3806,14 +5411,12 @@ mod test { send!(s, time 1050, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); // Second duplicate ACK send!(s, time 1055, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); // Third duplicate ACK @@ -3821,7 +5424,6 @@ mod test { send!(s, time 1060, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); @@ -3859,7 +5461,7 @@ mod test { _ => false, }); - // ACK all recived segments + // ACK all received segments send!(s, time 1120, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1 + (6 * 4)), @@ -3879,89 +5481,166 @@ mod test { ..RECV_TEMPL })); - // Normal ACK of previously recieved segment - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + // Normal ACK of previously received segment + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); // First duplicate - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); // Second duplicate - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); - assert_eq!(s.local_rx_dup_acks, 2, - "duplicate ACK counter is not set"); + assert_eq!(s.local_rx_dup_acks, 2, "duplicate ACK counter is not set"); // This packet has content, hence should not be detected // as a duplicate ACK and should reset the duplicate ACK count - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"xxxxxx"[..], - ..SEND_TEMPL - }); - - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1 + 3, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"xxxxxx"[..], + ..SEND_TEMPL + } + ); + + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 3, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); - assert_eq!(s.local_rx_dup_acks, 0, - "duplicate ACK counter is not reset when reciving data"); + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving data" + ); } #[test] - fn test_fast_retransmit_duplicate_detection() { + fn test_fast_retransmit_duplicate_detection_with_window_update() { let mut s = socket_established(); - // Normal ACK of previously recived segment - send!(s, time 0, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, - ..SEND_TEMPL - }); - - // First duplicate, should not be counted as there is nothing to resend - send!(s, time 0, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, - ..SEND_TEMPL - }); - - assert_eq!(s.local_rx_dup_acks, 0, - "duplicate ACK counter is set but wound not transmit data"); - - // Send a long string of text divided into several packets - // because of previously recieved "window_len" - s.send_slice(b"xxxxxxyyyyyywwwwwwzzzzzz").unwrap(); - - // This packet is reordered in network + s.send_slice(b"abc").unwrap(); // This is lost recv!(s, time 1000, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), - payload: &b"xxxxxx"[..], - ..RECV_TEMPL - })); - recv!(s, time 1005, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"yyyyyy"[..], + payload: &b"abc"[..], ..RECV_TEMPL })); - recv!(s, time 1010, Ok(TcpRepr { + + // Normal ACK of previously received segment + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // First duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + // Second duplicate + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + + assert_eq!(s.local_rx_dup_acks, 2, "duplicate ACK counter is not set"); + + // This packet has a window update, hence should not be detected + // as a duplicate ACK and should reset the duplicate ACK count + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 400, + ..SEND_TEMPL + } + ); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving a window update" + ); + } + + #[test] + fn test_fast_retransmit_duplicate_detection() { + let mut s = socket_established(); + s.remote_mss = 6; + + // Normal ACK of previously received segment + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + // First duplicate, should not be counted as there is nothing to resend + send!(s, time 0, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is set but wound not transmit data" + ); + + // Send a long string of text divided into several packets + // because of small remote_mss + s.send_slice(b"xxxxxxyyyyyywwwwwwzzzzzz").unwrap(); + + // This packet is reordered in network + recv!(s, time 1000, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"xxxxxx"[..], + ..RECV_TEMPL + })); + recv!(s, time 1005, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyyyy"[..], + ..RECV_TEMPL + })); + recv!(s, time 1010, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1 + (6 * 2), ack_number: Some(REMOTE_SEQ + 1), payload: &b"wwwwww"[..], @@ -3978,28 +5657,27 @@ mod test { send!(s, time 1050, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); // Second duplicate ACK send!(s, time 1055, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 6, ..SEND_TEMPL }); // Reordered packet arrives which should reset duplicate ACK count send!(s, time 1060, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1 + (6 * 3)), - window_len: 6, ..SEND_TEMPL }); - assert_eq!(s.local_rx_dup_acks, 0, - "duplicate ACK counter is not reset when reciving ACK which updates send window"); + assert_eq!( + s.local_rx_dup_acks, 0, + "duplicate ACK counter is not reset when receiving ACK which updates send window" + ); - // ACK all recived segments + // ACK all received segments send!(s, time 1120, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1 + (6 * 4)), @@ -4046,98 +5724,168 @@ mod test { ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - assert_eq!(s.local_rx_dup_acks, u8::max_value(), "duplicate ACK count should not overflow but saturate"); + assert_eq!( + s.local_rx_dup_acks, + u8::max_value(), + "duplicate ACK count should not overflow but saturate" + ); } - // =========================================================================================// - // Tests for window management. - // =========================================================================================// - #[test] - fn test_maximum_segment_size() { - let mut s = socket_listen(); - s.tx_buffer = SocketBuffer::new(vec![0; 32767]); - send!(s, TcpRepr { - control: TcpControl::Syn, - seq_number: REMOTE_SEQ, - ack_number: None, - max_seg_size: Some(1000), + fn test_fast_retransmit_zero_window() { + let mut s = socket_established(); + + send!(s, time 1000, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - recv!(s, [TcpRepr { - control: TcpControl::Syn, - seq_number: LOCAL_SEQ, + + s.send_slice(b"abc").unwrap(); + + recv!(s, time 0, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), - max_seg_size: Some(BASE_MSS), + payload: &b"abc"[..], ..RECV_TEMPL - }]); - send!(s, TcpRepr { + })); + + // 3 dup acks + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 1050, TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }); + send!(s, time 1050, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - window_len: 32767, + window_len: 0, // boom ..SEND_TEMPL }); + + // even though we're in "fast retransmit", we shouldn't + // force-send anything because the remote's window is full. + recv_nothing!(s); + } + + // =========================================================================================// + // Tests for window management. + // =========================================================================================// + + #[test] + fn test_maximum_segment_size() { + let mut s = socket_listen(); + s.tx_buffer = SocketBuffer::new(vec![0; 32767]); + send!( + s, + TcpRepr { + control: TcpControl::Syn, + seq_number: REMOTE_SEQ, + ack_number: None, + max_seg_size: Some(1000), + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + control: TcpControl::Syn, + seq_number: LOCAL_SEQ, + ack_number: Some(REMOTE_SEQ + 1), + max_seg_size: Some(BASE_MSS), + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + window_len: 32767, + ..SEND_TEMPL + } + ); s.send_slice(&[0; 1200][..]).unwrap(); - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &[0; 1000][..], - ..RECV_TEMPL - })); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &[0; 1000][..], + ..RECV_TEMPL + }) + ); } #[test] fn test_close_wait_no_window_update() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &[1,2,3,4], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &[1, 2, 3, 4], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::CloseWait); // we ack the FIN, with the reduced window size. - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 6), - window_len: 60, - ..RECV_TEMPL - })); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 6), + window_len: 60, + ..RECV_TEMPL + }) + ); let rx_buf = &mut [0; 32]; assert_eq!(s.recv_slice(rx_buf), Ok(4)); // check that we do NOT send a window update even if it has changed. - recv!(s, Err(Error::Exhausted)); + recv_nothing!(s); } #[test] fn test_time_wait_no_window_update() { let mut s = socket_fin_wait_2(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 2), - payload: &[1,2,3,4], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 2), + payload: &[1, 2, 3, 4], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::TimeWait); // we ack the FIN, with the reduced window size. - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 2, - ack_number: Some(REMOTE_SEQ + 6), - window_len: 60, - ..RECV_TEMPL - })); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 2, + ack_number: Some(REMOTE_SEQ + 6), + window_len: 60, + ..RECV_TEMPL + }) + ); let rx_buf = &mut [0; 32]; assert_eq!(s.recv_slice(rx_buf), Ok(4)); // check that we do NOT send a window update even if it has changed. - recv!(s, Err(Error::Exhausted)); + recv_nothing!(s); } // =========================================================================================// @@ -4147,7 +5895,7 @@ mod test { #[test] fn test_psh_transmit() { let mut s = socket_established(); - s.remote_win_len = 6; + s.remote_mss = 6; s.send_slice(b"abcdef").unwrap(); s.send_slice(b"123456").unwrap(); recv!(s, time 0, Ok(TcpRepr { @@ -4169,84 +5917,108 @@ mod test { #[test] fn test_psh_receive() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Psh, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }]); + send!( + s, + TcpRepr { + control: TcpControl::Psh, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }] + ); } #[test] fn test_zero_window_ack() { let mut s = socket_established(); s.rx_buffer = SocketBuffer::new(vec![0; 6]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 0, - ..RECV_TEMPL - }]); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 6, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"123456"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 0, - ..RECV_TEMPL - }))); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }] + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"123456"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }) + ); } #[test] fn test_zero_window_ack_on_window_growth() { let mut s = socket_established(); s.rx_buffer = SocketBuffer::new(vec![0; 6]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 0, - ..RECV_TEMPL - }]); - recv!(s, time 0, Err(Error::Exhausted)); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 0, + ..RECV_TEMPL + }] + ); + recv_nothing!(s, time 0); s.recv(|buffer| { assert_eq!(&buffer[..3], b"abc"); (3, ()) - }).unwrap(); + }) + .unwrap(); recv!(s, time 0, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1 + 6), window_len: 3, ..RECV_TEMPL })); - recv!(s, time 0, Err(Error::Exhausted)); + recv_nothing!(s, time 0); s.recv(|buffer| { assert_eq!(buffer, b"def"); (buffer.len(), ()) - }).unwrap(); + }) + .unwrap(); recv!(s, time 0, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1 + 6), @@ -4260,81 +6032,101 @@ mod test { let mut s = socket_established(); s.remote_mss = 6; s.send_slice(b"abcdef123456!@#$%^").unwrap(); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"abcdef"[..], - ..RECV_TEMPL - }, TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"123456"[..], - ..RECV_TEMPL - }, TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"!@#$%^"[..], - ..RECV_TEMPL - }]); + recv!( + s, + [ + TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }, + TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"123456"[..], + ..RECV_TEMPL + }, + TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"!@#$%^"[..], + ..RECV_TEMPL + } + ] + ); } #[test] fn test_announce_window_after_read() { let mut s = socket_established(); s.rx_buffer = SocketBuffer::new(vec![0; 6]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 3), - window_len: 3, - ..RECV_TEMPL - }]); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 3, + ..RECV_TEMPL + }] + ); // Test that `dispatch` updates `remote_last_win` assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); - s.recv(|buffer| { - (buffer.len(), ()) - }).unwrap(); + s.recv(|buffer| (buffer.len(), ())).unwrap(); assert!(s.window_to_update()); - recv!(s, [TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 3), - window_len: 6, - ..RECV_TEMPL - }]); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 6, + ..RECV_TEMPL + }] + ); assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); // Provoke immediate ACK to test that `process` updates `remote_last_win` - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 6, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"def"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 3), - window_len: 6, - ..RECV_TEMPL - }))); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 3, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 9), - window_len: 0, - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 6, + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 9), + window_len: 0, + ..RECV_TEMPL + }) + ); assert_eq!(s.remote_last_win, s.rx_buffer.window() as u16); - s.recv(|buffer| { - (buffer.len(), ()) - }).unwrap(); + s.recv(|buffer| (buffer.len(), ())).unwrap(); assert!(s.window_to_update()); } @@ -4346,14 +6138,16 @@ mod test { fn test_listen_timeout() { let mut s = socket_listen(); s.set_timeout(Some(Duration::from_millis(100))); - assert_eq!(s.poll_at(), PollAt::Ingress); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Ingress); } #[test] fn test_connect_timeout() { let mut s = socket(); s.local_seq_no = LOCAL_SEQ; - s.connect(REMOTE_END, LOCAL_END.port).unwrap(); + s.socket + .connect(&mut s.cx, REMOTE_END, LOCAL_END.port) + .unwrap(); s.set_timeout(Some(Duration::from_millis(100))); recv!(s, time 150, Ok(TcpRepr { control: TcpControl::Syn, @@ -4365,7 +6159,10 @@ mod test { ..RECV_TEMPL })); assert_eq!(s.state, State::SynSent); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(250))); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(250)) + ); recv!(s, time 250, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1, @@ -4379,26 +6176,35 @@ mod test { #[test] fn test_established_timeout() { let mut s = socket_established(); - s.set_timeout(Some(Duration::from_millis(200))); - recv!(s, time 250, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(450))); + s.set_timeout(Some(Duration::from_millis(1000))); + recv_nothing!(s, time 250); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(1250)) + ); s.send_slice(b"abcdef").unwrap(); - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); recv!(s, time 255, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(355))); - recv!(s, time 355, Ok(TcpRepr { + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(955)) + ); + recv!(s, time 955, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), payload: &b"abcdef"[..], ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(455))); - recv!(s, time 500, Ok(TcpRepr { + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(1255)) + ); + recv!(s, time 1255, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 6, ack_number: Some(REMOTE_SEQ + 1), @@ -4418,45 +6224,53 @@ mod test { payload: &[0], ..RECV_TEMPL })); - recv!(s, time 100, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(150))); + recv_nothing!(s, time 100); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(150)) + ); send!(s, time 105, TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(155))); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(155)) + ); recv!(s, time 155, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), payload: &[0], ..RECV_TEMPL })); - recv!(s, time 155, Err(Error::Exhausted)); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(205))); - recv!(s, time 200, Err(Error::Exhausted)); + recv_nothing!(s, time 155); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(205)) + ); + recv_nothing!(s, time 200); recv!(s, time 205, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), ..RECV_TEMPL })); - recv!(s, time 205, Err(Error::Exhausted)); + recv_nothing!(s, time 205); assert_eq!(s.state, State::Closed); } #[test] fn test_fin_wait_1_timeout() { let mut s = socket_fin_wait_1(); - s.set_timeout(Some(Duration::from_millis(200))); + s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Fin, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); - recv!(s, time 400, Ok(TcpRepr { + recv!(s, time 1100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 1, ack_number: Some(REMOTE_SEQ + 1), @@ -4468,15 +6282,14 @@ mod test { #[test] fn test_last_ack_timeout() { let mut s = socket_last_ack(); - s.set_timeout(Some(Duration::from_millis(200))); + s.set_timeout(Some(Duration::from_millis(1000))); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Fin, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1 + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); - recv!(s, time 400, Ok(TcpRepr { + recv!(s, time 1100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1 + 1, ack_number: Some(REMOTE_SEQ + 1 + 1), @@ -4491,14 +6304,14 @@ mod test { s.set_timeout(Some(Duration::from_millis(200))); s.remote_last_ts = Some(Instant::from_millis(100)); s.abort(); - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); recv!(s, time 100, Ok(TcpRepr { control: TcpControl::Rst, seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1), ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Ingress); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Ingress); } // =========================================================================================// @@ -4508,15 +6321,19 @@ mod test { #[test] fn test_responds_to_keep_alive() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); } #[test] @@ -4525,7 +6342,7 @@ mod test { s.set_keep_alive(Some(Duration::from_millis(100))); // drain the forced keep-alive packet - assert_eq!(s.poll_at(), PollAt::Now); + assert_eq!(s.socket.poll_at(&mut s.cx), PollAt::Now); recv!(s, time 0, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4533,8 +6350,11 @@ mod test { ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(100))); - recv!(s, time 95, Err(Error::Exhausted)); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(100)) + ); + recv_nothing!(s, time 95); recv!(s, time 100, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4542,8 +6362,11 @@ mod test { ..RECV_TEMPL })); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(200))); - recv!(s, time 195, Err(Error::Exhausted)); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(200)) + ); + recv_nothing!(s, time 195); recv!(s, time 200, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4556,8 +6379,11 @@ mod test { ack_number: Some(LOCAL_SEQ + 1), ..SEND_TEMPL }); - assert_eq!(s.poll_at(), PollAt::Time(Instant::from_millis(350))); - recv!(s, time 345, Err(Error::Exhausted)); + assert_eq!( + s.socket.poll_at(&mut s.cx), + PollAt::Time(Instant::from_millis(350)) + ); + recv_nothing!(s, time 345); recv!(s, time 350, Ok(TcpRepr { seq_number: LOCAL_SEQ, ack_number: Some(REMOTE_SEQ + 1), @@ -4573,14 +6399,20 @@ mod test { #[test] fn test_set_hop_limit() { let mut s = socket_syn_received(); - let mut caps = DeviceCapabilities::default(); - caps.max_transmission_unit = 1520; s.set_hop_limit(Some(0x2a)); - assert_eq!(s.dispatch(Instant::from_millis(0), &caps, |(ip_repr, _)| { - assert_eq!(ip_repr.hop_limit(), 0x2a); + assert_eq!( + s.socket.dispatch(&mut s.cx, |_, (ip_repr, _)| { + assert_eq!(ip_repr.hop_limit(), 0x2a); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); + + // assert that user-configurable settings are kept, + // see https://github.com/smoltcp-rs/smoltcp/issues/601. + s.reset(); + assert_eq!(s.hop_limit(), Some(0x2a)); } #[test] @@ -4597,58 +6429,75 @@ mod test { #[test] fn test_out_of_order() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 3, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"def"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + ..RECV_TEMPL + }) + ); s.recv(|buffer| { assert_eq!(buffer, b""); (buffer.len(), ()) - }).unwrap(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 6), - window_len: 58, - ..RECV_TEMPL - }))); + }) + .unwrap(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abcdef"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); s.recv(|buffer| { assert_eq!(buffer, b"abcdef"); (buffer.len(), ()) - }).unwrap(); + }) + .unwrap(); } #[test] fn test_buffer_wraparound_rx() { let mut s = socket_established(); s.rx_buffer = SocketBuffer::new(vec![0; 6]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + s.assembler = Assembler::new(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); s.recv(|buffer| { assert_eq!(buffer, b"abc"); (buffer.len(), ()) - }).unwrap(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 3, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"defghi"[..], - ..SEND_TEMPL - }); + }) + .unwrap(); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"defghi"[..], + ..SEND_TEMPL + } + ); let mut data = [0; 6]; assert_eq!(s.recv_slice(&mut data[..]), Ok(6)); assert_eq!(data, &b"defghi"[..]); @@ -4657,6 +6506,8 @@ mod test { #[test] fn test_buffer_wraparound_tx() { let mut s = socket_established(); + s.set_nagle_enabled(false); + s.tx_buffer = SocketBuffer::new(vec![b'.'; 9]); assert_eq!(s.send_slice(b"xxxyyy"), Ok(6)); assert_eq!(s.tx_buffer.dequeue_many(3), &b"xxx"[..]); @@ -4664,18 +6515,24 @@ mod test { // "abcdef" not contiguous in tx buffer assert_eq!(s.send_slice(b"abcdef"), Ok(6)); - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"yyyabc"[..], - ..RECV_TEMPL - })); - recv!(s, Ok(TcpRepr { - seq_number: LOCAL_SEQ + 1 + 6, - ack_number: Some(REMOTE_SEQ + 1), - payload: &b"def"[..], - ..RECV_TEMPL - })); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"yyyabc"[..], + ..RECV_TEMPL + }) + ); + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"def"[..], + ..RECV_TEMPL + }) + ); } // =========================================================================================// @@ -4685,151 +6542,465 @@ mod test { #[test] fn test_rx_close_fin() { let mut s = socket_established(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); - assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished)); + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); } #[test] fn test_rx_close_fin_in_fin_wait_1() { let mut s = socket_fin_wait_1(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::Closing); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); - assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished)); + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); } #[test] fn test_rx_close_fin_in_fin_wait_2() { let mut s = socket_fin_wait_2(); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1 + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); assert_eq!(s.state, State::TimeWait); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); - assert_eq!(s.recv(|_| (0, ())), Err(Error::Finished)); + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::Finished)); } - - #[test] fn test_rx_close_fin_with_hole() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); - send!(s, TcpRepr { - control: TcpControl::Fin, - seq_number: REMOTE_SEQ + 1 + 6, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"ghi"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { - seq_number: LOCAL_SEQ + 1, - ack_number: Some(REMOTE_SEQ + 1 + 3), - window_len: 61, - ..RECV_TEMPL - }))); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + control: TcpControl::Fin, + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 61, + ..RECV_TEMPL + }) + ); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); + }) + .unwrap(); s.recv(|data| { assert_eq!(data, b""); (0, ()) - }).unwrap(); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1 + 9, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + }) + .unwrap(); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 9, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); // Error must be `Illegal` even if we've received a FIN, // because we are missing data. - assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal)); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); } #[test] fn test_rx_close_rst() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1 + 3, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); - assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal)); + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); } #[test] fn test_rx_close_rst_with_hole() { let mut s = socket_established(); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abc"[..], - ..SEND_TEMPL - }); - send!(s, TcpRepr { - seq_number: REMOTE_SEQ + 1 + 6, - ack_number: Some(LOCAL_SEQ + 1), - payload: &b"ghi"[..], - ..SEND_TEMPL - }, Ok(Some(TcpRepr { + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + }, + Some(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + window_len: 61, + ..RECV_TEMPL + }) + ); + send!( + s, + TcpRepr { + control: TcpControl::Rst, + seq_number: REMOTE_SEQ + 1 + 9, + ack_number: Some(LOCAL_SEQ + 1), + ..SEND_TEMPL + } + ); + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + assert_eq!(s.recv(|_| (0, ())), Err(RecvError::InvalidState)); + } + + // =========================================================================================// + // Tests for delayed ACK + // =========================================================================================// + + #[test] + fn test_delayed_ack() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + // After 10ms, it is sent. + recv!(s, time 11, Ok(TcpRepr { seq_number: LOCAL_SEQ + 1, ack_number: Some(REMOTE_SEQ + 1 + 3), window_len: 61, ..RECV_TEMPL - }))); - send!(s, TcpRepr { - control: TcpControl::Rst, - seq_number: REMOTE_SEQ + 1 + 9, - ack_number: Some(LOCAL_SEQ + 1), - ..SEND_TEMPL - }); + })); + } + + #[test] + fn test_delayed_ack_win() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // Reading the data off the buffer should cause a window update. s.recv(|data| { assert_eq!(data, b"abc"); (3, ()) - }).unwrap(); - assert_eq!(s.recv(|_| (0, ())), Err(Error::Illegal)); + }) + .unwrap(); + + // However, no ACK or window update is immediately sent. + recv_nothing!(s); + + // After 10ms, it is sent. + recv!(s, time 11, Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + ..RECV_TEMPL + })); + } + + #[test] + fn test_delayed_ack_reply() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + s.recv(|data| { + assert_eq!(data, b"abc"); + (3, ()) + }) + .unwrap(); + + s.send_slice(&b"xyz"[..]).unwrap(); + + // Writing data to the socket causes ACK to not be delayed, + // because it is immediately sent with the data. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 3), + payload: &b"xyz"[..], + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_delayed_ack_every_second_packet() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + } + ); + + // Every 2nd packet, ACK is sent without delay. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 6), + window_len: 58, + ..RECV_TEMPL + }) + ); + } + + #[test] + fn test_delayed_ack_three_packets() { + let mut s = socket_established(); + s.set_ack_delay(Some(ACK_DELAY_DEFAULT)); + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"abc"[..], + ..SEND_TEMPL + } + ); + + // No ACK is immediately sent. + recv_nothing!(s); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 3, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"def"[..], + ..SEND_TEMPL + } + ); + + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1 + 6, + ack_number: Some(LOCAL_SEQ + 1), + payload: &b"ghi"[..], + ..SEND_TEMPL + } + ); + + // Every 2nd (or more) packet, ACK is sent without delay. + recv!( + s, + Ok(TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1 + 9), + window_len: 55, + ..RECV_TEMPL + }) + ); + } + + // =========================================================================================// + // Tests for Nagle's Algorithm + // =========================================================================================// + + #[test] + fn test_nagle() { + let mut s = socket_established(); + s.remote_mss = 6; + + s.send_slice(b"abcdef").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }] + ); + + // If there's data in flight, full segments get sent. + s.send_slice(b"foobar").unwrap(); + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"foobar"[..], + ..RECV_TEMPL + }] + ); + + s.send_slice(b"aaabbbccc").unwrap(); + // If there's data in flight, not-full segments don't get sent. + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"aaabbb"[..], + ..RECV_TEMPL + }] + ); + + // Data gets ACKd, so there's no longer data in flight + send!( + s, + TcpRepr { + seq_number: REMOTE_SEQ + 1, + ack_number: Some(LOCAL_SEQ + 1 + 6 + 6 + 6), + ..SEND_TEMPL + } + ); + + // Now non-full segment gets sent. + recv!( + s, + [TcpRepr { + seq_number: LOCAL_SEQ + 1 + 6 + 6 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"ccc"[..], + ..RECV_TEMPL + }] + ); + } + + #[test] + fn test_final_packet_in_stream_doesnt_wait_for_nagle() { + let mut s = socket_established(); + s.remote_mss = 6; + s.send_slice(b"abcdef0").unwrap(); + s.socket.close(); + + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::None, + seq_number: LOCAL_SEQ + 1, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"abcdef"[..], + ..RECV_TEMPL + }), exact); + recv!(s, time 0, Ok(TcpRepr { + control: TcpControl::Fin, + seq_number: LOCAL_SEQ + 1 + 6, + ack_number: Some(REMOTE_SEQ + 1), + payload: &b"0"[..], + ..RECV_TEMPL + }), exact); } // =========================================================================================// @@ -4840,62 +7011,62 @@ mod test { fn test_doesnt_accept_wrong_port() { let mut s = socket_established(); s.rx_buffer = SocketBuffer::new(vec![0; 6]); - s.assembler = Assembler::new(s.rx_buffer.capacity()); + s.assembler = Assembler::new(); let tcp_repr = TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - dst_port: LOCAL_PORT + 1, + dst_port: LOCAL_PORT + 1, ..SEND_TEMPL }; - assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); let tcp_repr = TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - src_port: REMOTE_PORT + 1, + src_port: REMOTE_PORT + 1, ..SEND_TEMPL }; - assert!(!s.accepts(&SEND_IP_TEMPL, &tcp_repr)); + assert!(!s.socket.accepts(&mut s.cx, &SEND_IP_TEMPL, &tcp_repr)); } #[test] fn test_doesnt_accept_wrong_ip() { - let s = socket_established(); + let mut s = socket_established(); let tcp_repr = TcpRepr { seq_number: REMOTE_SEQ + 1, ack_number: Some(LOCAL_SEQ + 1), - payload: &b"abcdef"[..], + payload: &b"abcdef"[..], ..SEND_TEMPL }; - let ip_repr = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_2, - dst_addr: MOCK_IP_ADDR_1, - protocol: IpProtocol::Tcp, + let ip_repr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, payload_len: tcp_repr.buffer_len(), - hop_limit: 64 - }; - assert!(s.accepts(&ip_repr, &tcp_repr)); + hop_limit: 64, + }); + assert!(s.socket.accepts(&mut s.cx, &ip_repr, &tcp_repr)); - let ip_repr_wrong_src = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_3, - dst_addr: MOCK_IP_ADDR_1, - protocol: IpProtocol::Tcp, + let ip_repr_wrong_src = IpReprIpvX(IpvXRepr { + src_addr: OTHER_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Tcp, payload_len: tcp_repr.buffer_len(), - hop_limit: 64 - }; - assert!(!s.accepts(&ip_repr_wrong_src, &tcp_repr)); + hop_limit: 64, + }); + assert!(!s.socket.accepts(&mut s.cx, &ip_repr_wrong_src, &tcp_repr)); - let ip_repr_wrong_dst = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_2, - dst_addr: MOCK_IP_ADDR_3, - protocol: IpProtocol::Tcp, + let ip_repr_wrong_dst = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: OTHER_ADDR, + next_header: IpProtocol::Tcp, payload_len: tcp_repr.buffer_len(), - hop_limit: 64 - }; - assert!(!s.accepts(&ip_repr_wrong_dst, &tcp_repr)); + hop_limit: 64, + }); + assert!(!s.socket.accepts(&mut s.cx, &ip_repr_wrong_dst, &tcp_repr)); } // =========================================================================================// @@ -4904,19 +7075,40 @@ mod test { #[test] fn test_timer_retransmit() { - let mut r = Timer::default(); + const RTO: Duration = Duration::from_millis(100); + let mut r = Timer::new(); assert_eq!(r.should_retransmit(Instant::from_secs(1)), None); - r.set_for_retransmit(Instant::from_millis(1000)); + r.set_for_retransmit(Instant::from_millis(1000), RTO); assert_eq!(r.should_retransmit(Instant::from_millis(1000)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1050)), None); - assert_eq!(r.should_retransmit(Instant::from_millis(1101)), Some(Duration::from_millis(101))); - r.set_for_retransmit(Instant::from_millis(1101)); + assert_eq!( + r.should_retransmit(Instant::from_millis(1101)), + Some(Duration::from_millis(101)) + ); + r.set_for_retransmit(Instant::from_millis(1101), RTO); assert_eq!(r.should_retransmit(Instant::from_millis(1101)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1150)), None); assert_eq!(r.should_retransmit(Instant::from_millis(1200)), None); - assert_eq!(r.should_retransmit(Instant::from_millis(1301)), Some(Duration::from_millis(300))); + assert_eq!( + r.should_retransmit(Instant::from_millis(1301)), + Some(Duration::from_millis(300)) + ); r.set_for_idle(Instant::from_millis(1301), None); assert_eq!(r.should_retransmit(Instant::from_millis(1350)), None); } + #[test] + fn test_rtt_estimator() { + let mut r = RttEstimator::default(); + + let rtos = &[ + 751, 766, 755, 731, 697, 656, 613, 567, 523, 484, 445, 411, 378, 350, 322, 299, 280, + 261, 243, 229, 215, 206, 197, 188, + ]; + + for &rto in rtos { + r.sample(100); + assert_eq!(r.retransmission_timeout(), Duration::from_millis(rto)); + } + } } diff --git a/src/socket/udp.rs b/src/socket/udp.rs index 7818c0fcd..39172dc8e 100644 --- a/src/socket/udp.rs +++ b/src/socket/udp.rs @@ -1,52 +1,176 @@ use core::cmp::min; +#[cfg(feature = "async")] +use core::task::Waker; + +use crate::iface::Context; +use crate::phy::PacketMeta; +use crate::socket::PollAt; +#[cfg(feature = "async")] +use crate::socket::WakerRegistration; +use crate::storage::Empty; +use crate::wire::{IpEndpoint, IpListenEndpoint, IpProtocol, IpRepr, UdpRepr}; + +/// Metadata for a sent or received UDP packet. +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub struct UdpMetadata { + pub endpoint: IpEndpoint, + pub meta: PacketMeta, +} + +impl> From for UdpMetadata { + fn from(value: T) -> Self { + Self { + endpoint: value.into(), + meta: PacketMeta::default(), + } + } +} + +impl core::fmt::Display for UdpMetadata { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + #[cfg(feature = "packetmeta-id")] + return write!(f, "{}, PacketID: {:?}", self.endpoint, self.meta); -use {Error, Result}; -use socket::{Socket, SocketMeta, SocketHandle, PollAt}; -use storage::{PacketBuffer, PacketMetadata}; -use wire::{IpProtocol, IpRepr, IpEndpoint, UdpRepr}; + #[cfg(not(feature = "packetmeta-id"))] + write!(f, "{}", self.endpoint) + } +} /// A UDP packet metadata. -pub type UdpPacketMetadata = PacketMetadata; +pub type PacketMetadata = crate::storage::PacketMetadata; /// A UDP packet ring buffer. -pub type UdpSocketBuffer<'a, 'b> = PacketBuffer<'a, 'b, IpEndpoint>; +pub type PacketBuffer<'a> = crate::storage::PacketBuffer<'a, UdpMetadata>; + +/// Error returned by [`Socket::bind`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum BindError { + InvalidState, + Unaddressable, +} + +impl core::fmt::Display for BindError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + BindError::InvalidState => write!(f, "invalid state"), + BindError::Unaddressable => write!(f, "unaddressable"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for BindError {} + +/// Error returned by [`Socket::send`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SendError { + Unaddressable, + BufferFull, +} + +impl core::fmt::Display for SendError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + SendError::Unaddressable => write!(f, "unaddressable"), + SendError::BufferFull => write!(f, "buffer full"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for SendError {} + +/// Error returned by [`Socket::recv`] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecvError { + Exhausted, +} + +impl core::fmt::Display for RecvError { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + RecvError::Exhausted => write!(f, "exhausted"), + } + } +} + +#[cfg(feature = "std")] +impl std::error::Error for RecvError {} /// A User Datagram Protocol socket. /// /// A UDP socket is bound to a specific endpoint, and owns transmit and receive /// packet buffers. #[derive(Debug)] -pub struct UdpSocket<'a, 'b: 'a> { - pub(crate) meta: SocketMeta, - endpoint: IpEndpoint, - rx_buffer: UdpSocketBuffer<'a, 'b>, - tx_buffer: UdpSocketBuffer<'a, 'b>, +pub struct Socket<'a> { + endpoint: IpListenEndpoint, + rx_buffer: PacketBuffer<'a>, + tx_buffer: PacketBuffer<'a>, /// The time-to-live (IPv4) or hop limit (IPv6) value used in outgoing packets. - hop_limit: Option + hop_limit: Option, + #[cfg(feature = "async")] + rx_waker: WakerRegistration, + #[cfg(feature = "async")] + tx_waker: WakerRegistration, } -impl<'a, 'b> UdpSocket<'a, 'b> { +impl<'a> Socket<'a> { /// Create an UDP socket with the given buffers. - pub fn new(rx_buffer: UdpSocketBuffer<'a, 'b>, - tx_buffer: UdpSocketBuffer<'a, 'b>) -> UdpSocket<'a, 'b> { - UdpSocket { - meta: SocketMeta::default(), - endpoint: IpEndpoint::default(), - rx_buffer: rx_buffer, - tx_buffer: tx_buffer, - hop_limit: None + pub fn new(rx_buffer: PacketBuffer<'a>, tx_buffer: PacketBuffer<'a>) -> Socket<'a> { + Socket { + endpoint: IpListenEndpoint::default(), + rx_buffer, + tx_buffer, + hop_limit: None, + #[cfg(feature = "async")] + rx_waker: WakerRegistration::new(), + #[cfg(feature = "async")] + tx_waker: WakerRegistration::new(), } } - /// Return the socket handle. - #[inline] - pub fn handle(&self) -> SocketHandle { - self.meta.handle + /// Register a waker for receive operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `recv` method calls, such as receiving data, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `recv` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_recv_waker(&mut self, waker: &Waker) { + self.rx_waker.register(waker) + } + + /// Register a waker for send operations. + /// + /// The waker is woken on state changes that might affect the return value + /// of `send` method calls, such as space becoming available in the transmit + /// buffer, or the socket closing. + /// + /// Notes: + /// + /// - Only one waker can be registered at a time. If another waker was previously registered, + /// it is overwritten and will no longer be woken. + /// - The Waker is woken only once. Once woken, you must register it again to receive more wakes. + /// - "Spurious wakes" are allowed: a wake doesn't guarantee the result of `send` has + /// necessarily changed. + #[cfg(feature = "async")] + pub fn register_send_waker(&mut self, waker: &Waker) { + self.tx_waker.register(waker) } /// Return the bound endpoint. #[inline] - pub fn endpoint(&self) -> IpEndpoint { + pub fn endpoint(&self) -> IpListenEndpoint { self.endpoint } @@ -82,16 +206,43 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// This function returns `Err(Error::Illegal)` if the socket was open /// (see [is_open](#method.is_open)), and `Err(Error::Unaddressable)` /// if the port in the given endpoint is zero. - pub fn bind>(&mut self, endpoint: T) -> Result<()> { + pub fn bind>(&mut self, endpoint: T) -> Result<(), BindError> { let endpoint = endpoint.into(); - if endpoint.port == 0 { return Err(Error::Unaddressable) } + if endpoint.port == 0 { + return Err(BindError::Unaddressable); + } - if self.is_open() { return Err(Error::Illegal) } + if self.is_open() { + return Err(BindError::InvalidState); + } self.endpoint = endpoint; + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + Ok(()) } + /// Close the socket. + pub fn close(&mut self) { + // Clear the bound endpoint of the socket. + self.endpoint = IpListenEndpoint::default(); + + // Reset the RX and TX buffers of the socket. + self.tx_buffer.reset(); + self.rx_buffer.reset(); + + #[cfg(feature = "async")] + { + self.rx_waker.wake(); + self.tx_waker.wake(); + } + } + /// Check whether the socket is open. #[inline] pub fn is_open(&self) -> bool { @@ -141,22 +292,84 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// `Err(Error::Unaddressable)` if local or remote port, or remote address are unspecified, /// and `Err(Error::Truncated)` if there is not enough transmit buffer capacity /// to ever send this packet. - pub fn send(&mut self, size: usize, endpoint: IpEndpoint) -> Result<&mut [u8]> { - if self.endpoint.port == 0 { return Err(Error::Unaddressable) } - if !endpoint.is_specified() { return Err(Error::Unaddressable) } - - let payload_buf = self.tx_buffer.enqueue(size, endpoint)?; + pub fn send( + &mut self, + size: usize, + meta: impl Into, + ) -> Result<&mut [u8], SendError> { + let meta = meta.into(); + if self.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + if meta.endpoint.addr.is_unspecified() { + return Err(SendError::Unaddressable); + } + if meta.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } - net_trace!("{}:{}:{}: buffer to send {} octets", - self.meta.handle, self.endpoint, endpoint, size); + let payload_buf = self + .tx_buffer + .enqueue(size, meta) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "udp:{}:{}: buffer to send {} octets", + self.endpoint, + meta.endpoint, + size + ); Ok(payload_buf) } + /// Enqueue a packet to be send to a given remote endpoint and pass the buffer + /// to the provided closure. The closure then returns the size of the data written + /// into the buffer. + /// + /// Also see [send](#method.send). + pub fn send_with( + &mut self, + max_size: usize, + meta: impl Into, + f: F, + ) -> Result + where + F: FnOnce(&mut [u8]) -> usize, + { + let meta = meta.into(); + if self.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + if meta.endpoint.addr.is_unspecified() { + return Err(SendError::Unaddressable); + } + if meta.endpoint.port == 0 { + return Err(SendError::Unaddressable); + } + + let size = self + .tx_buffer + .enqueue_with_infallible(max_size, meta, f) + .map_err(|_| SendError::BufferFull)?; + + net_trace!( + "udp:{}:{}: buffer to send {} octets", + self.endpoint, + meta.endpoint, + size + ); + Ok(size) + } + /// Enqueue a packet to be sent to a given remote endpoint, and fill it from a slice. /// /// See also [send](#method.send). - pub fn send_slice(&mut self, data: &[u8], endpoint: IpEndpoint) -> Result<()> { - self.send(data.len(), endpoint)?.copy_from_slice(data); + pub fn send_slice( + &mut self, + data: &[u8], + meta: impl Into, + ) -> Result<(), SendError> { + self.send(data.len(), meta)?.copy_from_slice(data); Ok(()) } @@ -164,21 +377,25 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// as a pointer to the payload. /// /// This function returns `Err(Error::Exhausted)` if the receive buffer is empty. - pub fn recv(&mut self) -> Result<(&[u8], IpEndpoint)> { - let (endpoint, payload_buf) = self.rx_buffer.dequeue()?; + pub fn recv(&mut self) -> Result<(&[u8], UdpMetadata), RecvError> { + let (remote_endpoint, payload_buf) = + self.rx_buffer.dequeue().map_err(|_| RecvError::Exhausted)?; - net_trace!("{}:{}:{}: receive {} buffered octets", - self.meta.handle, self.endpoint, - endpoint, payload_buf.len()); - Ok((payload_buf, endpoint)) + net_trace!( + "udp:{}:{}: receive {} buffered octets", + self.endpoint, + remote_endpoint.endpoint, + payload_buf.len() + ); + Ok((payload_buf, remote_endpoint)) } /// Dequeue a packet received from a remote endpoint, copy the payload into the given slice, /// and return the amount of octets copied as well as the endpoint. /// /// See also [recv](#method.recv). - pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, IpEndpoint)> { - let (buffer, endpoint) = self.recv()?; + pub fn recv_slice(&mut self, data: &mut [u8]) -> Result<(usize, UdpMetadata), RecvError> { + let (buffer, endpoint) = self.recv().map_err(|_| RecvError::Exhausted)?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) @@ -189,15 +406,19 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// This function otherwise behaves identically to [recv](#method.recv). /// /// It returns `Err(Error::Exhausted)` if the receive buffer is empty. - pub fn peek(&mut self) -> Result<(&[u8], &IpEndpoint)> { - let handle = self.meta.handle; + pub fn peek(&mut self) -> Result<(&[u8], &UdpMetadata), RecvError> { let endpoint = self.endpoint; - self.rx_buffer.peek().map(|(remote_endpoint, payload_buf)| { - net_trace!("{}:{}:{}: peek {} buffered octets", - handle, endpoint, - remote_endpoint, payload_buf.len()); - (payload_buf, remote_endpoint) - }) + self.rx_buffer.peek().map_err(|_| RecvError::Exhausted).map( + |(remote_endpoint, payload_buf)| { + net_trace!( + "udp:{}:{}: peek {} buffered octets", + endpoint, + remote_endpoint.endpoint, + payload_buf.len() + ); + (payload_buf, remote_endpoint) + }, + ) } /// Peek at a packet received from a remote endpoint, copy the payload into the given slice, @@ -206,65 +427,126 @@ impl<'a, 'b> UdpSocket<'a, 'b> { /// This function otherwise behaves identically to [recv_slice](#method.recv_slice). /// /// See also [peek](#method.peek). - pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &IpEndpoint)> { + pub fn peek_slice(&mut self, data: &mut [u8]) -> Result<(usize, &UdpMetadata), RecvError> { let (buffer, endpoint) = self.peek()?; let length = min(data.len(), buffer.len()); data[..length].copy_from_slice(&buffer[..length]); Ok((length, endpoint)) } - pub(crate) fn accepts(&self, ip_repr: &IpRepr, repr: &UdpRepr) -> bool { - if self.endpoint.port != repr.dst_port { return false } - if !self.endpoint.addr.is_unspecified() && - self.endpoint.addr != ip_repr.dst_addr() && - !ip_repr.dst_addr().is_broadcast() && - !ip_repr.dst_addr().is_multicast() { return false } + pub(crate) fn accepts(&self, cx: &mut Context, ip_repr: &IpRepr, repr: &UdpRepr) -> bool { + if self.endpoint.port != repr.dst_port { + return false; + } + if self.endpoint.addr.is_some() + && self.endpoint.addr != Some(ip_repr.dst_addr()) + && !cx.is_broadcast(&ip_repr.dst_addr()) + && !ip_repr.dst_addr().is_multicast() + { + return false; + } true } - pub(crate) fn process(&mut self, ip_repr: &IpRepr, repr: &UdpRepr) -> Result<()> { - debug_assert!(self.accepts(ip_repr, repr)); + pub(crate) fn process( + &mut self, + cx: &mut Context, + meta: PacketMeta, + ip_repr: &IpRepr, + repr: &UdpRepr, + payload: &[u8], + ) { + debug_assert!(self.accepts(cx, ip_repr, repr)); - let size = repr.payload.len(); + let size = payload.len(); - let endpoint = IpEndpoint { addr: ip_repr.src_addr(), port: repr.src_port }; - self.rx_buffer.enqueue(size, endpoint)?.copy_from_slice(repr.payload); + let remote_endpoint = IpEndpoint { + addr: ip_repr.src_addr(), + port: repr.src_port, + }; - net_trace!("{}:{}:{}: receiving {} octets", - self.meta.handle, self.endpoint, - endpoint, size); - Ok(()) + net_trace!( + "udp:{}:{}: receiving {} octets", + self.endpoint, + remote_endpoint, + size + ); + + let metadata = UdpMetadata { + endpoint: remote_endpoint, + meta, + }; + + match self.rx_buffer.enqueue(size, metadata) { + Ok(buf) => buf.copy_from_slice(payload), + Err(_) => net_trace!( + "udp:{}:{}: buffer full, dropped incoming packet", + self.endpoint, + remote_endpoint + ), + } + + #[cfg(feature = "async")] + self.rx_waker.wake(); } - pub(crate) fn dispatch(&mut self, emit: F) -> Result<()> - where F: FnOnce((IpRepr, UdpRepr)) -> Result<()> { - let handle = self.handle(); - let endpoint = self.endpoint; + pub(crate) fn dispatch(&mut self, cx: &mut Context, emit: F) -> Result<(), E> + where + F: FnOnce(&mut Context, PacketMeta, (IpRepr, UdpRepr, &[u8])) -> Result<(), E>, + { + let endpoint = self.endpoint; let hop_limit = self.hop_limit.unwrap_or(64); - self.tx_buffer.dequeue_with(|remote_endpoint, payload_buf| { - net_trace!("{}:{}:{}: sending {} octets", - handle, endpoint, - endpoint, payload_buf.len()); + let res = self.tx_buffer.dequeue_with(|packet_meta, payload_buf| { + let src_addr = match endpoint.addr { + Some(addr) => addr, + None => match cx.get_source_address(packet_meta.endpoint.addr) { + Some(addr) => addr, + None => { + net_trace!( + "udp:{}:{}: cannot find suitable source address, dropping.", + endpoint, + packet_meta.endpoint + ); + return Ok(()); + } + }, + }; + + net_trace!( + "udp:{}:{}: sending {} octets", + endpoint, + packet_meta.endpoint, + payload_buf.len() + ); let repr = UdpRepr { src_port: endpoint.port, - dst_port: remote_endpoint.port, - payload: payload_buf, + dst_port: packet_meta.endpoint.port, }; - let ip_repr = IpRepr::Unspecified { - src_addr: endpoint.addr, - dst_addr: remote_endpoint.addr, - protocol: IpProtocol::Udp, - payload_len: repr.buffer_len(), - hop_limit: hop_limit, - }; - emit((ip_repr, repr)) - }) + let ip_repr = IpRepr::new( + src_addr, + packet_meta.endpoint.addr, + IpProtocol::Udp, + repr.header_len() + payload_buf.len(), + hop_limit, + ); + + emit(cx, packet_meta.meta, (ip_repr, repr, payload_buf)) + }); + match res { + Err(Empty) => Ok(()), + Ok(Err(e)) => Err(e), + Ok(Ok(())) => { + #[cfg(feature = "async")] + self.tx_waker.wake(); + Ok(()) + } + } } - pub(crate) fn poll_at(&self) -> PollAt { + pub(crate) fn poll_at(&self, _cx: &mut Context) -> PollAt { if self.tx_buffer.is_empty() { PollAt::Ingress } else { @@ -273,91 +555,112 @@ impl<'a, 'b> UdpSocket<'a, 'b> { } } -impl<'a, 'b> Into> for UdpSocket<'a, 'b> { - fn into(self) -> Socket<'a, 'b> { - Socket::Udp(self) - } -} - #[cfg(test)] mod test { - use wire::{IpAddress, IpRepr, UdpRepr}; - #[cfg(feature = "proto-ipv4")] - use wire::Ipv4Repr; - #[cfg(feature = "proto-ipv6")] - use wire::Ipv6Repr; - use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2, MOCK_IP_ADDR_3}; use super::*; + use crate::wire::{IpRepr, UdpRepr}; - fn buffer(packets: usize) -> UdpSocketBuffer<'static, 'static> { - UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY; packets], vec![0; 16 * packets]) + fn buffer(packets: usize) -> PacketBuffer<'static> { + PacketBuffer::new( + (0..packets) + .map(|_| PacketMetadata::EMPTY) + .collect::>(), + vec![0; 16 * packets], + ) } - fn socket(rx_buffer: UdpSocketBuffer<'static, 'static>, - tx_buffer: UdpSocketBuffer<'static, 'static>) - -> UdpSocket<'static, 'static> { - UdpSocket::new(rx_buffer, tx_buffer) + fn socket( + rx_buffer: PacketBuffer<'static>, + tx_buffer: PacketBuffer<'static>, + ) -> Socket<'static> { + Socket::new(rx_buffer, tx_buffer) + } + + const LOCAL_PORT: u16 = 53; + const REMOTE_PORT: u16 = 49500; + + cfg_if::cfg_if! { + if #[cfg(feature = "proto-ipv4")] { + use crate::wire::Ipv4Address as IpvXAddress; + use crate::wire::Ipv4Repr as IpvXRepr; + use IpRepr::Ipv4 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 1]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 2]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([192, 168, 1, 3]); + } else { + use crate::wire::Ipv6Address as IpvXAddress; + use crate::wire::Ipv6Repr as IpvXRepr; + use IpRepr::Ipv6 as IpReprIpvX; + + const LOCAL_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ]); + const REMOTE_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ]); + const OTHER_ADDR: IpvXAddress = IpvXAddress([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ]); + } } - const LOCAL_PORT: u16 = 53; - const REMOTE_PORT: u16 = 49500; + pub const LOCAL_END: IpEndpoint = IpEndpoint { + addr: LOCAL_ADDR.into_address(), + port: LOCAL_PORT, + }; + pub const REMOTE_END: IpEndpoint = IpEndpoint { + addr: REMOTE_ADDR.into_address(), + port: REMOTE_PORT, + }; - pub const LOCAL_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_1, port: LOCAL_PORT }; - pub const REMOTE_END: IpEndpoint = IpEndpoint { addr: MOCK_IP_ADDR_2, port: REMOTE_PORT }; + pub const LOCAL_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 64, + }); - pub const LOCAL_IP_REPR: IpRepr = IpRepr::Unspecified { - src_addr: MOCK_IP_ADDR_1, - dst_addr: MOCK_IP_ADDR_2, - protocol: IpProtocol::Udp, + pub const REMOTE_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: LOCAL_ADDR, + next_header: IpProtocol::Udp, payload_len: 8 + 6, hop_limit: 64, - }; + }); + + pub const BAD_IP_REPR: IpRepr = IpReprIpvX(IpvXRepr { + src_addr: REMOTE_ADDR, + dst_addr: OTHER_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 64, + }); const LOCAL_UDP_REPR: UdpRepr = UdpRepr { src_port: LOCAL_PORT, dst_port: REMOTE_PORT, - payload: b"abcdef" }; const REMOTE_UDP_REPR: UdpRepr = UdpRepr { src_port: REMOTE_PORT, dst_port: LOCAL_PORT, - payload: b"abcdef" }; - fn remote_ip_repr() -> IpRepr { - match (MOCK_IP_ADDR_2, MOCK_IP_ADDR_1) { - #[cfg(feature = "proto-ipv4")] - (IpAddress::Ipv4(src), IpAddress::Ipv4(dst)) => IpRepr::Ipv4(Ipv4Repr { - src_addr: src, - dst_addr: dst, - protocol: IpProtocol::Udp, - payload_len: 8 + 6, - hop_limit: 64 - }), - #[cfg(feature = "proto-ipv6")] - (IpAddress::Ipv6(src), IpAddress::Ipv6(dst)) => IpRepr::Ipv6(Ipv6Repr { - src_addr: src, - dst_addr: dst, - next_header: IpProtocol::Udp, - payload_len: 8 + 6, - hop_limit: 64 - }), - _ => unreachable!() - } - } + const PAYLOAD: &[u8] = b"abcdef"; #[test] fn test_bind_unaddressable() { let mut socket = socket(buffer(0), buffer(0)); - assert_eq!(socket.bind(0), Err(Error::Unaddressable)); + assert_eq!(socket.bind(0), Err(BindError::Unaddressable)); } #[test] fn test_bind_twice() { let mut socket = socket(buffer(0), buffer(0)); assert_eq!(socket.bind(1), Ok(())); - assert_eq!(socket.bind(2), Err(Error::Illegal)); + assert_eq!(socket.bind(2), Err(BindError::InvalidState)); } #[test] @@ -370,170 +673,236 @@ mod test { #[test] fn test_send_unaddressable() { let mut socket = socket(buffer(0), buffer(1)); - assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Err(Error::Unaddressable)); + + assert_eq!( + socket.send_slice(b"abcdef", REMOTE_END), + Err(SendError::Unaddressable) + ); assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert_eq!(socket.send_slice(b"abcdef", - IpEndpoint { addr: IpAddress::Unspecified, ..REMOTE_END }), - Err(Error::Unaddressable)); - assert_eq!(socket.send_slice(b"abcdef", - IpEndpoint { port: 0, ..REMOTE_END }), - Err(Error::Unaddressable)); + assert_eq!( + socket.send_slice( + b"abcdef", + IpEndpoint { + addr: IpvXAddress::UNSPECIFIED.into(), + ..REMOTE_END + } + ), + Err(SendError::Unaddressable) + ); + assert_eq!( + socket.send_slice( + b"abcdef", + IpEndpoint { + port: 0, + ..REMOTE_END + } + ), + Err(SendError::Unaddressable) + ); assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(())); } #[test] fn test_send_dispatch() { let mut socket = socket(buffer(0), buffer(1)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_END), Ok(())); assert!(socket.can_send()); - assert_eq!(socket.dispatch(|_| unreachable!()), - Err(Error::Exhausted)); + assert_eq!( + socket.dispatch(&mut cx, |_, _, _| unreachable!()), + Ok::<_, ()>(()) + ); assert_eq!(socket.send_slice(b"abcdef", REMOTE_END), Ok(())); - assert_eq!(socket.send_slice(b"123456", REMOTE_END), Err(Error::Exhausted)); + assert_eq!( + socket.send_slice(b"123456", REMOTE_END), + Err(SendError::BufferFull) + ); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, udp_repr)| { - assert_eq!(ip_repr, LOCAL_IP_REPR); - assert_eq!(udp_repr, LOCAL_UDP_REPR); - Err(Error::Unaddressable) - }), Err(Error::Unaddressable)); + assert_eq!( + socket.dispatch(&mut cx, |_, _, (ip_repr, udp_repr, payload)| { + assert_eq!(ip_repr, LOCAL_IP_REPR); + assert_eq!(udp_repr, LOCAL_UDP_REPR); + assert_eq!(payload, PAYLOAD); + Err(()) + }), + Err(()) + ); assert!(!socket.can_send()); - assert_eq!(socket.dispatch(|(ip_repr, udp_repr)| { - assert_eq!(ip_repr, LOCAL_IP_REPR); - assert_eq!(udp_repr, LOCAL_UDP_REPR); + assert_eq!( + socket.dispatch(&mut cx, |_, _, (ip_repr, udp_repr, payload)| { + assert_eq!(ip_repr, LOCAL_IP_REPR); + assert_eq!(udp_repr, LOCAL_UDP_REPR); + assert_eq!(payload, PAYLOAD); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); assert!(socket.can_send()); } #[test] fn test_recv_process() { let mut socket = socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); assert!(!socket.can_recv()); - assert_eq!(socket.recv(), Err(Error::Exhausted)); - - assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR), - Ok(())); + assert_eq!(socket.recv(), Err(RecvError::Exhausted)); + + assert!(socket.accepts(&mut cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + &mut cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); assert!(socket.can_recv()); - assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR), - Err(Error::Exhausted)); - assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END))); + assert!(socket.accepts(&mut cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + &mut cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + + assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into()))); assert!(!socket.can_recv()); } #[test] fn test_peek_process() { let mut socket = socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert_eq!(socket.peek(), Err(Error::Exhausted)); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR), - Ok(())); - assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END))); - assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END))); - assert_eq!(socket.peek(), Err(Error::Exhausted)); + socket.process( + &mut cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); + assert_eq!(socket.peek(), Ok((&b"abcdef"[..], &REMOTE_END.into(),))); + assert_eq!(socket.recv(), Ok((&b"abcdef"[..], REMOTE_END.into(),))); + assert_eq!(socket.peek(), Err(RecvError::Exhausted)); } #[test] fn test_recv_truncated_slice() { let mut socket = socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert!(socket.accepts(&remote_ip_repr(), &REMOTE_UDP_REPR)); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR), - Ok(())); + assert!(socket.accepts(&mut cx, &REMOTE_IP_REPR, &REMOTE_UDP_REPR)); + socket.process( + &mut cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); let mut slice = [0; 4]; - assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END))); + assert_eq!( + socket.recv_slice(&mut slice[..]), + Ok((4, REMOTE_END.into())) + ); assert_eq!(&slice, b"abcd"); } #[test] fn test_peek_truncated_slice() { let mut socket = socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); - assert_eq!(socket.process(&remote_ip_repr(), &REMOTE_UDP_REPR), - Ok(())); + socket.process( + &mut cx, + PacketMeta::default(), + &REMOTE_IP_REPR, + &REMOTE_UDP_REPR, + PAYLOAD, + ); let mut slice = [0; 4]; - assert_eq!(socket.peek_slice(&mut slice[..]), Ok((4, &REMOTE_END))); + assert_eq!( + socket.peek_slice(&mut slice[..]), + Ok((4, &REMOTE_END.into())) + ); assert_eq!(&slice, b"abcd"); - assert_eq!(socket.recv_slice(&mut slice[..]), Ok((4, REMOTE_END))); + assert_eq!( + socket.recv_slice(&mut slice[..]), + Ok((4, REMOTE_END.into())) + ); assert_eq!(&slice, b"abcd"); - assert_eq!(socket.peek_slice(&mut slice[..]), Err(Error::Exhausted)); + assert_eq!(socket.peek_slice(&mut slice[..]), Err(RecvError::Exhausted)); } #[test] fn test_set_hop_limit() { let mut s = socket(buffer(0), buffer(1)); + let mut cx = Context::mock(); + assert_eq!(s.bind(LOCAL_END), Ok(())); s.set_hop_limit(Some(0x2a)); assert_eq!(s.send_slice(b"abcdef", REMOTE_END), Ok(())); - assert_eq!(s.dispatch(|(ip_repr, _)| { - assert_eq!(ip_repr, IpRepr::Unspecified{ - src_addr: MOCK_IP_ADDR_1, - dst_addr: MOCK_IP_ADDR_2, - protocol: IpProtocol::Udp, - payload_len: 8 + 6, - hop_limit: 0x2a, - }); + assert_eq!( + s.dispatch(&mut cx, |_, _, (ip_repr, _, _)| { + assert_eq!( + ip_repr, + IpReprIpvX(IpvXRepr { + src_addr: LOCAL_ADDR, + dst_addr: REMOTE_ADDR, + next_header: IpProtocol::Udp, + payload_len: 8 + 6, + hop_limit: 0x2a, + }) + ); + Ok::<_, ()>(()) + }), Ok(()) - }), Ok(())); + ); } #[test] fn test_doesnt_accept_wrong_port() { let mut socket = socket(buffer(1), buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); let mut udp_repr = REMOTE_UDP_REPR; - assert!(socket.accepts(&remote_ip_repr(), &udp_repr)); + assert!(socket.accepts(&mut cx, &REMOTE_IP_REPR, &udp_repr)); udp_repr.dst_port += 1; - assert!(!socket.accepts(&remote_ip_repr(), &udp_repr)); + assert!(!socket.accepts(&mut cx, &REMOTE_IP_REPR, &udp_repr)); } #[test] fn test_doesnt_accept_wrong_ip() { - fn generate_bad_repr() -> IpRepr { - match (MOCK_IP_ADDR_2, MOCK_IP_ADDR_3) { - #[cfg(feature = "proto-ipv4")] - (IpAddress::Ipv4(src), IpAddress::Ipv4(dst)) => IpRepr::Ipv4(Ipv4Repr { - src_addr: src, - dst_addr: dst, - protocol: IpProtocol::Udp, - payload_len: 8 + 6, - hop_limit: 64 - }), - #[cfg(feature = "proto-ipv6")] - (IpAddress::Ipv6(src), IpAddress::Ipv6(dst)) => IpRepr::Ipv6(Ipv6Repr { - src_addr: src, - dst_addr: dst, - next_header: IpProtocol::Udp, - payload_len: 8 + 6, - hop_limit: 64 - }), - _ => unreachable!() - } - } + let mut cx = Context::mock(); let mut port_bound_socket = socket(buffer(1), buffer(0)); assert_eq!(port_bound_socket.bind(LOCAL_PORT), Ok(())); - assert!(port_bound_socket.accepts(&generate_bad_repr(), &REMOTE_UDP_REPR)); + assert!(port_bound_socket.accepts(&mut cx, &BAD_IP_REPR, &REMOTE_UDP_REPR)); let mut ip_bound_socket = socket(buffer(1), buffer(0)); assert_eq!(ip_bound_socket.bind(LOCAL_END), Ok(())); - assert!(!ip_bound_socket.accepts(&generate_bad_repr(), &REMOTE_UDP_REPR)); + assert!(!ip_bound_socket.accepts(&mut cx, &BAD_IP_REPR, &REMOTE_UDP_REPR)); } #[test] @@ -543,22 +912,39 @@ mod test { assert_eq!(socket.bind(LOCAL_END), Ok(())); let too_large = b"0123456789abcdef0123456789abcdef0123456789abcdef0123456789abcdefx"; - assert_eq!(socket.send_slice(too_large, REMOTE_END), Err(Error::Truncated)); - assert_eq!(socket.send_slice(&too_large[..16*4], REMOTE_END), Ok(())); + assert_eq!( + socket.send_slice(too_large, REMOTE_END), + Err(SendError::BufferFull) + ); + assert_eq!(socket.send_slice(&too_large[..16 * 4], REMOTE_END), Ok(())); } #[test] fn test_process_empty_payload() { - let recv_buffer = UdpSocketBuffer::new(vec![UdpPacketMetadata::EMPTY; 1], vec![]); + let meta = Box::leak(Box::new([PacketMetadata::EMPTY])); + let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]); let mut socket = socket(recv_buffer, buffer(0)); + let mut cx = Context::mock(); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); let repr = UdpRepr { src_port: REMOTE_PORT, dst_port: LOCAL_PORT, - payload: &[] }; - assert_eq!(socket.process(&remote_ip_repr(), &repr), Ok(())); - assert_eq!(socket.recv(), Ok((&[][..], REMOTE_END))); + socket.process(&mut cx, PacketMeta::default(), &REMOTE_IP_REPR, &repr, &[]); + assert_eq!(socket.recv(), Ok((&[][..], REMOTE_END.into()))); + } + + #[test] + fn test_closing() { + let meta = Box::leak(Box::new([PacketMetadata::EMPTY])); + let recv_buffer = PacketBuffer::new(&mut meta[..], vec![]); + let mut socket = socket(recv_buffer, buffer(0)); + assert_eq!(socket.bind(LOCAL_PORT), Ok(())); + + assert!(socket.is_open()); + socket.close(); + assert!(!socket.is_open()); } } diff --git a/src/socket/waker.rs b/src/socket/waker.rs new file mode 100644 index 000000000..4f4219788 --- /dev/null +++ b/src/socket/waker.rs @@ -0,0 +1,33 @@ +use core::task::Waker; + +/// Utility struct to register and wake a waker. +#[derive(Debug)] +pub struct WakerRegistration { + waker: Option, +} + +impl WakerRegistration { + pub const fn new() -> Self { + Self { waker: None } + } + + /// Register a waker. Overwrites the previous waker, if any. + pub fn register(&mut self, w: &Waker) { + match self.waker { + // Optimization: If both the old and new Wakers wake the same task, we can simply + // keep the old waker, skipping the clone. (In most executor implementations, + // cloning a waker is somewhat expensive, comparable to cloning an Arc). + Some(ref w2) if (w2.will_wake(w)) => {} + // In all other cases + // - we have no waker registered + // - we have a waker registered but it's for a different task. + // then clone the new waker and store it + _ => self.waker = Some(w.clone()), + } + } + + /// Wake the registered waker, if any. + pub fn wake(&mut self) { + self.waker.take().map(|w| w.wake()); + } +} diff --git a/src/storage/assembler.rs b/src/storage/assembler.rs index d5685a760..1577d1392 100644 --- a/src/storage/assembler.rs +++ b/src/storage/assembler.rs @@ -1,32 +1,69 @@ use core::fmt; +use crate::config::ASSEMBLER_MAX_SEGMENT_COUNT; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub struct TooManyHolesError; + +impl fmt::Display for TooManyHolesError { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "too many holes") + } +} + +#[cfg(feature = "std")] +impl std::error::Error for TooManyHolesError {} + /// A contiguous chunk of absent data, followed by a contiguous chunk of present data. #[derive(Debug, Clone, Copy, PartialEq, Eq)] struct Contig { hole_size: usize, - data_size: usize + data_size: usize, } impl fmt::Display for Contig { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - if self.has_hole() { write!(f, "({})", self.hole_size)?; } - if self.has_hole() && self.has_data() { write!(f, " ")?; } - if self.has_data() { write!(f, "{}", self.data_size)?; } + if self.has_hole() { + write!(f, "({})", self.hole_size)?; + } + if self.has_hole() && self.has_data() { + write!(f, " ")?; + } + if self.has_data() { + write!(f, "{}", self.data_size)?; + } Ok(()) } } -impl Contig { - fn empty() -> Contig { - Contig { hole_size: 0, data_size: 0 } +#[cfg(feature = "defmt")] +impl defmt::Format for Contig { + fn format(&self, fmt: defmt::Formatter) { + if self.has_hole() { + defmt::write!(fmt, "({})", self.hole_size); + } + if self.has_hole() && self.has_data() { + defmt::write!(fmt, " "); + } + if self.has_data() { + defmt::write!(fmt, "{}", self.data_size); + } } +} - fn hole(size: usize) -> Contig { - Contig { hole_size: size, data_size: 0 } +impl Contig { + const fn empty() -> Contig { + Contig { + hole_size: 0, + data_size: 0, + } } fn hole_and_data(hole_size: usize, data_size: usize) -> Contig { - Contig { hole_size, data_size } + Contig { + hole_size, + data_size, + } } fn has_hole(&self) -> bool { @@ -41,14 +78,6 @@ impl Contig { self.hole_size + self.data_size } - fn is_empty(&self) -> bool { - self.total_size() == 0 - } - - fn expand_data_by(&mut self, size: usize) { - self.data_size += size; - } - fn shrink_hole_by(&mut self, size: usize) { self.hole_size -= size; } @@ -62,64 +91,73 @@ impl Contig { } } -#[cfg(feature = "std")] -use std::boxed::Box; -#[cfg(all(feature = "alloc", not(feature = "std")))] -use alloc::boxed::Box; -#[cfg(any(feature = "std", feature = "alloc"))] -const CONTIG_COUNT: usize = 32; - -#[cfg(not(any(feature = "std", feature = "alloc")))] -const CONTIG_COUNT: usize = 4; - /// A buffer (re)assembler. /// /// Currently, up to a hardcoded limit of 4 or 32 holes can be tracked in the buffer. -#[derive(Debug)] -#[cfg_attr(test, derive(PartialEq, Eq, Clone))] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Assembler { - #[cfg(not(any(feature = "std", feature = "alloc")))] - contigs: [Contig; CONTIG_COUNT], - #[cfg(any(feature = "std", feature = "alloc"))] - contigs: Box<[Contig; CONTIG_COUNT]>, + contigs: [Contig; ASSEMBLER_MAX_SEGMENT_COUNT], } impl fmt::Display for Assembler { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "[ ")?; for contig in self.contigs.iter() { - if contig.is_empty() { break } - write!(f, "{} ", contig)?; + if !contig.has_data() { + break; + } + write!(f, "{contig} ")?; } write!(f, "]")?; Ok(()) } } +#[cfg(feature = "defmt")] +impl defmt::Format for Assembler { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "[ "); + for contig in self.contigs.iter() { + if !contig.has_data() { + break; + } + defmt::write!(fmt, "{} ", contig); + } + defmt::write!(fmt, "]"); + } +} + +// Invariant on Assembler::contigs: +// - There's an index `i` where all contigs before have data, and all contigs after don't (are unused). +// - All contigs with data must have hole_size != 0, except the first. + impl Assembler { - /// Create a new buffer assembler for buffers of the given size. - pub fn new(size: usize) -> Assembler { - #[cfg(not(any(feature = "std", feature = "alloc")))] - let mut contigs = [Contig::empty(); CONTIG_COUNT]; - #[cfg(any(feature = "std", feature = "alloc"))] - let mut contigs = Box::new([Contig::empty(); CONTIG_COUNT]); - contigs[0] = Contig::hole(size); - Assembler { contigs } + /// Create a new buffer assembler. + pub const fn new() -> Assembler { + const EMPTY: Contig = Contig::empty(); + Assembler { + contigs: [EMPTY; ASSEMBLER_MAX_SEGMENT_COUNT], + } } - /// FIXME(whitequark): remove this once I'm certain enough that the assembler works well. - #[allow(dead_code)] - pub(crate) fn total_size(&self) -> usize { - self.contigs - .iter() - .map(|contig| contig.total_size()) - .sum() + pub fn clear(&mut self) { + self.contigs.fill(Contig::empty()); } fn front(&self) -> Contig { self.contigs[0] } + /// Return length of the front contiguous range without removing it from the assembler + pub fn peek_front(&self) -> usize { + let front = self.front(); + if front.has_hole() { + 0 + } else { + front.data_size + } + } + fn back(&self) -> Contig { self.contigs[self.contigs.len() - 1] } @@ -129,29 +167,26 @@ impl Assembler { !self.front().has_data() } - /// Remove a contig at the given index, and return a pointer to the first contig - /// without data. - fn remove_contig_at(&mut self, at: usize) -> &mut Contig { - debug_assert!(!self.contigs[at].is_empty()); + /// Remove a contig at the given index. + fn remove_contig_at(&mut self, at: usize) { + debug_assert!(self.contigs[at].has_data()); for i in at..self.contigs.len() - 1 { - self.contigs[i] = self.contigs[i + 1]; if !self.contigs[i].has_data() { - self.contigs[i + 1] = Contig::empty(); - return &mut self.contigs[i] + return; } + self.contigs[i] = self.contigs[i + 1]; } // Removing the last one. - self.contigs[at] = Contig::empty(); - &mut self.contigs[at] + self.contigs[self.contigs.len() - 1] = Contig::empty(); } /// Add a contig at the given index, and return a pointer to it. - fn add_contig_at(&mut self, at: usize) -> Result<&mut Contig, ()> { - debug_assert!(!self.contigs[at].is_empty()); - - if !self.back().is_empty() { return Err(()) } + fn add_contig_at(&mut self, at: usize) -> Result<&mut Contig, TooManyHolesError> { + if self.back().has_data() { + return Err(TooManyHolesError); + } for i in (at + 1..self.contigs.len()).rev() { self.contigs[i] = self.contigs[i - 1]; @@ -161,77 +196,126 @@ impl Assembler { Ok(&mut self.contigs[at]) } - /// Add a new contiguous range to the assembler, and return `Ok(())`, - /// or return `Err(())` if too many discontiguities are already recorded. - pub fn add(&mut self, mut offset: usize, mut size: usize) -> Result<(), ()> { - let mut index = 0; - while index != self.contigs.len() && size != 0 { - let contig = self.contigs[index]; - - if offset >= contig.total_size() { - // The range being added does not cover this contig, skip it. - index += 1; - } else if offset == 0 && size >= contig.hole_size && index > 0 { - // The range being added covers the entire hole in this contig, merge it - // into the previous config. - self.contigs[index - 1].expand_data_by(contig.total_size()); - self.remove_contig_at(index); - index += 0; - } else if offset == 0 && size < contig.hole_size && index > 0 { - // The range being added covers a part of the hole in this contig starting - // at the beginning, shrink the hole in this contig and expand data in - // the previous contig. - self.contigs[index - 1].expand_data_by(size); - self.contigs[index].shrink_hole_by(size); - index += 1; - } else if offset <= contig.hole_size && offset + size >= contig.hole_size { - // The range being added covers both a part of the hole and a part of the data - // in this contig, shrink the hole in this contig. - self.contigs[index].shrink_hole_to(offset); - index += 1; - } else if offset + size >= contig.hole_size { - // The range being added covers only a part of the data in this contig, skip it. - index += 1; - } else if offset + size < contig.hole_size { - // The range being added covers a part of the hole but not of the data - // in this contig, add a new contig containing the range. - { - let inserted = self.add_contig_at(index)?; - *inserted = Contig::hole_and_data(offset, size); - } + /// Add a new contiguous range to the assembler, + /// or return `Err(TooManyHolesError)` if too many discontiguities are already recorded. + pub fn add(&mut self, mut offset: usize, size: usize) -> Result<(), TooManyHolesError> { + if size == 0 { + return Ok(()); + } + + let mut i = 0; + + // Find index of the contig containing the start of the range. + loop { + if i == self.contigs.len() { + // The new range is after all the previous ranges, but there/s no space to add it. + return Err(TooManyHolesError); + } + let contig = &mut self.contigs[i]; + if !contig.has_data() { + // The new range is after all the previous ranges. Add it. + *contig = Contig::hole_and_data(offset, size); + return Ok(()); + } + if offset <= contig.total_size() { + break; + } + offset -= contig.total_size(); + i += 1; + } + + let contig = &mut self.contigs[i]; + if offset < contig.hole_size { + // Range starts within the hole. + + if offset + size < contig.hole_size { + // Range also ends within the hole. + let new_contig = self.add_contig_at(i)?; + new_contig.hole_size = offset; + new_contig.data_size = size; + // Previous contigs[index] got moved to contigs[index+1] - self.contigs[index+1].shrink_hole_by(offset + size); - index += 2; - } else { - unreachable!() + self.contigs[i + 1].shrink_hole_by(offset + size); + return Ok(()); } - // Skip the portion of the range covered by this contig. - if offset >= contig.total_size() { - offset = offset.saturating_sub(contig.total_size()); - } else { - size = (offset + size).saturating_sub(contig.total_size()); - offset = 0; + // The range being added covers both a part of the hole and a part of the data + // in this contig, shrink the hole in this contig. + contig.shrink_hole_to(offset); + } + + // coalesce contigs to the right. + let mut j = i + 1; + while j < self.contigs.len() + && self.contigs[j].has_data() + && offset + size >= self.contigs[i].total_size() + self.contigs[j].hole_size + { + self.contigs[i].data_size += self.contigs[j].total_size(); + j += 1; + } + let shift = j - i - 1; + if shift != 0 { + for x in i + 1..self.contigs.len() { + if !self.contigs[x].has_data() { + break; + } + + self.contigs[x] = self + .contigs + .get(x + shift) + .copied() + .unwrap_or_else(Contig::empty); + } + } + + if offset + size > self.contigs[i].total_size() { + // The added range still extends beyond the current contig. Increase data size. + let left = offset + size - self.contigs[i].total_size(); + self.contigs[i].data_size += left; + + // Decrease hole size of the next, if any. + if i + 1 < self.contigs.len() && self.contigs[i + 1].has_data() { + self.contigs[i + 1].hole_size -= left; } } - debug_assert!(size == 0); Ok(()) } - /// Remove a contiguous range from the front of the assembler and `Some(data_size)`, - /// or return `None` if there is no such range. - pub fn remove_front(&mut self) -> Option { + /// Remove a contiguous range from the front of the assembler. + /// If no such range, return 0. + pub fn remove_front(&mut self) -> usize { let front = self.front(); - if front.has_hole() { - None + if front.has_hole() || !front.has_data() { + 0 } else { - let last_hole = self.remove_contig_at(0); - last_hole.hole_size += front.data_size; - + self.remove_contig_at(0); debug_assert!(front.data_size > 0); - Some(front.data_size) + front.data_size + } + } + + /// Add a segment, then remove_front. + /// + /// This is equivalent to calling `add` then `remove_front` individually, + /// except it's guaranteed to not fail when offset = 0. + /// This is required for TCP: we must never drop the next expected segment, or + /// the protocol might get stuck. + pub fn add_then_remove_front( + &mut self, + offset: usize, + size: usize, + ) -> Result { + // This is the only case where a segment at offset=0 would cause the + // total amount of contigs to rise (and therefore can potentially cause + // a TooManyHolesError). Handle it in a way that is guaranteed to succeed. + if offset == 0 && size < self.contigs[0].hole_size { + self.contigs[0].hole_size -= size; + return Ok(size); } + + self.add(offset, size)?; + Ok(self.remove_front()) } /// Iterate over all of the contiguous data ranges. @@ -243,7 +327,7 @@ impl Assembler { /// |--- 100 ---|--- 200 ---|--- 100 ---| /// /// An offset of 1500 would return the ranges: ``(1500, 1600), (1800, 1900)`` - pub fn iter_data<'a>(&'a self, first_offset: usize) -> AssemblerIter<'a> { + pub fn iter_data(&self, first_offset: usize) -> AssemblerIter { AssemblerIter::new(self, first_offset) } } @@ -253,17 +337,17 @@ pub struct AssemblerIter<'a> { offset: usize, index: usize, left: usize, - right: usize + right: usize, } impl<'a> AssemblerIter<'a> { fn new(assembler: &'a Assembler, offset: usize) -> AssemblerIter<'a> { AssemblerIter { - assembler: assembler, - offset: offset, + assembler, + offset, index: 0, left: 0, - right: 0 + right: 0, } } } @@ -275,7 +359,7 @@ impl<'a> Iterator for AssemblerIter<'a> { let mut data_range = None; while data_range.is_none() && self.index < self.assembler.contigs.len() { let contig = self.assembler.contigs[self.index]; - self.left = self.left + contig.hole_size; + self.left += contig.hole_size; self.right = self.left + contig.data_size; data_range = if self.left < self.right { let data_range = (self.left + self.offset, self.right + self.offset); @@ -292,17 +376,19 @@ impl<'a> Iterator for AssemblerIter<'a> { #[cfg(test)] mod test { - use std::vec::Vec; use super::*; + use std::vec::Vec; impl From> for Assembler { fn from(vec: Vec<(usize, usize)>) -> Assembler { - #[cfg(not(any(feature = "std", feature = "alloc")))] - let mut contigs = [Contig::empty(); CONTIG_COUNT]; - #[cfg(any(feature = "std", feature = "alloc"))] - let mut contigs = Box::new([Contig::empty(); CONTIG_COUNT]); + const EMPTY: Contig = Contig::empty(); + + let mut contigs = [EMPTY; ASSEMBLER_MAX_SEGMENT_COUNT]; for (i, &(hole_size, data_size)) in vec.iter().enumerate() { - contigs[i] = Contig { hole_size, data_size }; + contigs[i] = Contig { + hole_size, + data_size, + }; } Assembler { contigs } } @@ -316,138 +402,184 @@ mod test { #[test] fn test_new() { - let assr = Assembler::new(16); - assert_eq!(assr.total_size(), 16); - assert_eq!(assr, contigs![(16, 0)]); + let assr = Assembler::new(); + assert_eq!(assr, contigs![]); } #[test] fn test_empty_add_full() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(0, 16), Ok(())); assert_eq!(assr, contigs![(0, 16)]); } #[test] fn test_empty_add_front() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(0, 4), Ok(())); - assert_eq!(assr, contigs![(0, 4), (12, 0)]); + assert_eq!(assr, contigs![(0, 4)]); } #[test] fn test_empty_add_back() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(12, 4), Ok(())); assert_eq!(assr, contigs![(12, 4)]); } #[test] fn test_empty_add_mid() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(4, 8), Ok(())); - assert_eq!(assr, contigs![(4, 8), (4, 0)]); + assert_eq!(assr, contigs![(4, 8)]); } #[test] fn test_partial_add_front() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(0, 4), Ok(())); - assert_eq!(assr, contigs![(0, 12), (4, 0)]); + assert_eq!(assr, contigs![(0, 12)]); } #[test] fn test_partial_add_back() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(12, 4), Ok(())); assert_eq!(assr, contigs![(4, 12)]); } #[test] fn test_partial_add_front_overlap() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(0, 8), Ok(())); - assert_eq!(assr, contigs![(0, 12), (4, 0)]); + assert_eq!(assr, contigs![(0, 12)]); } #[test] fn test_partial_add_front_overlap_split() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(2, 6), Ok(())); - assert_eq!(assr, contigs![(2, 10), (4, 0)]); + assert_eq!(assr, contigs![(2, 10)]); } #[test] fn test_partial_add_back_overlap() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(8, 8), Ok(())); assert_eq!(assr, contigs![(4, 12)]); } #[test] fn test_partial_add_back_overlap_split() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(10, 4), Ok(())); - assert_eq!(assr, contigs![(4, 10), (2, 0)]); + assert_eq!(assr, contigs![(4, 10)]); } #[test] fn test_partial_add_both_overlap() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(0, 16), Ok(())); assert_eq!(assr, contigs![(0, 16)]); } #[test] fn test_partial_add_both_overlap_split() { - let mut assr = contigs![(4, 8), (4, 0)]; + let mut assr = contigs![(4, 8)]; assert_eq!(assr.add(2, 12), Ok(())); - assert_eq!(assr, contigs![(2, 12), (2, 0)]); + assert_eq!(assr, contigs![(2, 12)]); } #[test] fn test_rejected_add_keeps_state() { - let mut assr = Assembler::new(CONTIG_COUNT*20); - for c in 1..=CONTIG_COUNT-1 { - assert_eq!(assr.add(c*10, 3), Ok(())); + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); } // Maximum of allowed holes is reached let assr_before = assr.clone(); - assert_eq!(assr.add(1, 3), Err(())); + assert_eq!(assr.add(1, 3), Err(TooManyHolesError)); assert_eq!(assr_before, assr); } #[test] fn test_empty_remove_front() { - let mut assr = contigs![(12, 0)]; - assert_eq!(assr.remove_front(), None); + let mut assr = contigs![]; + assert_eq!(assr.remove_front(), 0); } #[test] fn test_trailing_hole_remove_front() { - let mut assr = contigs![(0, 4), (8, 0)]; - assert_eq!(assr.remove_front(), Some(4)); - assert_eq!(assr, contigs![(12, 0)]); + let mut assr = contigs![(0, 4)]; + assert_eq!(assr.remove_front(), 4); + assert_eq!(assr, contigs![]); } #[test] fn test_trailing_data_remove_front() { let mut assr = contigs![(0, 4), (4, 4)]; - assert_eq!(assr.remove_front(), Some(4)); - assert_eq!(assr, contigs![(4, 4), (4, 0)]); + assert_eq!(assr.remove_front(), 4); + assert_eq!(assr, contigs![(4, 4)]); + } + + #[test] + fn test_boundary_case_remove_front() { + let mut vec = vec![(1, 1); ASSEMBLER_MAX_SEGMENT_COUNT]; + vec[0] = (0, 2); + let mut assr = Assembler::from(vec); + assert_eq!(assr.remove_front(), 2); + let mut vec = vec![(1, 1); ASSEMBLER_MAX_SEGMENT_COUNT]; + vec[ASSEMBLER_MAX_SEGMENT_COUNT - 1] = (0, 0); + let exp_assr = Assembler::from(vec); + assert_eq!(assr, exp_assr); + } + #[test] + fn test_shrink_next_hole() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(100, 10), Ok(())); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(40, 30), Ok(())); + assert_eq!(assr, contigs![(40, 30), (30, 10)]); + } + + #[test] + fn test_join_two() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(15, 40), Ok(())); + assert_eq!(assr, contigs![(10, 50)]); + } + + #[test] + fn test_join_two_reversed() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(15, 40), Ok(())); + assert_eq!(assr, contigs![(10, 50)]); + } + + #[test] + fn test_join_two_overlong() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add(10, 10), Ok(())); + assert_eq!(assr.add(15, 60), Ok(())); + assert_eq!(assr, contigs![(10, 65)]); } #[test] fn test_iter_empty() { - let assr = Assembler::new(16); + let assr = Assembler::new(); let segments: Vec<_> = assr.iter_data(10).collect(); assert_eq!(segments, vec![]); } #[test] fn test_iter_full() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(0, 16), Ok(())); let segments: Vec<_> = assr.iter_data(10).collect(); assert_eq!(segments, vec![(10, 26)]); @@ -455,7 +587,7 @@ mod test { #[test] fn test_iter_offset() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(0, 16), Ok(())); let segments: Vec<_> = assr.iter_data(100).collect(); assert_eq!(segments, vec![(100, 116)]); @@ -463,7 +595,7 @@ mod test { #[test] fn test_iter_one_front() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(0, 4), Ok(())); let segments: Vec<_> = assr.iter_data(10).collect(); assert_eq!(segments, vec![(10, 14)]); @@ -471,7 +603,7 @@ mod test { #[test] fn test_iter_one_back() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(12, 4), Ok(())); let segments: Vec<_> = assr.iter_data(10).collect(); assert_eq!(segments, vec![(22, 26)]); @@ -479,7 +611,7 @@ mod test { #[test] fn test_iter_one_mid() { - let mut assr = Assembler::new(16); + let mut assr = Assembler::new(); assert_eq!(assr.add(4, 8), Ok(())); let segments: Vec<_> = assr.iter_data(10).collect(); assert_eq!(segments, vec![(14, 22)]); @@ -487,22 +619,132 @@ mod test { #[test] fn test_iter_one_trailing_gap() { - let assr = contigs![(4, 8), (4, 0)]; + let assr = contigs![(4, 8)]; let segments: Vec<_> = assr.iter_data(100).collect(); assert_eq!(segments, vec![(104, 112)]); } #[test] fn test_iter_two_split() { - let assr = contigs![(2, 6), (4, 1), (1, 0)]; + let assr = contigs![(2, 6), (4, 1)]; let segments: Vec<_> = assr.iter_data(100).collect(); assert_eq!(segments, vec![(102, 108), (112, 113)]); } #[test] fn test_iter_three_split() { - let assr = contigs![(2, 6), (2, 1), (2, 2), (1, 0)]; + let assr = contigs![(2, 6), (2, 1), (2, 2)]; let segments: Vec<_> = assr.iter_data(100).collect(); assert_eq!(segments, vec![(102, 108), (110, 111), (113, 115)]); } + + #[test] + fn test_issue_694() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(0, 1), Ok(())); + assert_eq!(assr.add(2, 1), Ok(())); + assert_eq!(assr.add(1, 1), Ok(())); + } + + #[test] + fn test_add_then_remove_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(10, 10), Ok(0)); + assert_eq!(assr, contigs![(10, 10), (30, 10)]); + } + + #[test] + fn test_add_then_remove_front_at_front() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(0, 10), Ok(10)); + assert_eq!(assr, contigs![(40, 10)]); + } + + #[test] + fn test_add_then_remove_front_at_front_touch() { + let mut assr = Assembler::new(); + assert_eq!(assr.add(50, 10), Ok(())); + assert_eq!(assr.add_then_remove_front(0, 50), Ok(60)); + assert_eq!(assr, contigs![]); + } + + #[test] + fn test_add_then_remove_front_at_front_full() { + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); + } + // Maximum of allowed holes is reached + let assr_before = assr.clone(); + assert_eq!(assr.add_then_remove_front(1, 3), Err(TooManyHolesError)); + assert_eq!(assr_before, assr); + } + + #[test] + fn test_add_then_remove_front_at_front_full_offset_0() { + let mut assr = Assembler::new(); + for c in 1..=ASSEMBLER_MAX_SEGMENT_COUNT { + assert_eq!(assr.add(c * 10, 3), Ok(())); + } + assert_eq!(assr.add_then_remove_front(0, 3), Ok(3)); + } + + // Test against an obviously-correct but inefficient bitmap impl. + #[test] + fn test_random() { + use rand::Rng; + + const MAX_INDEX: usize = 256; + + for max_size in [2, 5, 10, 100] { + for _ in 0..300 { + //println!("==="); + let mut assr = Assembler::new(); + let mut map = [false; MAX_INDEX]; + + for _ in 0..60 { + let offset = rand::thread_rng().gen_range(0..MAX_INDEX - max_size - 1); + let size = rand::thread_rng().gen_range(1..=max_size); + + //println!("add {}..{} {}", offset, offset + size, size); + // Real impl + let res = assr.add(offset, size); + + // Bitmap impl + let mut map2 = map; + map2[offset..][..size].fill(true); + + let mut contigs = vec![]; + let mut hole: usize = 0; + let mut data: usize = 0; + for b in map2 { + if b { + data += 1; + } else { + if data != 0 { + contigs.push((hole, data)); + hole = 0; + data = 0; + } + hole += 1; + } + } + + // Compare. + let wanted_res = if contigs.len() > ASSEMBLER_MAX_SEGMENT_COUNT { + Err(TooManyHolesError) + } else { + Ok(()) + }; + assert_eq!(res, wanted_res); + if res.is_ok() { + map = map2; + assert_eq!(assr, Assembler::from(contigs)); + } + } + } + } + } } diff --git a/src/storage/mod.rs b/src/storage/mod.rs index 8b9fbe68e..b03de7124 100644 --- a/src/storage/mod.rs +++ b/src/storage/mod.rs @@ -6,12 +6,12 @@ or `alloc` crates being available, and heap-allocated memory. */ mod assembler; -mod ring_buffer; mod packet_buffer; +mod ring_buffer; pub use self::assembler::Assembler; -pub use self::ring_buffer::RingBuffer; pub use self::packet_buffer::{PacketBuffer, PacketMetadata}; +pub use self::ring_buffer::RingBuffer; /// A trait for setting a value to a known state. /// @@ -19,3 +19,13 @@ pub use self::packet_buffer::{PacketBuffer, PacketMetadata}; pub trait Resettable { fn reset(&mut self); } + +/// Error returned when enqueuing into a full buffer. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Full; + +/// Error returned when dequeuing from an empty buffer. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Empty; diff --git a/src/storage/packet_buffer.rs b/src/storage/packet_buffer.rs index 3bec8f1ac..1447e8245 100644 --- a/src/storage/packet_buffer.rs +++ b/src/storage/packet_buffer.rs @@ -1,30 +1,35 @@ use managed::ManagedSlice; -use {Error, Result}; -use super::RingBuffer; +use crate::storage::{Full, RingBuffer}; + +use super::Empty; /// Size and header of a packet. #[derive(Debug, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct PacketMetadata { - size: usize, - header: Option + size: usize, + header: Option, } impl PacketMetadata { /// Empty packet description. - pub const EMPTY: PacketMetadata = PacketMetadata { size: 0, header: None }; + pub const EMPTY: PacketMetadata = PacketMetadata { + size: 0, + header: None, + }; fn padding(size: usize) -> PacketMetadata { PacketMetadata { - size: size, - header: None + size: size, + header: None, } } fn packet(size: usize, header: H) -> PacketMetadata { PacketMetadata { - size: size, - header: Some(header) + size: size, + header: Some(header), } } @@ -35,23 +40,24 @@ impl PacketMetadata { /// An UDP packet ring buffer. #[derive(Debug)] -pub struct PacketBuffer<'a, 'b, H: 'a> { +pub struct PacketBuffer<'a, H: 'a> { metadata_ring: RingBuffer<'a, PacketMetadata>, - payload_ring: RingBuffer<'b, u8>, + payload_ring: RingBuffer<'a, u8>, } -impl<'a, 'b, H> PacketBuffer<'a, 'b, H> { +impl<'a, H> PacketBuffer<'a, H> { /// Create a new packet buffer with the provided metadata and payload storage. /// /// Metadata storage limits the maximum _number_ of packets in the buffer and payload /// storage limits the maximum _total size_ of packets. - pub fn new(metadata_storage: MS, payload_storage: PS) -> PacketBuffer<'a, 'b, H> - where MS: Into>>, - PS: Into>, + pub fn new(metadata_storage: MS, payload_storage: PS) -> PacketBuffer<'a, H> + where + MS: Into>>, + PS: Into>, { PacketBuffer { metadata_ring: RingBuffer::new(metadata_storage), - payload_ring: RingBuffer::new(payload_storage), + payload_ring: RingBuffer::new(payload_storage), } } @@ -69,35 +75,38 @@ impl<'a, 'b, H> PacketBuffer<'a, 'b, H> { // in case of failure. /// Enqueue a single packet with the given header into the buffer, and - /// return a reference to its payload, or return `Err(Error::Exhausted)` - /// if the buffer is full, or return `Err(Error::Truncated)` if the buffer - /// does not have enough spare payload space. - pub fn enqueue(&mut self, size: usize, header: H) -> Result<&mut [u8]> { - if self.payload_ring.capacity() < size { - return Err(Error::Truncated) + /// return a reference to its payload, or return `Err(Full)` + /// if the buffer is full. + pub fn enqueue(&mut self, size: usize, header: H) -> Result<&mut [u8], Full> { + if self.payload_ring.capacity() < size || self.metadata_ring.is_full() { + return Err(Full); } - if self.metadata_ring.is_full() { - return Err(Error::Exhausted) + // Ring is currently empty. Clear it (resetting `read_at`) to maximize + // for contiguous space. + if self.payload_ring.is_empty() { + self.payload_ring.clear(); } let window = self.payload_ring.window(); let contig_window = self.payload_ring.contiguous_window(); if window < size { - return Err(Error::Exhausted) + return Err(Full); } else if contig_window < size { if window - contig_window < size { // The buffer length is larger than the current contiguous window // and is larger than the contiguous window will be after adding // the padding necessary to circle around to the beginning of the // ring buffer. - return Err(Error::Exhausted) + return Err(Full); } else { // Add padding to the end of the ring buffer so that the // contiguous window is at the beginning of the ring buffer. *self.metadata_ring.enqueue_one()? = PacketMetadata::padding(contig_window); - self.payload_ring.enqueue_many(contig_window); + // note(discard): function does not write to the result + // enqueued padding buffer location + let _buf_enqueued = self.payload_ring.enqueue_many(contig_window); } } @@ -108,64 +117,115 @@ impl<'a, 'b, H> PacketBuffer<'a, 'b, H> { Ok(payload_buf) } - fn dequeue_padding(&mut self) { - let Self { ref mut metadata_ring, ref mut payload_ring } = *self; + /// Call `f` with a packet from the buffer large enough to fit `max_size` bytes. The packet + /// is shrunk to the size returned from `f` and enqueued into the buffer. + pub fn enqueue_with_infallible<'b, F>( + &'b mut self, + max_size: usize, + header: H, + f: F, + ) -> Result + where + F: FnOnce(&'b mut [u8]) -> usize, + { + if self.payload_ring.capacity() < max_size || self.metadata_ring.is_full() { + return Err(Full); + } + + let window = self.payload_ring.window(); + let contig_window = self.payload_ring.contiguous_window(); + + if window < max_size { + return Err(Full); + } else if contig_window < max_size { + if window - contig_window < max_size { + // The buffer length is larger than the current contiguous window + // and is larger than the contiguous window will be after adding + // the padding necessary to circle around to the beginning of the + // ring buffer. + return Err(Full); + } else { + // Add padding to the end of the ring buffer so that the + // contiguous window is at the beginning of the ring buffer. + *self.metadata_ring.enqueue_one()? = PacketMetadata::padding(contig_window); + // note(discard): function does not write to the result + // enqueued padding buffer location + let _buf_enqueued = self.payload_ring.enqueue_many(contig_window); + } + } + + let (size, _) = self + .payload_ring + .enqueue_many_with(|data| (f(&mut data[..max_size]), ())); - let _ = metadata_ring.dequeue_one_with(|metadata| { + *self.metadata_ring.enqueue_one()? = PacketMetadata::packet(size, header); + + Ok(size) + } + + fn dequeue_padding(&mut self) { + let _ = self.metadata_ring.dequeue_one_with(|metadata| { if metadata.is_padding() { - payload_ring.dequeue_many(metadata.size); + // note(discard): function does not use value of dequeued padding bytes + let _buf_dequeued = self.payload_ring.dequeue_many(metadata.size); Ok(()) // dequeue metadata } else { - Err(Error::Exhausted) // don't dequeue metadata + Err(()) // don't dequeue metadata } }); } /// Call `f` with a single packet from the buffer, and dequeue the packet if `f` - /// returns successfully, or return `Err(Error::Exhausted)` if the buffer is empty. - pub fn dequeue_with<'c, R, F>(&'c mut self, f: F) -> Result - where F: FnOnce(&mut H, &'c mut [u8]) -> Result { + /// returns successfully, or return `Err(EmptyError)` if the buffer is empty. + pub fn dequeue_with<'c, R, E, F>(&'c mut self, f: F) -> Result, Empty> + where + F: FnOnce(&mut H, &'c mut [u8]) -> Result, + { self.dequeue_padding(); - let Self { ref mut metadata_ring, ref mut payload_ring } = *self; - - metadata_ring.dequeue_one_with(move |metadata| { - let PacketMetadata { ref mut header, size } = *metadata; - - payload_ring.dequeue_many_with(|payload_buf| { - debug_assert!(payload_buf.len() >= size); - - match f(header.as_mut().unwrap(), &mut payload_buf[..size]) { - Ok(val) => (size, Ok(val)), - Err(err) => (0, Err(err)), - } - }).1 + self.metadata_ring.dequeue_one_with(|metadata| { + self.payload_ring + .dequeue_many_with(|payload_buf| { + debug_assert!(payload_buf.len() >= metadata.size); + + match f( + metadata.header.as_mut().unwrap(), + &mut payload_buf[..metadata.size], + ) { + Ok(val) => (metadata.size, Ok(val)), + Err(err) => (0, Err(err)), + } + }) + .1 }) } /// Dequeue a single packet from the buffer, and return a reference to its payload /// as well as its header, or return `Err(Error::Exhausted)` if the buffer is empty. - pub fn dequeue(&mut self) -> Result<(H, &mut [u8])> { + pub fn dequeue(&mut self) -> Result<(H, &mut [u8]), Empty> { self.dequeue_padding(); - let PacketMetadata { ref mut header, size } = *self.metadata_ring.dequeue_one()?; + let meta = self.metadata_ring.dequeue_one()?; - let payload_buf = self.payload_ring.dequeue_many(size); - debug_assert!(payload_buf.len() == size); - Ok((header.take().unwrap(), payload_buf)) + let payload_buf = self.payload_ring.dequeue_many(meta.size); + debug_assert!(payload_buf.len() == meta.size); + Ok((meta.header.take().unwrap(), payload_buf)) } /// Peek at a single packet from the buffer without removing it, and return a reference to - /// its payload as well as its header, or return `Err(Error:Exhaused)` if the buffer is empty. + /// its payload as well as its header, or return `Err(Error:Exhausted)` if the buffer is empty. /// /// This function otherwise behaves identically to [dequeue](#method.dequeue). - pub fn peek(&mut self) -> Result<(&H, &[u8])> { + pub fn peek(&mut self) -> Result<(&H, &[u8]), Empty> { self.dequeue_padding(); if let Some(metadata) = self.metadata_ring.get_allocated(0, 1).first() { - Ok((metadata.header.as_ref().unwrap(), self.payload_ring.get_allocated(0, metadata.size))) + Ok(( + metadata.header.as_ref().unwrap(), + self.payload_ring.get_allocated(0, metadata.size), + )) } else { - Err(Error::Exhausted) + Err(Empty) } } @@ -178,36 +238,42 @@ impl<'a, 'b, H> PacketBuffer<'a, 'b, H> { pub fn payload_capacity(&self) -> usize { self.payload_ring.capacity() } + + /// Reset the packet buffer and clear any staged. + #[allow(unused)] + pub(crate) fn reset(&mut self) { + self.payload_ring.clear(); + self.metadata_ring.clear(); + } } #[cfg(test)] mod test { use super::*; - fn buffer() -> PacketBuffer<'static, 'static, ()> { - PacketBuffer::new(vec![PacketMetadata::EMPTY; 4], - vec![0u8; 16]) + fn buffer() -> PacketBuffer<'static, ()> { + PacketBuffer::new(vec![PacketMetadata::EMPTY; 4], vec![0u8; 16]) } #[test] fn test_simple() { let mut buffer = buffer(); buffer.enqueue(6, ()).unwrap().copy_from_slice(b"abcdef"); - assert_eq!(buffer.enqueue(16, ()), Err(Error::Exhausted)); + assert_eq!(buffer.enqueue(16, ()), Err(Full)); assert_eq!(buffer.metadata_ring.len(), 1); assert_eq!(buffer.dequeue().unwrap().1, &b"abcdef"[..]); - assert_eq!(buffer.dequeue(), Err(Error::Exhausted)); + assert_eq!(buffer.dequeue(), Err(Empty)); } #[test] fn test_peek() { let mut buffer = buffer(); - assert_eq!(buffer.peek(), Err(Error::Exhausted)); + assert_eq!(buffer.peek(), Err(Empty)); buffer.enqueue(6, ()).unwrap().copy_from_slice(b"abcdef"); assert_eq!(buffer.metadata_ring.len(), 1); assert_eq!(buffer.peek().unwrap().1, &b"abcdef"[..]); assert_eq!(buffer.dequeue().unwrap().1, &b"abcdef"[..]); - assert_eq!(buffer.peek(), Err(Error::Exhausted)); + assert_eq!(buffer.peek(), Err(Empty)); } #[test] @@ -229,7 +295,10 @@ mod test { let mut buffer = buffer(); assert!(buffer.enqueue(12, ()).is_ok()); assert!(buffer.dequeue().is_ok()); - buffer.enqueue(12, ()).unwrap().copy_from_slice(b"abcdefghijkl"); + buffer + .enqueue(12, ()) + .unwrap() + .copy_from_slice(b"abcdefghijkl"); } #[test] @@ -242,32 +311,37 @@ mod test { assert_eq!(buffer.metadata_ring.len(), 3); assert!(buffer.dequeue().is_ok()); - assert!(buffer.dequeue_with(|_, _| Err(Error::Unaddressable) as Result<()>).is_err()); + assert!(matches!( + buffer.dequeue_with(|_, _| Result::<(), u32>::Err(123)), + Ok(Err(_)) + )); assert_eq!(buffer.metadata_ring.len(), 1); - assert!(buffer.dequeue_with(|&mut (), payload| { - assert_eq!(payload, &b"abcd"[..]); - Ok(()) - }).is_ok()); + assert!(buffer + .dequeue_with(|&mut (), payload| { + assert_eq!(payload, &b"abcd"[..]); + Result::<(), ()>::Ok(()) + }) + .is_ok()); assert_eq!(buffer.metadata_ring.len(), 0); } #[test] fn test_metadata_full_empty() { let mut buffer = buffer(); - assert_eq!(buffer.is_empty(), true); - assert_eq!(buffer.is_full(), false); + assert!(buffer.is_empty()); + assert!(!buffer.is_full()); assert!(buffer.enqueue(1, ()).is_ok()); - assert_eq!(buffer.is_empty(), false); + assert!(!buffer.is_empty()); assert!(buffer.enqueue(1, ()).is_ok()); assert!(buffer.enqueue(1, ()).is_ok()); - assert_eq!(buffer.is_full(), false); - assert_eq!(buffer.is_empty(), false); + assert!(!buffer.is_full()); + assert!(!buffer.is_empty()); assert!(buffer.enqueue(1, ()).is_ok()); - assert_eq!(buffer.is_full(), true); - assert_eq!(buffer.is_empty(), false); + assert!(buffer.is_full()); + assert!(!buffer.is_empty()); assert_eq!(buffer.metadata_ring.len(), 4); - assert_eq!(buffer.enqueue(1, ()), Err(Error::Exhausted)); + assert_eq!(buffer.enqueue(1, ()), Err(Full)); } #[test] @@ -276,7 +350,7 @@ mod test { assert!(buffer.enqueue(4, ()).is_ok()); assert!(buffer.enqueue(8, ()).is_ok()); assert!(buffer.dequeue().is_ok()); - assert_eq!(buffer.enqueue(16, ()), Err(Error::Exhausted)); + assert_eq!(buffer.enqueue(16, ()), Err(Full)); assert_eq!(buffer.metadata_ring.len(), 1); } @@ -286,14 +360,22 @@ mod test { assert!(buffer.enqueue(4, ()).is_ok()); assert!(buffer.enqueue(8, ()).is_ok()); assert!(buffer.dequeue().is_ok()); - assert_eq!(buffer.enqueue(8, ()), Err(Error::Exhausted)); + assert_eq!(buffer.enqueue(8, ()), Err(Full)); assert_eq!(buffer.metadata_ring.len(), 1); } + #[test] + fn test_contiguous_window_wrap() { + let mut buffer = buffer(); + assert!(buffer.enqueue(15, ()).is_ok()); + assert!(buffer.dequeue().is_ok()); + assert!(buffer.enqueue(16, ()).is_ok()); + } + #[test] fn test_capacity_too_small() { let mut buffer = buffer(); - assert_eq!(buffer.enqueue(32, ()), Err(Error::Truncated)); + assert_eq!(buffer.enqueue(32, ()), Err(Full)); } #[test] @@ -303,4 +385,18 @@ mod test { assert!(buffer.dequeue().is_ok()); assert!(buffer.enqueue(5, ()).is_ok()); } + + #[test] + fn clear() { + let mut buffer = buffer(); + + // Ensure enqueuing data in teh buffer fills it somewhat. + assert!(buffer.is_empty()); + assert!(buffer.enqueue(6, ()).is_ok()); + + // Ensure that resetting the buffer causes it to be empty. + assert!(!buffer.is_empty()); + buffer.reset(); + assert!(buffer.is_empty()); + } } diff --git a/src/storage/ring_buffer.rs b/src/storage/ring_buffer.rs index cf8244ffc..7d461b68c 100644 --- a/src/storage/ring_buffer.rs +++ b/src/storage/ring_buffer.rs @@ -1,11 +1,13 @@ -// Uncomment the #[must_use]s here once [RFC 1940] hits stable. +// Some of the functions in ring buffer is marked as #[must_use]. It notes that +// these functions may have side effects, and it's implemented by [RFC 1940]. // [RFC 1940]: https://github.com/rust-lang/rust/issues/43302 use core::cmp; use managed::ManagedSlice; -use {Error, Result}; -use super::Resettable; +use crate::storage::Resettable; + +use super::{Empty, Full}; /// A ring buffer. /// @@ -25,7 +27,7 @@ use super::Resettable; pub struct RingBuffer<'a, T: 'a> { storage: ManagedSlice<'a, T>, read_at: usize, - length: usize, + length: usize, } impl<'a, T: 'a> RingBuffer<'a, T> { @@ -33,19 +35,20 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// /// During creation, every element in `storage` is reset. pub fn new(storage: S) -> RingBuffer<'a, T> - where S: Into>, + where + S: Into>, { RingBuffer { storage: storage.into(), read_at: 0, - length: 0, + length: 0, } } /// Clear the ring buffer. pub fn clear(&mut self) { self.read_at = 0; - self.length = 0; + self.length = 0; } /// Return the maximum number of elements in the ring buffer. @@ -55,7 +58,9 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// Clear the ring buffer, and reset every element. pub fn reset(&mut self) - where T: Resettable { + where + T: Resettable, + { self.clear(); for elem in self.storage.iter_mut() { elem.reset(); @@ -110,52 +115,57 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// and boundary conditions (empty/full) are errors. impl<'a, T: 'a> RingBuffer<'a, T> { /// Call `f` with a single buffer element, and enqueue the element if `f` - /// returns successfully, or return `Err(Error::Exhausted)` if the buffer is full. - pub fn enqueue_one_with<'b, R, F>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut T) -> Result { - if self.is_full() { return Err(Error::Exhausted) } + /// returns successfully, or return `Err(Full)` if the buffer is full. + pub fn enqueue_one_with<'b, R, E, F>(&'b mut self, f: F) -> Result, Full> + where + F: FnOnce(&'b mut T) -> Result, + { + if self.is_full() { + return Err(Full); + } let index = self.get_idx_unchecked(self.length); - match f(&mut self.storage[index]) { - Ok(result) => { - self.length += 1; - Ok(result) - } - Err(error) => Err(error) + let res = f(&mut self.storage[index]); + if res.is_ok() { + self.length += 1; } + Ok(res) } /// Enqueue a single element into the buffer, and return a reference to it, - /// or return `Err(Error::Exhausted)` if the buffer is full. + /// or return `Err(Full)` if the buffer is full. /// /// This function is a shortcut for `ring_buf.enqueue_one_with(Ok)`. - pub fn enqueue_one<'b>(&'b mut self) -> Result<&'b mut T> { - self.enqueue_one_with(Ok) + pub fn enqueue_one(&mut self) -> Result<&mut T, Full> { + self.enqueue_one_with(Ok)? } /// Call `f` with a single buffer element, and dequeue the element if `f` - /// returns successfully, or return `Err(Error::Exhausted)` if the buffer is empty. - pub fn dequeue_one_with<'b, R, F>(&'b mut self, f: F) -> Result - where F: FnOnce(&'b mut T) -> Result { - if self.is_empty() { return Err(Error::Exhausted) } + /// returns successfully, or return `Err(Empty)` if the buffer is empty. + pub fn dequeue_one_with<'b, R, E, F>(&'b mut self, f: F) -> Result, Empty> + where + F: FnOnce(&'b mut T) -> Result, + { + if self.is_empty() { + return Err(Empty); + } let next_at = self.get_idx_unchecked(1); - match f(&mut self.storage[self.read_at]) { - Ok(result) => { - self.length -= 1; - self.read_at = next_at; - Ok(result) - } - Err(error) => Err(error) + let res = f(&mut self.storage[self.read_at]); + + if res.is_ok() { + self.length -= 1; + self.read_at = next_at; } + Ok(res) } /// Dequeue an element from the buffer, and return a reference to it, - /// or return `Err(Error::Exhausted)` if the buffer is empty. + /// or return `Err(Empty)` if the buffer is empty. /// /// This function is a shortcut for `ring_buf.dequeue_one_with(Ok)`. - pub fn dequeue_one(&mut self) -> Result<&mut T> { - self.dequeue_one_with(Ok) + pub fn dequeue_one(&mut self) -> Result<&mut T, Empty> { + self.dequeue_one_with(Ok)? } } @@ -169,7 +179,9 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// This function panics if the amount of elements returned by `f` is larger /// than the size of the slice passed into it. pub fn enqueue_many_with<'b, R, F>(&'b mut self, f: F) -> (usize, R) - where F: FnOnce(&'b mut [T]) -> (usize, R) { + where + F: FnOnce(&'b mut [T]) -> (usize, R), + { if self.length == 0 { // Ring is currently empty. Reset `read_at` to optimize // for contiguous space. @@ -189,19 +201,22 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// /// This function may return a slice smaller than the given size /// if the free space in the buffer is not contiguous. - // #[must_use] - pub fn enqueue_many<'b>(&'b mut self, size: usize) -> &'b mut [T] { + #[must_use] + pub fn enqueue_many(&mut self, size: usize) -> &mut [T] { self.enqueue_many_with(|buf| { let size = cmp::min(size, buf.len()); (size, &mut buf[..size]) - }).1 + }) + .1 } /// Enqueue as many elements from the given slice into the buffer as possible, /// and return the amount of elements that could fit. - // #[must_use] + #[must_use] pub fn enqueue_slice(&mut self, data: &[T]) -> usize - where T: Copy { + where + T: Copy, + { let (size_1, data) = self.enqueue_many_with(|buf| { let size = cmp::min(buf.len(), data.len()); buf[..size].copy_from_slice(&data[..size]); @@ -222,7 +237,9 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// This function panics if the amount of elements returned by `f` is larger /// than the size of the slice passed into it. pub fn dequeue_many_with<'b, R, F>(&'b mut self, f: F) -> (usize, R) - where F: FnOnce(&'b mut [T]) -> (usize, R) { + where + F: FnOnce(&'b mut [T]) -> (usize, R), + { let capacity = self.capacity(); let max_size = cmp::min(self.len(), capacity - self.read_at); let (size, result) = f(&mut self.storage[self.read_at..self.read_at + max_size]); @@ -241,19 +258,22 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// /// This function may return a slice smaller than the given size /// if the allocated space in the buffer is not contiguous. - // #[must_use] - pub fn dequeue_many<'b>(&'b mut self, size: usize) -> &'b mut [T] { + #[must_use] + pub fn dequeue_many(&mut self, size: usize) -> &mut [T] { self.dequeue_many_with(|buf| { let size = cmp::min(size, buf.len()); (size, &mut buf[..size]) - }).1 + }) + .1 } /// Dequeue as many elements from the buffer into the given slice as possible, /// and return the amount of elements that could fit. - // #[must_use] + #[must_use] pub fn dequeue_slice(&mut self, data: &mut [T]) -> usize - where T: Copy { + where + T: Copy, + { let (size_1, data) = self.dequeue_many_with(|buf| { let size = cmp::min(buf.len(), data.len()); data[..size].copy_from_slice(&buf[..size]); @@ -273,17 +293,23 @@ impl<'a, T: 'a> RingBuffer<'a, T> { impl<'a, T: 'a> RingBuffer<'a, T> { /// Return the largest contiguous slice of unallocated buffer elements starting /// at the given offset past the last allocated element, and up to the given size. - // #[must_use] + #[must_use] pub fn get_unallocated(&mut self, offset: usize, mut size: usize) -> &mut [T] { let start_at = self.get_idx(self.length + offset); // We can't access past the end of unallocated data. - if offset > self.window() { return &mut [] } + if offset > self.window() { + return &mut []; + } // We can't enqueue more than there is free space. let clamped_window = self.window() - offset; - if size > clamped_window { size = clamped_window } + if size > clamped_window { + size = clamped_window + } // We can't contiguously enqueue past the end of the storage. let until_end = self.capacity() - start_at; - if size > until_end { size = until_end } + if size > until_end { + size = until_end + } &mut self.storage[start_at..start_at + size] } @@ -291,9 +317,11 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// Write as many elements from the given slice into unallocated buffer elements /// starting at the given offset past the last allocated element, and return /// the amount written. - // #[must_use] + #[must_use] pub fn write_unallocated(&mut self, offset: usize, data: &[T]) -> usize - where T: Copy { + where + T: Copy, + { let (size_1, offset, data) = { let slice = self.get_unallocated(offset, data.len()); let slice_len = slice.len(); @@ -320,17 +348,23 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// Return the largest contiguous slice of allocated buffer elements starting /// at the given offset past the first allocated element, and up to the given size. - // #[must_use] + #[must_use] pub fn get_allocated(&self, offset: usize, mut size: usize) -> &[T] { let start_at = self.get_idx(offset); // We can't read past the end of the allocated data. - if offset > self.length { return &mut [] } + if offset > self.length { + return &mut []; + } // We can't read more than we have allocated. let clamped_length = self.length - offset; - if size > clamped_length { size = clamped_length } + if size > clamped_length { + size = clamped_length + } // We can't contiguously dequeue past the end of the storage. let until_end = self.capacity() - start_at; - if size > until_end { size = until_end } + if size > until_end { + size = until_end + } &self.storage[start_at..start_at + size] } @@ -338,9 +372,11 @@ impl<'a, T: 'a> RingBuffer<'a, T> { /// Read as many elements from allocated buffer elements into the given slice /// starting at the given offset past the first allocated element, and return /// the amount read. - // #[must_use] + #[must_use] pub fn read_allocated(&mut self, offset: usize, data: &mut [T]) -> usize - where T: Copy { + where + T: Copy, + { let (size_1, offset, data) = { let slice = self.get_allocated(offset, data.len()); data[..slice.len()].copy_from_slice(slice); @@ -402,34 +438,45 @@ mod test { #[test] fn test_buffer_enqueue_dequeue_one_with() { let mut ring = RingBuffer::new(vec![0; 5]); - assert_eq!(ring.dequeue_one_with(|_| unreachable!()) as Result<()>, - Err(Error::Exhausted)); + assert_eq!( + ring.dequeue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Empty) + ); - ring.enqueue_one_with(|e| Ok(e)).unwrap(); + ring.enqueue_one_with(Ok::<_, ()>).unwrap().unwrap(); assert!(!ring.is_empty()); assert!(!ring.is_full()); for i in 1..5 { - ring.enqueue_one_with(|e| Ok(*e = i)).unwrap(); + ring.enqueue_one_with(|e| Ok::<_, ()>(*e = i)) + .unwrap() + .unwrap(); assert!(!ring.is_empty()); } assert!(ring.is_full()); - assert_eq!(ring.enqueue_one_with(|_| unreachable!()) as Result<()>, - Err(Error::Exhausted)); + assert_eq!( + ring.enqueue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Full) + ); for i in 0..5 { - assert_eq!(ring.dequeue_one_with(|e| Ok(*e)).unwrap(), i); + assert_eq!( + ring.dequeue_one_with(|e| Ok::<_, ()>(*e)).unwrap().unwrap(), + i + ); assert!(!ring.is_full()); } - assert_eq!(ring.dequeue_one_with(|_| unreachable!()) as Result<()>, - Err(Error::Exhausted)); + assert_eq!( + ring.dequeue_one_with(|_| -> Result::<(), ()> { unreachable!() }), + Err(Empty) + ); assert!(ring.is_empty()); } #[test] fn test_buffer_enqueue_dequeue_one() { let mut ring = RingBuffer::new(vec![0; 5]); - assert_eq!(ring.dequeue_one(), Err(Error::Exhausted)); + assert_eq!(ring.dequeue_one(), Err(Empty)); ring.enqueue_one().unwrap(); assert!(!ring.is_empty()); @@ -440,13 +487,13 @@ mod test { assert!(!ring.is_empty()); } assert!(ring.is_full()); - assert_eq!(ring.enqueue_one(), Err(Error::Exhausted)); + assert_eq!(ring.enqueue_one(), Err(Full)); for i in 0..5 { assert_eq!(*ring.dequeue_one().unwrap(), i); assert!(!ring.is_full()); } - assert_eq!(ring.dequeue_one(), Err(Error::Exhausted)); + assert_eq!(ring.dequeue_one(), Err(Empty)); assert!(ring.is_empty()); } @@ -454,11 +501,14 @@ mod test { fn test_buffer_enqueue_many_with() { let mut ring = RingBuffer::new(vec![b'.'; 12]); - assert_eq!(ring.enqueue_many_with(|buf| { - assert_eq!(buf.len(), 12); - buf[0..2].copy_from_slice(b"ab"); + assert_eq!( + ring.enqueue_many_with(|buf| { + assert_eq!(buf.len(), 12); + buf[0..2].copy_from_slice(b"ab"); + (2, true) + }), (2, true) - }), (2, true)); + ); assert_eq!(ring.len(), 2); assert_eq!(&ring.storage[..], b"ab.........."); @@ -545,12 +595,15 @@ mod test { assert_eq!(ring.enqueue_slice(b"abcdefghijkl"), 12); - assert_eq!(ring.dequeue_many_with(|buf| { - assert_eq!(buf.len(), 12); - assert_eq!(buf, b"abcdefghijkl"); - buf[..4].copy_from_slice(b"...."); + assert_eq!( + ring.dequeue_many_with(|buf| { + assert_eq!(buf.len(), 12); + assert_eq!(buf, b"abcdefghijkl"); + buf[..4].copy_from_slice(b"...."); + (4, true) + }), (4, true) - }), (4, true)); + ); assert_eq!(ring.len(), 8); assert_eq!(&ring.storage[..], b"....efghijkl"); @@ -637,7 +690,8 @@ mod test { } assert_eq!(&ring.storage[..], b"abcd........"); - ring.enqueue_many(4); + let buf_enqueued = ring.enqueue_many(4); + assert_eq!(buf_enqueued.len(), 4); assert_eq!(ring.len(), 4); { @@ -679,17 +733,20 @@ mod test { let mut ring = RingBuffer::new(vec![b'.'; 12]); assert_eq!(ring.get_allocated(16, 4), b""); - assert_eq!(ring.get_allocated(0, 4), b""); + assert_eq!(ring.get_allocated(0, 4), b""); - ring.enqueue_slice(b"abcd"); + let len_enqueued = ring.enqueue_slice(b"abcd"); assert_eq!(ring.get_allocated(0, 8), b"abcd"); + assert_eq!(len_enqueued, 4); - ring.enqueue_slice(b"efghijkl"); + let len_enqueued = ring.enqueue_slice(b"efghijkl"); ring.dequeue_many(4).copy_from_slice(b"...."); assert_eq!(ring.get_allocated(4, 8), b"ijkl"); + assert_eq!(len_enqueued, 8); - ring.enqueue_slice(b"abcd"); + let len_enqueued = ring.enqueue_slice(b"abcd"); assert_eq!(ring.get_allocated(4, 8), b"ijkl"); + assert_eq!(len_enqueued, 4); } #[test] @@ -711,7 +768,6 @@ mod test { let mut data = [0; 6]; assert_eq!(ring.read_allocated(6, &mut data[..]), 3); assert_eq!(&data[..], b"mno\x00\x00\x00"); - } #[test] @@ -724,7 +780,7 @@ mod test { assert_eq!(no_capacity.get_allocated(0, 0), &[]); no_capacity.dequeue_allocated(0); assert_eq!(no_capacity.enqueue_many(0), &[]); - assert_eq!(no_capacity.enqueue_one(), Err(Error::Exhausted)); + assert_eq!(no_capacity.enqueue_one(), Err(Full)); assert_eq!(no_capacity.contiguous_window(), 0); } @@ -734,10 +790,11 @@ mod test { #[test] fn test_buffer_write_wholly() { let mut ring = RingBuffer::new(vec![b'.'; 8]); - ring.enqueue_many(2).copy_from_slice(b"xx"); - ring.enqueue_many(2).copy_from_slice(b"xx"); + ring.enqueue_many(2).copy_from_slice(b"ab"); + ring.enqueue_many(2).copy_from_slice(b"cd"); assert_eq!(ring.len(), 4); - ring.dequeue_many(4); + let buf_dequeued = ring.dequeue_many(4); + assert_eq!(buf_dequeued, b"abcd"); assert_eq!(ring.len(), 0); let large = ring.enqueue_many(8); diff --git a/src/time.rs b/src/time.rs index 221e148ee..318dacacb 100644 --- a/src/time.rs +++ b/src/time.rs @@ -4,18 +4,18 @@ The `time` module contains structures used to represent both absolute and relative time. - [Instant] is used to represent absolute time. - - [Duration] is used to represet relative time. + - [Duration] is used to represent relative time. [Instant]: struct.Instant.html [Duration]: struct.Duration.html */ -use core::{ops, fmt}; +use core::{fmt, ops}; /// A representation of an absolute time value. /// /// The `Instant` type is a wrapper around a `i64` value that -/// represents a number of milliseconds, monotonically increasing +/// represents a number of microseconds, monotonically increasing /// since an arbitrary moment in time, such as system startup. /// /// * A value of `0` is inherently arbitrary. @@ -23,18 +23,42 @@ use core::{ops, fmt}; /// point. #[derive(Debug, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Instant { - pub millis: i64, + micros: i64, } impl Instant { + pub const ZERO: Instant = Instant::from_micros_const(0); + + /// Create a new `Instant` from a number of microseconds. + pub fn from_micros>(micros: T) -> Instant { + Instant { + micros: micros.into(), + } + } + + pub const fn from_micros_const(micros: i64) -> Instant { + Instant { micros } + } + /// Create a new `Instant` from a number of milliseconds. pub fn from_millis>(millis: T) -> Instant { - Instant { millis: millis.into() } + Instant { + micros: millis.into() * 1000, + } + } + + /// Create a new `Instant` from a number of milliseconds. + pub const fn from_millis_const(millis: i64) -> Instant { + Instant { + micros: millis * 1000, + } } /// Create a new `Instant` from a number of seconds. pub fn from_secs>(secs: T) -> Instant { - Instant { millis: secs.into() * 1000 } + Instant { + micros: secs.into() * 1000000, + } } /// Create a new `Instant` from the current [std::time::SystemTime]. @@ -50,20 +74,31 @@ impl Instant { /// The fractional number of milliseconds that have passed /// since the beginning of time. - pub fn millis(&self) -> i64 { - self.millis % 1000 + pub const fn millis(&self) -> i64 { + self.micros % 1000000 / 1000 + } + + /// The fractional number of microseconds that have passed + /// since the beginning of time. + pub const fn micros(&self) -> i64 { + self.micros % 1000000 } /// The number of whole seconds that have passed since the /// beginning of time. - pub fn secs(&self) -> i64 { - self.millis / 1000 + pub const fn secs(&self) -> i64 { + self.micros / 1000000 } /// The total number of milliseconds that have passed since - /// the biginning of time. - pub fn total_millis(&self) -> i64 { - self.millis + /// the beginning of time. + pub const fn total_millis(&self) -> i64 { + self.micros / 1000 + } + /// The total number of milliseconds that have passed since + /// the beginning of time. + pub const fn total_micros(&self) -> i64 { + self.micros } } @@ -71,29 +106,37 @@ impl Instant { impl From<::std::time::Instant> for Instant { fn from(other: ::std::time::Instant) -> Instant { let elapsed = other.elapsed(); - Instant::from_millis((elapsed.as_secs() * 1_000) as i64 + (elapsed.subsec_nanos() / 1_000_000) as i64) + Instant::from_micros((elapsed.as_secs() * 1_000000) as i64 + elapsed.subsec_micros() as i64) } } #[cfg(feature = "std")] impl From<::std::time::SystemTime> for Instant { fn from(other: ::std::time::SystemTime) -> Instant { - let n = other.duration_since(::std::time::UNIX_EPOCH) + let n = other + .duration_since(::std::time::UNIX_EPOCH) .expect("start time must not be before the unix epoch"); - Self::from_millis(n.as_secs() as i64 * 1000 + (n.subsec_nanos() / 1000000) as i64) + Self::from_micros(n.as_secs() as i64 * 1000000 + n.subsec_micros() as i64) } } #[cfg(feature = "std")] -impl Into<::std::time::SystemTime> for Instant { - fn into(self) -> ::std::time::SystemTime { - ::std::time::UNIX_EPOCH + ::std::time::Duration::from_millis(self.millis as u64) +impl From for ::std::time::SystemTime { + fn from(val: Instant) -> Self { + ::std::time::UNIX_EPOCH + ::std::time::Duration::from_micros(val.micros as u64) } } impl fmt::Display for Instant { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "{}.{}s", self.secs(), self.millis()) + write!(f, "{}.{:0>3}s", self.secs(), self.millis()) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Instant { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}.{:03}s", self.secs(), self.millis()); } } @@ -101,13 +144,13 @@ impl ops::Add for Instant { type Output = Instant; fn add(self, rhs: Duration) -> Instant { - Instant::from_millis(self.millis + rhs.total_millis() as i64) + Instant::from_micros(self.micros + rhs.total_micros() as i64) } } impl ops::AddAssign for Instant { fn add_assign(&mut self, rhs: Duration) { - self.millis += rhs.total_millis() as i64; + self.micros += rhs.total_micros() as i64; } } @@ -115,13 +158,13 @@ impl ops::Sub for Instant { type Output = Instant; fn sub(self, rhs: Duration) -> Instant { - Instant::from_millis(self.millis - rhs.total_millis() as i64) + Instant::from_micros(self.micros - rhs.total_micros() as i64) } } impl ops::SubAssign for Instant { fn sub_assign(&mut self, rhs: Duration) { - self.millis -= rhs.total_millis() as i64; + self.micros -= rhs.total_micros() as i64; } } @@ -129,40 +172,60 @@ impl ops::Sub for Instant { type Output = Duration; fn sub(self, rhs: Instant) -> Duration { - Duration::from_millis((self.millis - rhs.millis).abs() as u64) + Duration::from_micros((self.micros - rhs.micros).unsigned_abs()) } } /// A relative amount of time. #[derive(Debug, Default, Copy, Clone, PartialEq, Eq, PartialOrd, Ord)] pub struct Duration { - pub millis: u64, + micros: u64, } impl Duration { + pub const ZERO: Duration = Duration::from_micros(0); + /// Create a new `Duration` from a number of microseconds. + pub const fn from_micros(micros: u64) -> Duration { + Duration { micros } + } + /// Create a new `Duration` from a number of milliseconds. - pub fn from_millis(millis: u64) -> Duration { - Duration { millis } + pub const fn from_millis(millis: u64) -> Duration { + Duration { + micros: millis * 1000, + } } /// Create a new `Instant` from a number of seconds. - pub fn from_secs(secs: u64) -> Duration { - Duration { millis: secs * 1000 } + pub const fn from_secs(secs: u64) -> Duration { + Duration { + micros: secs * 1000000, + } } /// The fractional number of milliseconds in this `Duration`. - pub fn millis(&self) -> u64 { - self.millis % 1000 + pub const fn millis(&self) -> u64 { + self.micros / 1000 % 1000 + } + + /// The fractional number of milliseconds in this `Duration`. + pub const fn micros(&self) -> u64 { + self.micros % 1000000 } /// The number of whole seconds in this `Duration`. - pub fn secs(&self) -> u64 { - self.millis / 1000 + pub const fn secs(&self) -> u64 { + self.micros / 1000000 } /// The total number of milliseconds in this `Duration`. - pub fn total_millis(&self) -> u64 { - self.millis + pub const fn total_millis(&self) -> u64 { + self.micros / 1000 + } + + /// The total number of microseconds in this `Duration`. + pub const fn total_micros(&self) -> u64 { + self.micros } } @@ -172,17 +235,24 @@ impl fmt::Display for Duration { } } +#[cfg(feature = "defmt")] +impl defmt::Format for Duration { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}.{:03}s", self.secs(), self.millis()); + } +} + impl ops::Add for Duration { type Output = Duration; fn add(self, rhs: Duration) -> Duration { - Duration::from_millis(self.millis + rhs.total_millis()) + Duration::from_micros(self.micros + rhs.total_micros()) } } impl ops::AddAssign for Duration { fn add_assign(&mut self, rhs: Duration) { - self.millis += rhs.total_millis(); + self.micros += rhs.total_micros(); } } @@ -190,15 +260,20 @@ impl ops::Sub for Duration { type Output = Duration; fn sub(self, rhs: Duration) -> Duration { - Duration::from_millis( - self.millis.checked_sub(rhs.total_millis()).expect("overflow when subtracting durations")) + Duration::from_micros( + self.micros + .checked_sub(rhs.total_micros()) + .expect("overflow when subtracting durations"), + ) } } impl ops::SubAssign for Duration { fn sub_assign(&mut self, rhs: Duration) { - self.millis = self.millis.checked_sub( - rhs.total_millis()).expect("overflow when subtracting durations"); + self.micros = self + .micros + .checked_sub(rhs.total_micros()) + .expect("overflow when subtracting durations"); } } @@ -206,13 +281,13 @@ impl ops::Mul for Duration { type Output = Duration; fn mul(self, rhs: u32) -> Duration { - Duration::from_millis(self.millis * rhs as u64) + Duration::from_micros(self.micros * rhs as u64) } } impl ops::MulAssign for Duration { fn mul_assign(&mut self, rhs: u32) { - self.millis *= rhs as u64; + self.micros *= rhs as u64; } } @@ -220,29 +295,53 @@ impl ops::Div for Duration { type Output = Duration; fn div(self, rhs: u32) -> Duration { - Duration::from_millis(self.millis / rhs as u64) + Duration::from_micros(self.micros / rhs as u64) } } impl ops::DivAssign for Duration { fn div_assign(&mut self, rhs: u32) { - self.millis /= rhs as u64; + self.micros /= rhs as u64; + } +} + +impl ops::Shl for Duration { + type Output = Duration; + + fn shl(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros << rhs) + } +} + +impl ops::ShlAssign for Duration { + fn shl_assign(&mut self, rhs: u32) { + self.micros <<= rhs; + } +} + +impl ops::Shr for Duration { + type Output = Duration; + + fn shr(self, rhs: u32) -> Duration { + Duration::from_micros(self.micros >> rhs) + } +} + +impl ops::ShrAssign for Duration { + fn shr_assign(&mut self, rhs: u32) { + self.micros >>= rhs; } } impl From<::core::time::Duration> for Duration { fn from(other: ::core::time::Duration) -> Duration { - Duration::from_millis( - other.as_secs() * 1000 + (other.subsec_nanos() / 1_000_000) as u64 - ) + Duration::from_micros(other.as_secs() * 1000000 + other.subsec_micros() as u64) } } -impl Into<::core::time::Duration> for Duration { - fn into(self) -> ::core::time::Duration { - ::core::time::Duration::from_millis( - self.total_millis() - ) +impl From for ::core::time::Duration { + fn from(val: Duration) -> Self { + ::core::time::Duration::from_micros(val.total_micros()) } } @@ -253,9 +352,15 @@ mod test { #[test] fn test_instant_ops() { // std::ops::Add - assert_eq!(Instant::from_millis(4) + Duration::from_millis(6), Instant::from_millis(10)); + assert_eq!( + Instant::from_millis(4) + Duration::from_millis(6), + Instant::from_millis(10) + ); // std::ops::Sub - assert_eq!(Instant::from_millis(7) - Duration::from_millis(5), Instant::from_millis(2)); + assert_eq!( + Instant::from_millis(7) - Duration::from_millis(5), + Instant::from_millis(2) + ); } #[test] @@ -268,31 +373,43 @@ mod test { #[test] fn test_instant_display() { + assert_eq!(format!("{}", Instant::from_millis(74)), "0.074s"); assert_eq!(format!("{}", Instant::from_millis(5674)), "5.674s"); - assert_eq!(format!("{}", Instant::from_millis(5000)), "5.0s"); + assert_eq!(format!("{}", Instant::from_millis(5000)), "5.000s"); } #[test] #[cfg(feature = "std")] fn test_instant_conversions() { let mut epoc: ::std::time::SystemTime = Instant::from_millis(0).into(); - assert_eq!(Instant::from(::std::time::UNIX_EPOCH), - Instant::from_millis(0)); + assert_eq!( + Instant::from(::std::time::UNIX_EPOCH), + Instant::from_millis(0) + ); assert_eq!(epoc, ::std::time::UNIX_EPOCH); epoc = Instant::from_millis(2085955200i64 * 1000).into(); - assert_eq!(epoc, ::std::time::UNIX_EPOCH + ::std::time::Duration::from_secs(2085955200)); + assert_eq!( + epoc, + ::std::time::UNIX_EPOCH + ::std::time::Duration::from_secs(2085955200) + ); } #[test] fn test_duration_ops() { // std::ops::Add - assert_eq!(Duration::from_millis(40) + Duration::from_millis(2), Duration::from_millis(42)); + assert_eq!( + Duration::from_millis(40) + Duration::from_millis(2), + Duration::from_millis(42) + ); // std::ops::Sub - assert_eq!(Duration::from_millis(555) - Duration::from_millis(42), Duration::from_millis(513)); + assert_eq!( + Duration::from_millis(555) - Duration::from_millis(42), + Duration::from_millis(513) + ); // std::ops::Mul assert_eq!(Duration::from_millis(13) * 22, Duration::from_millis(286)); // std::ops::Div - assert_eq!(Duration::from_millis(53) / 4, Duration::from_millis(13)); + assert_eq!(Duration::from_millis(53) / 4, Duration::from_micros(13250)); } #[test] @@ -305,7 +422,7 @@ mod test { duration *= 4; assert_eq!(duration, Duration::from_millis(20936)); duration /= 5; - assert_eq!(duration, Duration::from_millis(4187)); + assert_eq!(duration, Duration::from_micros(4187200)); } #[test] diff --git a/src/wire/arp.rs b/src/wire/arp.rs index 7e66d1d44..bb0df3a0e 100644 --- a/src/wire/arp.rs +++ b/src/wire/arp.rs @@ -1,7 +1,7 @@ -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; -use {Error, Result}; +use super::{Error, Result}; pub use super::EthernetProtocol as Protocol; @@ -21,42 +21,43 @@ enum_with_unknown! { } /// A read/write wrapper around an Address Resolution Protocol packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; pub const HTYPE: Field = 0..2; pub const PTYPE: Field = 2..4; - pub const HLEN: usize = 4; - pub const PLEN: usize = 5; - pub const OPER: Field = 6..8; + pub const HLEN: usize = 4; + pub const PLEN: usize = 5; + pub const OPER: Field = 6..8; #[inline] - pub fn SHA(hardware_len: u8, _protocol_len: u8) -> Field { + pub const fn SHA(hardware_len: u8, _protocol_len: u8) -> Field { let start = OPER.end; start..(start + hardware_len as usize) } #[inline] - pub fn SPA(hardware_len: u8, protocol_len: u8) -> Field { + pub const fn SPA(hardware_len: u8, protocol_len: u8) -> Field { let start = SHA(hardware_len, protocol_len).end; start..(start + protocol_len as usize) } #[inline] - pub fn THA(hardware_len: u8, protocol_len: u8) -> Field { + pub const fn THA(hardware_len: u8, protocol_len: u8) -> Field { let start = SPA(hardware_len, protocol_len).end; start..(start + hardware_len as usize) } #[inline] - pub fn TPA(hardware_len: u8, protocol_len: u8) -> Field { + pub const fn TPA(hardware_len: u8, protocol_len: u8) -> Field { let start = THA(hardware_len, protocol_len).end; start..(start + protocol_len as usize) } @@ -64,7 +65,7 @@ mod field { impl> Packet { /// Imbue a raw octet buffer with ARP packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -79,19 +80,20 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_hardware_len] or /// [set_protocol_len]. /// /// [set_hardware_len]: #method.set_hardware_len /// [set_protocol_len]: #method.set_protocol_len + #[allow(clippy::if_same_then_else)] pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::OPER.end { - Err(Error::Truncated) + Err(Error) } else if len < field::TPA(self.hardware_len(), self.protocol_len()).end { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -248,10 +250,12 @@ impl> AsRef<[u8]> for Packet { } } -use super::{EthernetAddress, Ipv4Address}; +use crate::wire::{EthernetAddress, Ipv4Address}; /// A high-level representation of an Address Resolution Protocol packet. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub enum Repr { /// An Ethernet and IPv4 Address Resolution Protocol packet. EthernetIpv4 { @@ -259,50 +263,47 @@ pub enum Repr { source_hardware_addr: EthernetAddress, source_protocol_addr: Ipv4Address, target_hardware_addr: EthernetAddress, - target_protocol_addr: Ipv4Address + target_protocol_addr: Ipv4Address, }, - #[doc(hidden)] - __Nonexhaustive } impl Repr { /// Parse an Address Resolution Protocol packet and return a high-level representation, - /// or return `Err(Error::Unrecognized)` if the packet is not recognized. + /// or return `Err(Error)` if the packet is not recognized. pub fn parse>(packet: &Packet) -> Result { - match (packet.hardware_type(), packet.protocol_type(), - packet.hardware_len(), packet.protocol_len()) { - (Hardware::Ethernet, Protocol::Ipv4, 6, 4) => { - Ok(Repr::EthernetIpv4 { - operation: packet.operation(), - source_hardware_addr: - EthernetAddress::from_bytes(packet.source_hardware_addr()), - source_protocol_addr: - Ipv4Address::from_bytes(packet.source_protocol_addr()), - target_hardware_addr: - EthernetAddress::from_bytes(packet.target_hardware_addr()), - target_protocol_addr: - Ipv4Address::from_bytes(packet.target_protocol_addr()) - }) - }, - _ => Err(Error::Unrecognized) + match ( + packet.hardware_type(), + packet.protocol_type(), + packet.hardware_len(), + packet.protocol_len(), + ) { + (Hardware::Ethernet, Protocol::Ipv4, 6, 4) => Ok(Repr::EthernetIpv4 { + operation: packet.operation(), + source_hardware_addr: EthernetAddress::from_bytes(packet.source_hardware_addr()), + source_protocol_addr: Ipv4Address::from_bytes(packet.source_protocol_addr()), + target_hardware_addr: EthernetAddress::from_bytes(packet.target_hardware_addr()), + target_protocol_addr: Ipv4Address::from_bytes(packet.target_protocol_addr()), + }), + _ => Err(Error), } } /// Return the length of a packet that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { - match self { - &Repr::EthernetIpv4 { .. } => field::TPA(6, 4).end, - &Repr::__Nonexhaustive => unreachable!() + pub const fn buffer_len(&self) -> usize { + match *self { + Repr::EthernetIpv4 { .. } => field::TPA(6, 4).end, } } /// Emit a high-level representation into an Address Resolution Protocol packet. pub fn emit + AsMut<[u8]>>(&self, packet: &mut Packet) { - match self { - &Repr::EthernetIpv4 { + match *self { + Repr::EthernetIpv4 { operation, - source_hardware_addr, source_protocol_addr, - target_hardware_addr, target_protocol_addr + source_hardware_addr, + source_protocol_addr, + target_hardware_addr, + target_protocol_addr, } => { packet.set_hardware_type(Hardware::Ethernet); packet.set_protocol_type(Protocol::Ipv4); @@ -313,8 +314,7 @@ impl Repr { packet.set_source_protocol_addr(source_protocol_addr.as_bytes()); packet.set_target_hardware_addr(target_hardware_addr.as_bytes()); packet.set_target_protocol_addr(target_protocol_addr.as_bytes()); - }, - &Repr::__Nonexhaustive => unreachable!() + } } } } @@ -322,16 +322,26 @@ impl Repr { impl> fmt::Display for Packet { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), _ => { write!(f, "ARP (unrecognized)")?; - write!(f, " htype={:?} ptype={:?} hlen={:?} plen={:?} op={:?}", - self.hardware_type(), self.protocol_type(), - self.hardware_len(), self.protocol_len(), - self.operation())?; - write!(f, " sha={:?} spa={:?} tha={:?} tpa={:?}", - self.source_hardware_addr(), self.source_protocol_addr(), - self.target_hardware_addr(), self.target_protocol_addr())?; + write!( + f, + " htype={:?} ptype={:?} hlen={:?} plen={:?} op={:?}", + self.hardware_type(), + self.protocol_type(), + self.hardware_len(), + self.protocol_len(), + self.operation() + )?; + write!( + f, + " sha={:?} spa={:?} tha={:?} tpa={:?}", + self.source_hardware_addr(), + self.source_protocol_addr(), + self.target_hardware_addr(), + self.target_protocol_addr() + )?; Ok(()) } } @@ -340,30 +350,34 @@ impl> fmt::Display for Packet { impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Repr::EthernetIpv4 { + match *self { + Repr::EthernetIpv4 { operation, - source_hardware_addr, source_protocol_addr, - target_hardware_addr, target_protocol_addr + source_hardware_addr, + source_protocol_addr, + target_hardware_addr, + target_protocol_addr, } => { - write!(f, "ARP type=Ethernet+IPv4 src={}/{} tgt={}/{} op={:?}", - source_hardware_addr, source_protocol_addr, - target_hardware_addr, target_protocol_addr, - operation) - }, - &Repr::__Nonexhaustive => unreachable!() + write!( + f, + "ARP type=Ethernet+IPv4 src={source_hardware_addr}/{source_protocol_addr} tgt={target_hardware_addr}/{target_protocol_addr} op={operation:?}" + ) + } } } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { match Packet::new_checked(buffer) { - Err(err) => write!(f, "{}({})", indent, err), - Ok(packet) => write!(f, "{}{}", indent, packet) + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), } } } @@ -372,16 +386,10 @@ impl> PrettyPrint for Packet { mod test { use super::*; - static PACKET_BYTES: [u8; 28] = - [0x00, 0x01, - 0x08, 0x00, - 0x06, - 0x04, - 0x00, 0x01, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, - 0x21, 0x22, 0x23, 0x24, - 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, - 0x41, 0x42, 0x43, 0x44]; + static PACKET_BYTES: [u8; 28] = [ + 0x00, 0x01, 0x08, 0x00, 0x06, 0x04, 0x00, 0x01, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x21, + 0x22, 0x23, 0x24, 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, 0x41, 0x42, 0x43, 0x44, + ]; #[test] fn test_deconstruct() { @@ -391,9 +399,15 @@ mod test { assert_eq!(packet.hardware_len(), 6); assert_eq!(packet.protocol_len(), 4); assert_eq!(packet.operation(), Operation::Request); - assert_eq!(packet.source_hardware_addr(), &[0x11, 0x12, 0x13, 0x14, 0x15, 0x16]); + assert_eq!( + packet.source_hardware_addr(), + &[0x11, 0x12, 0x13, 0x14, 0x15, 0x16] + ); assert_eq!(packet.source_protocol_addr(), &[0x21, 0x22, 0x23, 0x24]); - assert_eq!(packet.target_hardware_addr(), &[0x31, 0x32, 0x33, 0x34, 0x35, 0x36]); + assert_eq!( + packet.target_hardware_addr(), + &[0x31, 0x32, 0x33, 0x34, 0x35, 0x36] + ); assert_eq!(packet.target_protocol_addr(), &[0x41, 0x42, 0x43, 0x44]); } @@ -410,20 +424,20 @@ mod test { packet.set_source_protocol_addr(&[0x21, 0x22, 0x23, 0x24]); packet.set_target_hardware_addr(&[0x31, 0x32, 0x33, 0x34, 0x35, 0x36]); packet.set_target_protocol_addr(&[0x41, 0x42, 0x43, 0x44]); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } fn packet_repr() -> Repr { Repr::EthernetIpv4 { operation: Operation::Request, - source_hardware_addr: - EthernetAddress::from_bytes(&[0x11, 0x12, 0x13, 0x14, 0x15, 0x16]), - source_protocol_addr: - Ipv4Address::from_bytes(&[0x21, 0x22, 0x23, 0x24]), - target_hardware_addr: - EthernetAddress::from_bytes(&[0x31, 0x32, 0x33, 0x34, 0x35, 0x36]), - target_protocol_addr: - Ipv4Address::from_bytes(&[0x41, 0x42, 0x43, 0x44]) + source_hardware_addr: EthernetAddress::from_bytes(&[ + 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, + ]), + source_protocol_addr: Ipv4Address::from_bytes(&[0x21, 0x22, 0x23, 0x24]), + target_hardware_addr: EthernetAddress::from_bytes(&[ + 0x31, 0x32, 0x33, 0x34, 0x35, 0x36, + ]), + target_protocol_addr: Ipv4Address::from_bytes(&[0x41, 0x42, 0x43, 0x44]), } } @@ -439,6 +453,6 @@ mod test { let mut bytes = vec![0xa5; 28]; let mut packet = Packet::new_unchecked(&mut bytes); packet_repr().emit(&mut packet); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } } diff --git a/src/wire/dhcpv4.rs b/src/wire/dhcpv4.rs index 7062c9c54..40a03c15a 100644 --- a/src/wire/dhcpv4.rs +++ b/src/wire/dhcpv4.rs @@ -1,10 +1,17 @@ // See https://tools.ietf.org/html/rfc2131 for the DHCP specification. +use bitflags::bitflags; use byteorder::{ByteOrder, NetworkEndian}; +use core::iter; +use heapless::Vec; -use {Error, Result}; -use super::{EthernetAddress, Ipv4Address}; -use super::arp::Hardware; +use super::{Error, Result}; +use crate::wire::arp::Hardware; +use crate::wire::{EthernetAddress, Ipv4Address}; + +pub const SERVER_PORT: u16 = 67; +pub const CLIENT_PORT: u16 = 68; +pub const MAX_DNS_SERVER_COUNT: usize = 3; const DHCP_MAGIC_NUMBER: u32 = 0x63825363; @@ -30,176 +37,91 @@ enum_with_unknown! { } } +bitflags! { + pub struct Flags: u16 { + const BROADCAST = 0b1000_0000_0000_0000; + } +} + impl MessageType { - fn opcode(&self) -> OpCode { + const fn opcode(&self) -> OpCode { match *self { - MessageType::Discover | MessageType::Inform | MessageType::Request | - MessageType::Decline | MessageType::Release => OpCode::Request, + MessageType::Discover + | MessageType::Inform + | MessageType::Request + | MessageType::Decline + | MessageType::Release => OpCode::Request, MessageType::Offer | MessageType::Ack | MessageType::Nak => OpCode::Reply, MessageType::Unknown(_) => OpCode::Unknown(0), } } } -/// A representation of a single DHCP option. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub enum DhcpOption<'a> { - EndOfList, - Pad, - MessageType(MessageType), - RequestedIp(Ipv4Address), - ClientIdentifier(EthernetAddress), - ServerIdentifier(Ipv4Address), - Router(Ipv4Address), - SubnetMask(Ipv4Address), - MaximumDhcpMessageSize(u16), - Other { kind: u8, data: &'a [u8] } +/// A buffer for DHCP options. +#[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct DhcpOptionWriter<'a> { + /// The underlying buffer, directly from the DHCP packet representation. + buffer: &'a mut [u8], } -impl<'a> DhcpOption<'a> { - pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], DhcpOption<'a>)> { - // See https://tools.ietf.org/html/rfc2132 for all possible DHCP options. +impl<'a> DhcpOptionWriter<'a> { + pub fn new(buffer: &'a mut [u8]) -> Self { + Self { buffer } + } - let (skip_len, option); - match *buffer.get(0).ok_or(Error::Truncated)? { - field::OPT_END => { - skip_len = 1; - option = DhcpOption::EndOfList; - } - field::OPT_PAD => { - skip_len = 1; - option = DhcpOption::Pad; - } - kind => { - let length = *buffer.get(1).ok_or(Error::Truncated)? as usize; - skip_len = length + 2; - let data = buffer.get(2..skip_len).ok_or(Error::Truncated)?; - match (kind, length) { - (field::OPT_END, _) | - (field::OPT_PAD, _) => - unreachable!(), - (field::OPT_DHCP_MESSAGE_TYPE, 1) => { - option = DhcpOption::MessageType(MessageType::from(data[0])); - }, - (field::OPT_REQUESTED_IP, 4) => { - option = DhcpOption::RequestedIp(Ipv4Address::from_bytes(data)); - } - (field::OPT_CLIENT_ID, 7) => { - let hardware_type = Hardware::from(u16::from(data[0])); - if hardware_type != Hardware::Ethernet { - return Err(Error::Unrecognized); - } - option = DhcpOption::ClientIdentifier(EthernetAddress::from_bytes(&data[1..])); - } - (field::OPT_SERVER_IDENTIFIER, 4) => { - option = DhcpOption::ServerIdentifier(Ipv4Address::from_bytes(data)); - } - (field::OPT_ROUTER, 4) => { - option = DhcpOption::Router(Ipv4Address::from_bytes(data)); - } - (field::OPT_SUBNET_MASK, 4) => { - option = DhcpOption::SubnetMask(Ipv4Address::from_bytes(data)); - } - (field::OPT_MAX_DHCP_MESSAGE_SIZE, 2) => { - option = DhcpOption::MaximumDhcpMessageSize(u16::from_be_bytes([data[0], data[1]])); - } - (_, _) => { - option = DhcpOption::Other { kind: kind, data: data }; - } - } - } + /// Emit a [`DhcpOption`] into a [`DhcpOptionWriter`]. + pub fn emit(&mut self, option: DhcpOption<'_>) -> Result<()> { + if option.data.len() > u8::MAX as _ { + return Err(Error); } - Ok((&buffer[skip_len..], option)) - } - pub fn buffer_len(&self) -> usize { - match self { - &DhcpOption::EndOfList => 1, - &DhcpOption::Pad => 1, - &DhcpOption::MessageType(_) => 3, - &DhcpOption::ClientIdentifier(eth_addr) => { - 3 + eth_addr.as_bytes().len() - } - &DhcpOption::RequestedIp(ip) | - &DhcpOption::ServerIdentifier(ip) | - &DhcpOption::Router(ip) | - &DhcpOption::SubnetMask(ip) => { - 2 + ip.as_bytes().len() - }, - &DhcpOption::MaximumDhcpMessageSize(_) => { - 4 - } - &DhcpOption::Other { data, .. } => 2 + data.len() + let total_len = 2 + option.data.len(); + if self.buffer.len() < total_len { + return Err(Error); } + + let (buf, rest) = core::mem::take(&mut self.buffer).split_at_mut(total_len); + self.buffer = rest; + + buf[0] = option.kind; + buf[1] = option.data.len() as _; + buf[2..].copy_from_slice(option.data); + + Ok(()) } - pub fn emit<'b>(&self, buffer: &'b mut [u8]) -> &'b mut [u8] { - let skip_length; - match self { - &DhcpOption::EndOfList => { - skip_length = 1; - buffer[0] = field::OPT_END; - } - &DhcpOption::Pad => { - skip_length = 1; - buffer[0] = field::OPT_PAD; - } - _ => { - skip_length = self.buffer_len(); - buffer[1] = (skip_length - 2) as u8; - match self { - &DhcpOption::EndOfList | &DhcpOption::Pad => unreachable!(), - &DhcpOption::MessageType(value) => { - buffer[0] = field::OPT_DHCP_MESSAGE_TYPE; - buffer[2] = value.into(); - } - &DhcpOption::ClientIdentifier(eth_addr) => { - buffer[0] = field::OPT_CLIENT_ID; - buffer[2] = u16::from(Hardware::Ethernet) as u8; - buffer[3..9].copy_from_slice(eth_addr.as_bytes()); - } - &DhcpOption::RequestedIp(ip) => { - buffer[0] = field::OPT_REQUESTED_IP; - buffer[2..6].copy_from_slice(ip.as_bytes()); - } - &DhcpOption::ServerIdentifier(ip) => { - buffer[0] = field::OPT_SERVER_IDENTIFIER; - buffer[2..6].copy_from_slice(ip.as_bytes()); - } - &DhcpOption::Router(ip) => { - buffer[0] = field::OPT_ROUTER; - buffer[2..6].copy_from_slice(ip.as_bytes()); - } - &DhcpOption::SubnetMask(mask) => { - buffer[0] = field::OPT_SUBNET_MASK; - buffer[2..6].copy_from_slice(mask.as_bytes()); - } - &DhcpOption::MaximumDhcpMessageSize(size) => { - buffer[0] = field::OPT_MAX_DHCP_MESSAGE_SIZE; - buffer[2..4].copy_from_slice(&size.to_be_bytes()[..]); - } - &DhcpOption::Other { kind, data: provided } => { - buffer[0] = kind; - buffer[2..skip_length].copy_from_slice(provided); - } - } - } + pub fn end(&mut self) -> Result<()> { + if self.buffer.is_empty() { + return Err(Error); } - &mut buffer[skip_length..] + + self.buffer[0] = field::OPT_END; + self.buffer = &mut []; + Ok(()) } } +/// A representation of a single DHCP option. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct DhcpOption<'a> { + pub kind: u8, + pub data: &'a [u8], +} + /// A read/write wrapper around a Dynamic Host Configuration Protocol packet buffer. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq, Copy, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } pub(crate) mod field { #![allow(non_snake_case)] #![allow(unused)] - use wire::field::*; + use crate::wire::field::*; pub const OP: usize = 0; pub const HTYPE: usize = 1; @@ -311,7 +233,7 @@ pub(crate) mod field { impl> Packet { /// Imbue a raw octet buffer with DHCP packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -326,13 +248,13 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// [set_header_len]: #method.set_header_len pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::MAGIC_NUMBER.end { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -431,19 +353,67 @@ impl> Packet { Ipv4Address::from_bytes(field) } - /// Returns true if the broadcast flag is set. - pub fn broadcast_flag(&self) -> bool { + pub fn flags(&self) -> Flags { let field = &self.buffer.as_ref()[field::FLAGS]; - NetworkEndian::read_u16(field) & 0b1 == 0b1 + Flags::from_bits_truncate(NetworkEndian::read_u16(field)) } -} -impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { - /// Return a pointer to the options. + /// Return an iterator over the options. #[inline] - pub fn options(&self) -> Result<&'a [u8]> { - let data = self.buffer.as_ref(); - data.get(field::OPTIONS).ok_or(Error::Malformed) + pub fn options(&self) -> impl Iterator> + '_ { + let mut buf = &self.buffer.as_ref()[field::OPTIONS]; + iter::from_fn(move || { + loop { + match buf.first().copied() { + // No more options, return. + None => return None, + Some(field::OPT_END) => return None, + + // Skip padding. + Some(field::OPT_PAD) => buf = &buf[1..], + Some(kind) => { + if buf.len() < 2 { + return None; + } + + let len = buf[1] as usize; + + if buf.len() < 2 + len { + return None; + } + + let opt = DhcpOption { + kind, + data: &buf[2..2 + len], + }; + + buf = &buf[2 + len..]; + return Some(opt); + } + } + } + }) + } + + pub fn get_sname(&self) -> Result<&str> { + let data = &self.buffer.as_ref()[field::SNAME]; + let len = data.iter().position(|&x| x == 0).ok_or(Error)?; + if len == 0 { + return Err(Error); + } + + let data = core::str::from_utf8(&data[..len]).map_err(|_| Error)?; + Ok(data) + } + + pub fn get_boot_file(&self) -> Result<&str> { + let data = &self.buffer.as_ref()[field::FILE]; + let len = data.iter().position(|&x| x == 0).ok_or(Error)?; + if len == 0 { + return Err(Error); + } + let data = core::str::from_utf8(&data[..len]).map_err(|_| Error)?; + Ok(data) } } @@ -554,19 +524,18 @@ impl + AsMut<[u8]>> Packet { field.copy_from_slice(value.as_bytes()); } - /// Sets the broadcast flag to the specified value. - pub fn set_broadcast_flag(&mut self, value: bool) { + /// Sets the flags to the specified value. + pub fn set_flags(&mut self, val: Flags) { let field = &mut self.buffer.as_mut()[field::FLAGS]; - NetworkEndian::write_u16(field, if value { 1 } else { 0 }); + NetworkEndian::write_u16(field, val.bits()); } } impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> { /// Return a pointer to the options. #[inline] - pub fn options_mut(&mut self) -> Result<&mut [u8]> { - let data = self.buffer.as_mut(); - data.get_mut(field::OPTIONS).ok_or(Error::Truncated) + pub fn options_mut(&mut self) -> DhcpOptionWriter<'_> { + DhcpOptionWriter::new(&mut self.buffer.as_mut()[field::OPTIONS]) } } @@ -614,7 +583,8 @@ impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'a mut T> { /// length) is set to `6`. /// /// The `options` field has a variable length. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Repr<'a> { /// This field is also known as `op` in the RFC. It indicates the type of DHCP message this /// packet represents. @@ -623,6 +593,11 @@ pub struct Repr<'a> { /// used by the client and server to associate messages and responses between a client and a /// server. pub transaction_id: u32, + /// seconds elapsed since client began address acquisition or renewal + /// process the DHCPREQUEST message MUST use the same value in the DHCP + /// message header's 'secs' field and be sent to the same IP broadcast + /// address as the original DHCPDISCOVER message. + pub secs: u16, /// This field is also known as `chaddr` in the RFC and for networks where the access layer is /// ethernet, it is the client MAC address. pub client_hardware_address: EthernetAddress, @@ -671,9 +646,19 @@ pub struct Repr<'a> { /// the client is interested in. pub parameter_request_list: Option<&'a [u8]>, /// DNS servers - pub dns_servers: Option<[Option; 3]>, + pub dns_servers: Option>, /// The maximum size dhcp packet the interface can receive pub max_size: Option, + /// The DHCP IP lease duration, specified in seconds. + pub lease_duration: Option, + /// The DHCP IP renew duration (T1 interval), in seconds, if specified in the packet. + pub renew_duration: Option, + /// The DHCP IP rebind duration (T2 interval), in seconds, if specified in the packet. + pub rebind_duration: Option, + /// When returned from [`Repr::parse`], this field will be `None`. + /// However, when calling [`Repr::emit`], this field should contain only + /// additional DHCP options not known to smoltcp. + pub additional_options: &'a [DhcpOption<'a>], } impl<'a> Repr<'a> { @@ -682,41 +667,69 @@ impl<'a> Repr<'a> { let mut len = field::OPTIONS.start; // message type and end-of-options options len += 3 + 1; - if self.requested_ip.is_some() { len += 6; } - if self.client_identifier.is_some() { len += 9; } - if self.server_identifier.is_some() { len += 6; } - if self.max_size.is_some() { len += 4; } - if let Some(list) = self.parameter_request_list { len += list.len() + 2; } + if self.requested_ip.is_some() { + len += 6; + } + if self.client_identifier.is_some() { + len += 9; + } + if self.server_identifier.is_some() { + len += 6; + } + if self.max_size.is_some() { + len += 4; + } + if self.router.is_some() { + len += 6; + } + if self.subnet_mask.is_some() { + len += 6; + } + if self.lease_duration.is_some() { + len += 6; + } + if let Some(dns_servers) = &self.dns_servers { + len += 2; + len += dns_servers.iter().count() * core::mem::size_of::(); + } + if let Some(list) = self.parameter_request_list { + len += list.len() + 2; + } + for opt in self.additional_options { + len += 2 + opt.data.len() + } len } /// Parse a DHCP packet and return a high-level representation. - pub fn parse(packet: &Packet<&'a T>) -> Result - where T: AsRef<[u8]> + ?Sized { - + pub fn parse(packet: &'a Packet<&'a T>) -> Result + where + T: AsRef<[u8]> + ?Sized, + { let transaction_id = packet.transaction_id(); let client_hardware_address = packet.client_hardware_address(); let client_ip = packet.client_ip(); let your_ip = packet.your_ip(); let server_ip = packet.server_ip(); let relay_agent_ip = packet.relay_agent_ip(); + let secs = packet.secs(); // only ethernet is supported right now match packet.hardware_type() { Hardware::Ethernet => { if packet.hardware_len() != 6 { - return Err(Error::Malformed); + return Err(Error); } } - Hardware::Unknown(_) => return Err(Error::Unrecognized), // unimplemented + Hardware::Unknown(_) => return Err(Error), // unimplemented } if packet.magic_number() != DHCP_MAGIC_NUMBER { - return Err(Error::Malformed); + return Err(Error); } - let mut message_type = Err(Error::Malformed); + let mut message_type = Err(Error); let mut requested_ip = None; let mut client_identifier = None; let mut server_identifier = None; @@ -725,68 +738,101 @@ impl<'a> Repr<'a> { let mut parameter_request_list = None; let mut dns_servers = None; let mut max_size = None; - - let mut options = packet.options()?; - while options.len() > 0 { - let (next_options, option) = DhcpOption::parse(options)?; - match option { - DhcpOption::EndOfList => break, - DhcpOption::Pad => {}, - DhcpOption::MessageType(value) => { + let mut lease_duration = None; + let mut renew_duration = None; + let mut rebind_duration = None; + + for option in packet.options() { + let data = option.data; + match (option.kind, data.len()) { + (field::OPT_DHCP_MESSAGE_TYPE, 1) => { + let value = MessageType::from(data[0]); if value.opcode() == packet.opcode() { message_type = Ok(value); } - }, - DhcpOption::RequestedIp(ip) => { - requested_ip = Some(ip); } - DhcpOption::ClientIdentifier(eth_addr) => { - client_identifier = Some(eth_addr); + (field::OPT_REQUESTED_IP, 4) => { + requested_ip = Some(Ipv4Address::from_bytes(data)); } - DhcpOption::ServerIdentifier(ip) => { - server_identifier = Some(ip); + (field::OPT_CLIENT_ID, 7) => { + let hardware_type = Hardware::from(u16::from(data[0])); + if hardware_type != Hardware::Ethernet { + return Err(Error); + } + client_identifier = Some(EthernetAddress::from_bytes(&data[1..])); + } + (field::OPT_SERVER_IDENTIFIER, 4) => { + server_identifier = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_ROUTER, 4) => { + router = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_SUBNET_MASK, 4) => { + subnet_mask = Some(Ipv4Address::from_bytes(data)); + } + (field::OPT_MAX_DHCP_MESSAGE_SIZE, 2) => { + max_size = Some(u16::from_be_bytes([data[0], data[1]])); } - DhcpOption::Router(ip) => { - router = Some(ip); + (field::OPT_RENEWAL_TIME_VALUE, 4) => { + renew_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) } - DhcpOption::SubnetMask(mask) => { - subnet_mask = Some(mask); - }, - DhcpOption::MaximumDhcpMessageSize(size) => { - max_size = Some(size); + (field::OPT_REBINDING_TIME_VALUE, 4) => { + rebind_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) } - DhcpOption::Other {kind: field::OPT_PARAMETER_REQUEST_LIST, data} => { + (field::OPT_IP_LEASE_TIME, 4) => { + lease_duration = Some(u32::from_be_bytes([data[0], data[1], data[2], data[3]])) + } + (field::OPT_PARAMETER_REQUEST_LIST, _) => { parameter_request_list = Some(data); } - DhcpOption::Other {kind: field::OPT_DOMAIN_NAME_SERVER, data} => { - let mut dns_servers_inner = [None; 3]; - for i in 0..3 { - let offset = 4 * i; - let end = offset + 4; - if end > data.len() { break } - dns_servers_inner[i] = Some(Ipv4Address::from_bytes(&data[offset..end])); + (field::OPT_DOMAIN_NAME_SERVER, _) => { + let mut servers = Vec::new(); + const IP_ADDR_BYTE_LEN: usize = 4; + for chunk in data.chunks(IP_ADDR_BYTE_LEN) { + // We ignore push failures because that will only happen + // if we attempt to push more than 4 addresses, and the only + // solution to that is to support more addresses. + servers.push(Ipv4Address::from_bytes(chunk)).ok(); } - dns_servers = Some(dns_servers_inner); + dns_servers = Some(servers); } - DhcpOption::Other {..} => {} + _ => {} } - options = next_options; } - let broadcast = packet.broadcast_flag(); + let broadcast = packet.flags().contains(Flags::BROADCAST); Ok(Repr { - transaction_id, client_hardware_address, client_ip, your_ip, server_ip, relay_agent_ip, - broadcast, requested_ip, server_identifier, router, - subnet_mask, client_identifier, parameter_request_list, dns_servers, max_size, + secs, + transaction_id, + client_hardware_address, + client_ip, + your_ip, + server_ip, + relay_agent_ip, + broadcast, + requested_ip, + server_identifier, + router, + subnet_mask, + client_identifier, + parameter_request_list, + dns_servers, + max_size, + lease_duration, + renew_duration, + rebind_duration, message_type: message_type?, + additional_options: &[], }) } /// Emit a high-level representation into a Dynamic Host /// Configuration Protocol packet. pub fn emit(&self, packet: &mut Packet<&mut T>) -> Result<()> - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { packet.set_sname_and_boot_file_to_zero(); packet.set_opcode(self.message_type.opcode()); packet.set_hardware_type(Hardware::Ethernet); @@ -794,40 +840,105 @@ impl<'a> Repr<'a> { packet.set_transaction_id(self.transaction_id); packet.set_client_hardware_address(self.client_hardware_address); packet.set_hops(0); - packet.set_secs(0); // TODO + packet.set_secs(self.secs); packet.set_magic_number(0x63825363); packet.set_client_ip(self.client_ip); packet.set_your_ip(self.your_ip); packet.set_server_ip(self.server_ip); packet.set_relay_agent_ip(self.relay_agent_ip); - packet.set_broadcast_flag(self.broadcast); + + let mut flags = Flags::empty(); + if self.broadcast { + flags |= Flags::BROADCAST; + } + packet.set_flags(flags); { - let mut options = packet.options_mut()?; - let tmp = options; options = DhcpOption::MessageType(self.message_type).emit(tmp); - if let Some(eth_addr) = self.client_identifier { - let tmp = options; options = DhcpOption::ClientIdentifier(eth_addr).emit(tmp); + let mut options = packet.options_mut(); + + options.emit(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[self.message_type.into()], + })?; + + if let Some(val) = &self.client_identifier { + let mut data = [0; 7]; + data[0] = u16::from(Hardware::Ethernet) as u8; + data[1..].copy_from_slice(val.as_bytes()); + + options.emit(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &data, + })?; } - if let Some(ip) = self.server_identifier { - let tmp = options; options = DhcpOption::ServerIdentifier(ip).emit(tmp); + + if let Some(val) = &self.server_identifier { + options.emit(DhcpOption { + kind: field::OPT_SERVER_IDENTIFIER, + data: val.as_bytes(), + })?; } - if let Some(ip) = self.router { - let tmp = options; options = DhcpOption::Router(ip).emit(tmp); + + if let Some(val) = &self.router { + options.emit(DhcpOption { + kind: field::OPT_ROUTER, + data: val.as_bytes(), + })?; + } + if let Some(val) = &self.subnet_mask { + options.emit(DhcpOption { + kind: field::OPT_SUBNET_MASK, + data: val.as_bytes(), + })?; } - if let Some(ip) = self.subnet_mask { - let tmp = options; options = DhcpOption::SubnetMask(ip).emit(tmp); + if let Some(val) = &self.requested_ip { + options.emit(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: val.as_bytes(), + })?; } - if let Some(ip) = self.requested_ip { - let tmp = options; options = DhcpOption::RequestedIp(ip).emit(tmp); + if let Some(val) = &self.max_size { + options.emit(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &val.to_be_bytes(), + })?; } - if let Some(size) = self.max_size { - let tmp = options; options = DhcpOption::MaximumDhcpMessageSize(size).emit(tmp); + if let Some(val) = &self.lease_duration { + options.emit(DhcpOption { + kind: field::OPT_IP_LEASE_TIME, + data: &val.to_be_bytes(), + })?; } - if let Some(list) = self.parameter_request_list { - let option = DhcpOption::Other{ kind: field::OPT_PARAMETER_REQUEST_LIST, data: list }; - let tmp = options; options = option.emit(tmp); + if let Some(val) = &self.parameter_request_list { + options.emit(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: val, + })?; } - DhcpOption::EndOfList.emit(options); + + if let Some(dns_servers) = &self.dns_servers { + const IP_SIZE: usize = core::mem::size_of::(); + let mut servers = [0; MAX_DNS_SERVER_COUNT * IP_SIZE]; + + let data_len = dns_servers + .iter() + .enumerate() + .inspect(|(i, ip)| { + servers[(i * IP_SIZE)..((i + 1) * IP_SIZE)].copy_from_slice(ip.as_bytes()); + }) + .count() + * IP_SIZE; + options.emit(DhcpOption { + kind: field::OPT_DOMAIN_NAME_SERVER, + data: &servers[..data_len], + })?; + } + + for option in self.additional_options { + options.emit(*option)?; + } + + options.end()?; } Ok(()) @@ -836,53 +947,80 @@ impl<'a> Repr<'a> { #[cfg(test)] mod test { - use wire::Ipv4Address; use super::*; + use crate::wire::Ipv4Address; const MAGIC_COOKIE: u32 = 0x63825363; static DISCOVER_BYTES: &[u8] = &[ - 0x01, 0x01, 0x06, 0x00, 0x00, 0x00, 0x3d, 0x1d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0b, 0x82, 0x01, - 0xfc, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, - 0x35, 0x01, 0x01, 0x3d, 0x07, 0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42, 0x32, 0x04, 0x00, 0x00, - 0x00, 0x00, 0x39, 0x2, 0x5, 0xdc, 0x37, 0x04, 0x01, 0x03, 0x06, 0x2a, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, 0x01, 0x06, 0x00, 0x00, 0x00, 0x3d, 0x1d, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0b, + 0x82, 0x01, 0xfc, 0x42, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x01, 0x3d, 0x07, 0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42, 0x32, 0x04, 0x00, + 0x00, 0x00, 0x00, 0x39, 0x2, 0x5, 0xdc, 0x37, 0x04, 0x01, 0x03, 0x06, 0x2a, 0xff, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + static ACK_DNS_SERVER_BYTES: &[u8] = &[ + 0x02, 0x01, 0x06, 0x00, 0xcc, 0x34, 0x75, 0xab, 0x00, 0x00, 0x80, 0x00, 0x0a, 0xff, 0x06, + 0x91, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0xff, 0x06, 0xfe, 0x34, 0x17, + 0xeb, 0xc9, 0xaa, 0x2f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x05, 0x36, 0x04, 0xa3, 0x01, 0x4a, 0x16, 0x01, 0x04, 0xff, 0xff, 0xff, 0x00, + 0x2b, 0x05, 0xdc, 0x03, 0x4e, 0x41, 0x50, 0x0f, 0x15, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x68, + 0x79, 0x73, 0x69, 0x63, 0x73, 0x2e, 0x6f, 0x78, 0x2e, 0x61, 0x63, 0x2e, 0x75, 0x6b, 0x00, + 0x03, 0x04, 0x0a, 0xff, 0x06, 0xfe, 0x06, 0x10, 0xa3, 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, + 0x07, 0xa3, 0x01, 0x4a, 0x03, 0xa3, 0x01, 0x4a, 0x04, 0x2c, 0x10, 0xa3, 0x01, 0x4a, 0x03, + 0xa3, 0x01, 0x4a, 0x04, 0xa3, 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, 0x07, 0x2e, 0x01, 0x08, + 0xff, ]; - static ACK_BYTES: &[u8] = &[ - 0x02, 0x01, 0x06, 0x00, 0xcc, 0x34, 0x75, 0xab, 0x00, 0x00, 0x80, 0x00, 0x0a, 0xff, 0x06, 0x91, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x0a, 0xff, 0x06, 0xfe, 0x34, 0x17, 0xeb, 0xc9, - 0xaa, 0x2f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, - 0x35, 0x01, 0x05, 0x36, 0x04, 0xa3, 0x01, 0x4a, 0x16, 0x01, 0x04, 0xff, 0xff, 0xff, 0x00, 0x2b, - 0x05, 0xdc, 0x03, 0x4e, 0x41, 0x50, 0x0f, 0x15, 0x6e, 0x61, 0x74, 0x2e, 0x70, 0x68, 0x79, 0x73, - 0x69, 0x63, 0x73, 0x2e, 0x6f, 0x78, 0x2e, 0x61, 0x63, 0x2e, 0x75, 0x6b, 0x00, 0x03, 0x04, 0x0a, - 0xff, 0x06, 0xfe, 0x06, 0x10, 0xa3, 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, 0x07, 0xa3, 0x01, 0x4a, - 0x03, 0xa3, 0x01, 0x4a, 0x04, 0x2c, 0x10, 0xa3, 0x01, 0x4a, 0x03, 0xa3, 0x01, 0x4a, 0x04, 0xa3, - 0x01, 0x4a, 0x06, 0xa3, 0x01, 0x4a, 0x07, 0x2e, 0x01, 0x08, 0xff + static ACK_LEASE_TIME_BYTES: &[u8] = &[ + 0x02, 0x01, 0x06, 0x00, 0x00, 0x00, 0x00, 0x06, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x0a, 0x22, 0x10, 0x0b, 0x0a, 0x22, 0x10, 0x0a, 0x00, 0x00, 0x00, 0x00, 0x04, 0x91, + 0x62, 0xd2, 0xa8, 0x6f, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x63, 0x82, 0x53, 0x63, + 0x35, 0x01, 0x05, 0x36, 0x04, 0x0a, 0x22, 0x10, 0x0a, 0x33, 0x04, 0x00, 0x00, 0x02, 0x56, + 0x01, 0x04, 0xff, 0xff, 0xff, 0x00, 0x03, 0x04, 0x0a, 0x22, 0x10, 0x0a, 0xff, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; const IP_NULL: Ipv4Address = Ipv4Address([0, 0, 0, 0]); @@ -904,34 +1042,44 @@ mod test { assert_eq!(packet.server_ip(), IP_NULL); assert_eq!(packet.relay_agent_ip(), IP_NULL); assert_eq!(packet.client_hardware_address(), CLIENT_MAC); - let options = packet.options().unwrap(); - assert_eq!(options.len(), 3 + 9 + 6 + 4 + 6 + 1 + 7); - - let (options, message_type) = DhcpOption::parse(options).unwrap(); - assert_eq!(message_type, DhcpOption::MessageType(MessageType::Discover)); - assert_eq!(options.len(), 9 + 6 + 4 + 6 + 1 + 7); - - let (options, client_id) = DhcpOption::parse(options).unwrap(); - assert_eq!(client_id, DhcpOption::ClientIdentifier(CLIENT_MAC)); - assert_eq!(options.len(), 6 + 4 + 6 + 1 + 7); - - let (options, client_id) = DhcpOption::parse(options).unwrap(); - assert_eq!(client_id, DhcpOption::RequestedIp(IP_NULL)); - assert_eq!(options.len(), 4 + 6 + 1 + 7); - - let (options, msg_size) = DhcpOption::parse(options).unwrap(); - assert_eq!(msg_size, DhcpOption::MaximumDhcpMessageSize(DHCP_SIZE)); - assert_eq!(options.len(), 6 + 1 + 7); - - let (options, client_id) = DhcpOption::parse(options).unwrap(); - assert_eq!(client_id, DhcpOption::Other { - kind: field::OPT_PARAMETER_REQUEST_LIST, data: &[1, 3, 6, 42] - }); - assert_eq!(options.len(), 1 + 7); - - let (options, client_id) = DhcpOption::parse(options).unwrap(); - assert_eq!(client_id, DhcpOption::EndOfList); - assert_eq!(options.len(), 7); // padding + + let mut options = packet.options(); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[0x01] + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &[0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42], + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: &[0x00, 0x00, 0x00, 0x00], + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &DHCP_SIZE.to_be_bytes(), + }) + ); + assert_eq!( + options.next(), + Some(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: &[1, 3, 6, 42] + }) + ); + assert_eq!(options.next(), None); } #[test] @@ -946,25 +1094,46 @@ mod test { packet.set_hops(0); packet.set_transaction_id(0x3d1d); packet.set_secs(0); - packet.set_broadcast_flag(false); + packet.set_flags(Flags::empty()); packet.set_client_ip(IP_NULL); packet.set_your_ip(IP_NULL); packet.set_server_ip(IP_NULL); packet.set_relay_agent_ip(IP_NULL); packet.set_client_hardware_address(CLIENT_MAC); - { - let mut options = packet.options_mut().unwrap(); - let tmp = options; options = DhcpOption::MessageType(MessageType::Discover).emit(tmp); - let tmp = options; options = DhcpOption::ClientIdentifier(CLIENT_MAC).emit(tmp); - let tmp = options; options = DhcpOption::RequestedIp(IP_NULL).emit(tmp); - let tmp = options; options = DhcpOption::MaximumDhcpMessageSize(DHCP_SIZE).emit(tmp); - let option = DhcpOption::Other { - kind: field::OPT_PARAMETER_REQUEST_LIST, data: &[1, 3, 6, 42], - }; - let tmp = options; options = option.emit(tmp); - DhcpOption::EndOfList.emit(options); - } + let mut options = packet.options_mut(); + + options + .emit(DhcpOption { + kind: field::OPT_DHCP_MESSAGE_TYPE, + data: &[0x01], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_CLIENT_ID, + data: &[0x01, 0x00, 0x0b, 0x82, 0x01, 0xfc, 0x42], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_REQUESTED_IP, + data: &[0x00, 0x00, 0x00, 0x00], + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_MAX_DHCP_MESSAGE_SIZE, + data: &DHCP_SIZE.to_be_bytes(), + }) + .unwrap(); + options + .emit(DhcpOption { + kind: field::OPT_PARAMETER_REQUEST_LIST, + data: &[1, 3, 6, 42], + }) + .unwrap(); + options.end().unwrap(); let packet = &mut packet.into_inner()[..]; for byte in &mut packet[269..276] { @@ -974,7 +1143,33 @@ mod test { assert_eq!(packet, DISCOVER_BYTES); } - fn discover_repr() -> Repr<'static> { + const fn offer_repr() -> Repr<'static> { + Repr { + message_type: MessageType::Offer, + transaction_id: 0x3d1d, + client_hardware_address: CLIENT_MAC, + client_ip: IP_NULL, + your_ip: IP_NULL, + server_ip: IP_NULL, + router: Some(IP_NULL), + subnet_mask: Some(IP_NULL), + relay_agent_ip: IP_NULL, + secs: 0, + broadcast: false, + requested_ip: None, + client_identifier: Some(CLIENT_MAC), + server_identifier: None, + parameter_request_list: None, + dns_servers: None, + max_size: None, + renew_duration: None, + rebind_duration: None, + lease_duration: Some(0xffff_ffff), // Infinite lease + additional_options: &[], + } + } + + const fn discover_repr() -> Repr<'static> { Repr { message_type: MessageType::Discover, transaction_id: 0x3d1d, @@ -986,12 +1181,17 @@ mod test { subnet_mask: None, relay_agent_ip: IP_NULL, broadcast: false, + secs: 0, max_size: Some(DHCP_SIZE), + renew_duration: None, + rebind_duration: None, + lease_duration: None, requested_ip: Some(IP_NULL), client_identifier: Some(CLIENT_MAC), server_identifier: None, parameter_request_list: Some(&[1, 3, 6, 42]), dns_servers: None, + additional_options: &[], } } @@ -1008,7 +1208,7 @@ mod test { let mut bytes = vec![0xa5; repr.buffer_len()]; let mut packet = Packet::new_unchecked(&mut bytes); repr.emit(&mut packet).unwrap(); - let packet = &packet.into_inner()[..]; + let packet = &*packet.into_inner(); let packet_len = packet.len(); assert_eq!(packet, &DISCOVER_BYTES[..packet_len]); for byte in &DISCOVER_BYTES[packet_len..] { @@ -1016,32 +1216,95 @@ mod test { } } + #[test] + fn test_emit_offer() { + let repr = offer_repr(); + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet).unwrap(); + } + + #[test] + fn test_emit_offer_dns() { + let repr = { + let mut repr = offer_repr(); + repr.dns_servers = Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]), + ]) + .unwrap(), + ); + repr + }; + let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut packet = Packet::new_unchecked(&mut bytes); + repr.emit(&mut packet).unwrap(); + + let packet = Packet::new_unchecked(&bytes); + let repr_parsed = Repr::parse(&packet).unwrap(); + + assert_eq!( + repr_parsed.dns_servers, + Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]), + ]) + .unwrap() + ) + ); + } + #[test] fn test_emit_dhcp_option() { static DATA: &[u8] = &[1, 3, 6]; - let mut bytes = vec![0xa5; 5]; - let dhcp_option = DhcpOption::Other { + let dhcp_option = DhcpOption { kind: field::OPT_PARAMETER_REQUEST_LIST, data: DATA, }; - { - let rest = dhcp_option.emit(&mut bytes); - assert_eq!(rest.len(), 0); - } - assert_eq!(&bytes[0..2], &[field::OPT_PARAMETER_REQUEST_LIST, DATA.len() as u8]); + + let mut bytes = vec![0xa5; 5]; + let mut writer = DhcpOptionWriter::new(&mut bytes); + writer.emit(dhcp_option).unwrap(); + + assert_eq!( + &bytes[0..2], + &[field::OPT_PARAMETER_REQUEST_LIST, DATA.len() as u8] + ); assert_eq!(&bytes[2..], DATA); } #[test] fn test_parse_ack_dns_servers() { - let packet = Packet::new_unchecked(ACK_BYTES); + let packet = Packet::new_unchecked(ACK_DNS_SERVER_BYTES); let repr = Repr::parse(&packet).unwrap(); + // The packet described by ACK_BYTES advertises 4 DNS servers // Here we ensure that we correctly parse the first 3 into our fixed // length-3 array (see issue #305) - assert_eq!(repr.dns_servers, Some([ - Some(Ipv4Address([163, 1, 74, 6])), - Some(Ipv4Address([163, 1, 74, 7])), - Some(Ipv4Address([163, 1, 74, 3]))])); + assert_eq!( + repr.dns_servers, + Some( + Vec::from_slice(&[ + Ipv4Address([163, 1, 74, 6]), + Ipv4Address([163, 1, 74, 7]), + Ipv4Address([163, 1, 74, 3]) + ]) + .unwrap() + ) + ); + } + + #[test] + fn test_parse_ack_lease_duration() { + let packet = Packet::new_unchecked(ACK_LEASE_TIME_BYTES); + let repr = Repr::parse(&packet).unwrap(); + + // Verify that the lease time in the ACK is properly parsed. The packet contains a lease + // duration of 598s. + assert_eq!(repr.lease_duration, Some(598)); } } diff --git a/src/wire/dns.rs b/src/wire/dns.rs new file mode 100644 index 000000000..7cd008324 --- /dev/null +++ b/src/wire/dns.rs @@ -0,0 +1,793 @@ +#![allow(dead_code)] + +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; +use core::iter; +use core::iter::Iterator; + +use super::{Error, Result}; +#[cfg(feature = "proto-ipv4")] +use crate::wire::Ipv4Address; +#[cfg(feature = "proto-ipv6")] +use crate::wire::Ipv6Address; + +enum_with_unknown! { + /// DNS OpCodes + pub enum Opcode(u8) { + Query = 0x00, + Status = 0x01, + } +} +enum_with_unknown! { + /// DNS OpCodes + pub enum Rcode(u8) { + NoError = 0x00, + FormErr = 0x01, + ServFail = 0x02, + NXDomain = 0x03, + NotImp = 0x04, + Refused = 0x05, + YXDomain = 0x06, + YXRRSet = 0x07, + NXRRSet = 0x08, + NotAuth = 0x09, + NotZone = 0x0a, + } +} + +enum_with_unknown! { + /// DNS record types + pub enum Type(u16) { + A = 0x0001, + Ns = 0x0002, + Cname = 0x0005, + Soa = 0x0006, + Aaaa = 0x001c, + } +} + +bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Flags: u16 { + const RESPONSE = 0b1000_0000_0000_0000; + const AUTHORITATIVE = 0b0000_0100_0000_0000; + const TRUNCATED = 0b0000_0010_0000_0000; + const RECURSION_DESIRED = 0b0000_0001_0000_0000; + const RECURSION_AVAILABLE = 0b0000_0000_1000_0000; + const AUTHENTIC_DATA = 0b0000_0000_0010_0000; + const CHECK_DISABLED = 0b0000_0000_0001_0000; + } +} + +mod field { + use crate::wire::field::*; + + pub const ID: Field = 0..2; + pub const FLAGS: Field = 2..4; + pub const QDCOUNT: Field = 4..6; + pub const ANCOUNT: Field = 6..8; + pub const NSCOUNT: Field = 8..10; + pub const ARCOUNT: Field = 10..12; + + pub const HEADER_END: usize = 12; +} + +// DNS class IN (Internet) +const CLASS_IN: u16 = 1; + +/// A read/write wrapper around a DNS packet buffer. +#[derive(Debug, PartialEq, Eq)] +pub struct Packet> { + buffer: T, +} + +impl> Packet { + /// Imbue a raw octet buffer with DNS packet structure. + pub const fn new_unchecked(buffer: T) -> Packet { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is smaller than + /// the header length. + pub fn check_len(&self) -> Result<()> { + let len = self.buffer.as_ref().len(); + if len < field::HEADER_END { + Err(Error) + } else { + Ok(()) + } + } + + /// Consume the packet, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + pub fn payload(&self) -> &[u8] { + &self.buffer.as_ref()[field::HEADER_END..] + } + + pub fn transaction_id(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ID]; + NetworkEndian::read_u16(field) + } + + pub fn flags(&self) -> Flags { + let field = &self.buffer.as_ref()[field::FLAGS]; + Flags::from_bits_truncate(NetworkEndian::read_u16(field)) + } + + pub fn opcode(&self) -> Opcode { + let field = &self.buffer.as_ref()[field::FLAGS]; + let flags = NetworkEndian::read_u16(field); + Opcode::from((flags >> 11 & 0xF) as u8) + } + + pub fn rcode(&self) -> Rcode { + let field = &self.buffer.as_ref()[field::FLAGS]; + let flags = NetworkEndian::read_u16(field); + Rcode::from((flags & 0xF) as u8) + } + + pub fn question_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::QDCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn answer_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ANCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn authority_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::NSCOUNT]; + NetworkEndian::read_u16(field) + } + + pub fn additional_record_count(&self) -> u16 { + let field = &self.buffer.as_ref()[field::ARCOUNT]; + NetworkEndian::read_u16(field) + } + + /// Parse part of a name from `bytes`, following pointers if any. + pub fn parse_name<'a>(&'a self, mut bytes: &'a [u8]) -> impl Iterator> { + let mut packet = self.buffer.as_ref(); + + iter::from_fn(move || loop { + if bytes.is_empty() { + return Some(Err(Error)); + } + match bytes[0] { + 0x00 => return None, + x if x & 0xC0 == 0x00 => { + let len = (x & 0x3F) as usize; + if bytes.len() < 1 + len { + return Some(Err(Error)); + } + let label = &bytes[1..1 + len]; + bytes = &bytes[1 + len..]; + return Some(Ok(label)); + } + x if x & 0xC0 == 0xC0 => { + if bytes.len() < 2 { + return Some(Err(Error)); + } + let y = bytes[1]; + let ptr = ((x & 0x3F) as usize) << 8 | (y as usize); + if packet.len() <= ptr { + return Some(Err(Error)); + } + + // RFC1035 says: "In this scheme, an entire domain name or a list of labels at + // the end of a domain name is replaced with a pointer to a ***prior*** occurance + // of the same name. + // + // Is it unclear if this means the pointer MUST point backwards in the packet or not. Either way, + // pointers that don't point backwards are never seen in the fields, so use this to check that + // there are no pointer loops. + + // Split packet into parts before and after `ptr`. + // parse the part after, keep only the part before in `packet`. This ensure we never + // parse the same byte twice, therefore eliminating pointer loops. + + bytes = &packet[ptr..]; + packet = &packet[..ptr]; + } + _ => return Some(Err(Error)), + } + }) + } +} + +impl + AsMut<[u8]>> Packet { + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + &mut data[field::HEADER_END..] + } + + pub fn set_transaction_id(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ID]; + NetworkEndian::write_u16(field, val) + } + + pub fn set_flags(&mut self, val: Flags) { + let field = &mut self.buffer.as_mut()[field::FLAGS]; + let mask = Flags::all().bits; + let old = NetworkEndian::read_u16(field); + NetworkEndian::write_u16(field, (old & !mask) | val.bits()); + } + + pub fn set_opcode(&mut self, val: Opcode) { + let field = &mut self.buffer.as_mut()[field::FLAGS]; + let mask = 0x3800; + let val: u8 = val.into(); + let val = (val as u16) << 11; + let old = NetworkEndian::read_u16(field); + NetworkEndian::write_u16(field, (old & !mask) | val); + } + + pub fn set_question_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::QDCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_answer_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ANCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_authority_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::NSCOUNT]; + NetworkEndian::write_u16(field, val) + } + pub fn set_additional_record_count(&mut self, val: u16) { + let field = &mut self.buffer.as_mut()[field::ARCOUNT]; + NetworkEndian::write_u16(field, val) + } +} + +/// Parse part of a name from `bytes`, not following pointers. +/// Returns the unused part of `bytes`, and the pointer offset if the sequence ends with a pointer. +fn parse_name_part<'a>( + mut bytes: &'a [u8], + mut f: impl FnMut(&'a [u8]), +) -> Result<(&'a [u8], Option)> { + loop { + let x = *bytes.first().ok_or(Error)?; + bytes = &bytes[1..]; + match x { + 0x00 => return Ok((bytes, None)), + x if x & 0xC0 == 0x00 => { + let len = (x & 0x3F) as usize; + let label = bytes.get(..len).ok_or(Error)?; + bytes = &bytes[len..]; + f(label); + } + x if x & 0xC0 == 0xC0 => { + let y = *bytes.first().ok_or(Error)?; + bytes = &bytes[1..]; + + let ptr = ((x & 0x3F) as usize) << 8 | (y as usize); + return Ok((bytes, Some(ptr))); + } + _ => return Err(Error), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Question<'a> { + pub name: &'a [u8], + pub type_: Type, +} + +impl<'a> Question<'a> { + pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], Question<'a>)> { + let (rest, _) = parse_name_part(buffer, |_| ())?; + let name = &buffer[..buffer.len() - rest.len()]; + + if rest.len() < 4 { + return Err(Error); + } + let type_ = NetworkEndian::read_u16(&rest[0..2]).into(); + let class = NetworkEndian::read_u16(&rest[2..4]); + let rest = &rest[4..]; + + if class != CLASS_IN { + return Err(Error); + } + + Ok((rest, Question { name, type_ })) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + self.name.len() + 4 + } + + /// Emit a high-level representation into a DNS packet. + pub fn emit(&self, packet: &mut [u8]) { + packet[..self.name.len()].copy_from_slice(self.name); + let rest = &mut packet[self.name.len()..]; + NetworkEndian::write_u16(&mut rest[0..2], self.type_.into()); + NetworkEndian::write_u16(&mut rest[2..4], CLASS_IN); + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Record<'a> { + pub name: &'a [u8], + pub ttl: u32, + pub data: RecordData<'a>, +} + +impl<'a> RecordData<'a> { + pub fn parse(type_: Type, data: &'a [u8]) -> Result> { + match type_ { + #[cfg(feature = "proto-ipv4")] + Type::A => { + if data.len() != 4 { + return Err(Error); + } + Ok(RecordData::A(Ipv4Address::from_bytes(data))) + } + #[cfg(feature = "proto-ipv6")] + Type::Aaaa => { + if data.len() != 16 { + return Err(Error); + } + Ok(RecordData::Aaaa(Ipv6Address::from_bytes(data))) + } + Type::Cname => Ok(RecordData::Cname(data)), + x => Ok(RecordData::Other(x, data)), + } + } +} + +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum RecordData<'a> { + #[cfg(feature = "proto-ipv4")] + A(Ipv4Address), + #[cfg(feature = "proto-ipv6")] + Aaaa(Ipv6Address), + Cname(&'a [u8]), + Other(Type, &'a [u8]), +} + +impl<'a> Record<'a> { + pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], Record<'a>)> { + let (rest, _) = parse_name_part(buffer, |_| ())?; + let name = &buffer[..buffer.len() - rest.len()]; + + if rest.len() < 10 { + return Err(Error); + } + let type_ = NetworkEndian::read_u16(&rest[0..2]).into(); + let class = NetworkEndian::read_u16(&rest[2..4]); + let ttl = NetworkEndian::read_u32(&rest[4..8]); + let len = NetworkEndian::read_u16(&rest[8..10]) as usize; + let rest = &rest[10..]; + + if class != CLASS_IN { + return Err(Error); + } + + let data = rest.get(..len).ok_or(Error)?; + let rest = &rest[len..]; + + Ok(( + rest, + Record { + name, + ttl, + data: RecordData::parse(type_, data)?, + }, + )) + } +} + +/// High-level DNS packet representation. +/// +/// Currently only supports query packets. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + pub transaction_id: u16, + pub opcode: Opcode, + pub flags: Flags, + pub question: Question<'a>, +} + +impl<'a> Repr<'a> { + /// Return the length of a packet that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + field::HEADER_END + self.question.buffer_len() + } + + /// Emit a high-level representation into a DNS packet. + pub fn emit(&self, packet: &mut Packet<&mut T>) + where + T: AsRef<[u8]> + AsMut<[u8]>, + { + packet.set_transaction_id(self.transaction_id); + packet.set_flags(self.flags); + packet.set_opcode(self.opcode); + packet.set_question_count(1); + packet.set_answer_record_count(0); + packet.set_authority_record_count(0); + packet.set_additional_record_count(0); + self.question.emit(packet.payload_mut()) + } +} + +#[cfg(feature = "proto-ipv4")] // tests assume ipv4 +#[cfg(test)] +mod test { + use super::*; + use std::vec::Vec; + + #[test] + fn test_parse_name() { + let bytes = &[ + 0x78, 0x6c, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, + 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, 0x63, 0x6f, + 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, + 0x05, 0xf3, 0x00, 0x11, 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, + 0x04, 0x63, 0x31, 0x30, 0x72, 0xc0, 0x10, 0xc0, 0x2e, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x05, 0x00, 0x04, 0x1f, 0x0d, 0x53, 0x24, + ]; + let packet = Packet::new_unchecked(bytes); + + let name_vec = |bytes| { + let mut v = Vec::new(); + packet + .parse_name(bytes) + .try_for_each(|label| label.map(|label| v.push(label))) + .map(|_| v) + }; + + //assert_eq!(parse_name_len(bytes, 0x0c), Ok(18)); + assert_eq!( + name_vec(&bytes[0x0c..]), + Ok(vec![&b"www"[..], &b"facebook"[..], &b"com"[..]]) + ); + //assert_eq!(parse_name_len(bytes, 0x22), Ok(2)); + assert_eq!( + name_vec(&bytes[0x22..]), + Ok(vec![&b"www"[..], &b"facebook"[..], &b"com"[..]]) + ); + //assert_eq!(parse_name_len(bytes, 0x2e), Ok(17)); + assert_eq!( + name_vec(&bytes[0x2e..]), + Ok(vec![ + &b"star-mini"[..], + &b"c10r"[..], + &b"facebook"[..], + &b"com"[..] + ]) + ); + //assert_eq!(parse_name_len(bytes, 0x3f), Ok(2)); + assert_eq!( + name_vec(&bytes[0x3f..]), + Ok(vec![ + &b"star-mini"[..], + &b"c10r"[..], + &b"facebook"[..], + &b"com"[..] + ]) + ); + } + + struct Parsed<'a> { + packet: Packet<&'a [u8]>, + questions: Vec>, + answers: Vec>, + authorities: Vec>, + additionals: Vec>, + } + + impl<'a> Parsed<'a> { + fn parse(bytes: &'a [u8]) -> Result { + let packet = Packet::new_unchecked(bytes); + let mut questions = Vec::new(); + let mut answers = Vec::new(); + let mut authorities = Vec::new(); + let mut additionals = Vec::new(); + + let mut payload = &bytes[12..]; + + for _ in 0..packet.question_count() { + let (p, r) = Question::parse(payload)?; + questions.push(r); + payload = p; + } + for _ in 0..packet.answer_record_count() { + let (p, r) = Record::parse(payload)?; + answers.push(r); + payload = p; + } + for _ in 0..packet.authority_record_count() { + let (p, r) = Record::parse(payload)?; + authorities.push(r); + payload = p; + } + for _ in 0..packet.additional_record_count() { + let (p, r) = Record::parse(payload)?; + additionals.push(r); + payload = p; + } + + // Check that there are no bytes left + assert_eq!(payload.len(), 0); + + Ok(Parsed { + packet, + questions, + answers, + authorities, + additionals, + }) + } + } + + #[test] + fn test_parse_request() { + let p = Parsed::parse(&[ + 0x51, 0x84, 0x01, 0x20, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x5184); + assert_eq!( + p.packet.flags(), + Flags::RECURSION_DESIRED | Flags::AUTHENTIC_DATA + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 0); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!(p.questions.len(), 1); + assert_eq!( + p.questions[0].name, + &[0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers.len(), 0); + assert_eq!(p.authorities.len(), 0); + assert_eq!(p.additionals.len(), 0); + } + + #[test] + fn test_parse_response() { + let p = Parsed::parse(&[ + 0x51, 0x84, 0x81, 0x80, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x06, 0x67, + 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, + 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0xca, 0x00, 0x04, 0xac, 0xd9, + 0xa8, 0xae, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x5184); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 1); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[0x06, 0x67, 0x6f, 0x6f, 0x67, 0x6c, 0x65, 0x03, 0x63, 0x6f, 0x6d, 0x00] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 202); + assert_eq!( + p.answers[0].data, + RecordData::A(Ipv4Address::new(0xac, 0xd9, 0xa8, 0xae)) + ); + } + + #[test] + fn test_parse_response_multiple_a() { + let p = Parsed::parse(&[ + 0x4b, 0x9e, 0x81, 0x80, 0x00, 0x01, 0x00, 0x04, 0x00, 0x00, 0x00, 0x00, 0x09, 0x72, + 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, + 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, + 0x04, 0x0d, 0xe0, 0x77, 0x35, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, 0x00, 0x00, + 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x28, 0xc0, 0x0c, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x43, 0xc0, 0x0c, 0x00, 0x01, 0x00, + 0x01, 0x00, 0x00, 0x00, 0x09, 0x00, 0x04, 0x0d, 0xe0, 0x77, 0x62, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x4b9e); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 4); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[ + 0x09, 0x72, 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, + 0x00 + ] + ); + assert_eq!(p.questions[0].type_, Type::A); + + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 9); + assert_eq!( + p.answers[0].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x35)) + ); + + assert_eq!(p.answers[1].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[1].ttl, 9); + assert_eq!( + p.answers[1].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x28)) + ); + + assert_eq!(p.answers[2].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[2].ttl, 9); + assert_eq!( + p.answers[2].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x43)) + ); + + assert_eq!(p.answers[3].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[3].ttl, 9); + assert_eq!( + p.answers[3].data, + RecordData::A(Ipv4Address::new(0x0d, 0xe0, 0x77, 0x62)) + ); + } + + #[test] + fn test_parse_response_cname() { + let p = Parsed::parse(&[ + 0x78, 0x6c, 0x81, 0x80, 0x00, 0x01, 0x00, 0x02, 0x00, 0x00, 0x00, 0x00, 0x03, 0x77, + 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, 0x63, 0x6f, + 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, 0x0c, 0x00, 0x05, 0x00, 0x01, 0x00, 0x00, + 0x05, 0xf3, 0x00, 0x11, 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, + 0x04, 0x63, 0x31, 0x30, 0x72, 0xc0, 0x10, 0xc0, 0x2e, 0x00, 0x01, 0x00, 0x01, 0x00, + 0x00, 0x00, 0x05, 0x00, 0x04, 0x1f, 0x0d, 0x53, 0x24, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x786c); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NoError); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 2); + assert_eq!(p.packet.authority_record_count(), 0); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!( + p.questions[0].name, + &[ + 0x03, 0x77, 0x77, 0x77, 0x08, 0x66, 0x61, 0x63, 0x65, 0x62, 0x6f, 0x6f, 0x6b, 0x03, + 0x63, 0x6f, 0x6d, 0x00 + ] + ); + assert_eq!(p.questions[0].type_, Type::A); + + // cname + assert_eq!(p.answers[0].name, &[0xc0, 0x0c]); + assert_eq!(p.answers[0].ttl, 1523); + assert_eq!( + p.answers[0].data, + RecordData::Cname(&[ + 0x09, 0x73, 0x74, 0x61, 0x72, 0x2d, 0x6d, 0x69, 0x6e, 0x69, 0x04, 0x63, 0x31, 0x30, + 0x72, 0xc0, 0x10 + ]) + ); + // a + assert_eq!(p.answers[1].name, &[0xc0, 0x2e]); + assert_eq!(p.answers[1].ttl, 5); + assert_eq!( + p.answers[1].data, + RecordData::A(Ipv4Address::new(0x1f, 0x0d, 0x53, 0x24)) + ); + } + + #[test] + fn test_parse_response_nxdomain() { + let p = Parsed::parse(&[ + 0x63, 0xc4, 0x81, 0x83, 0x00, 0x01, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x13, 0x61, + 0x68, 0x61, 0x73, 0x64, 0x67, 0x68, 0x6c, 0x61, 0x6b, 0x73, 0x6a, 0x68, 0x62, 0x61, + 0x61, 0x73, 0x6c, 0x64, 0x03, 0x63, 0x6f, 0x6d, 0x00, 0x00, 0x01, 0x00, 0x01, 0xc0, + 0x20, 0x00, 0x06, 0x00, 0x01, 0x00, 0x00, 0x03, 0x83, 0x00, 0x3d, 0x01, 0x61, 0x0c, + 0x67, 0x74, 0x6c, 0x64, 0x2d, 0x73, 0x65, 0x72, 0x76, 0x65, 0x72, 0x73, 0x03, 0x6e, + 0x65, 0x74, 0x00, 0x05, 0x6e, 0x73, 0x74, 0x6c, 0x64, 0x0c, 0x76, 0x65, 0x72, 0x69, + 0x73, 0x69, 0x67, 0x6e, 0x2d, 0x67, 0x72, 0x73, 0xc0, 0x20, 0x5f, 0xce, 0x8b, 0x85, + 0x00, 0x00, 0x07, 0x08, 0x00, 0x00, 0x03, 0x84, 0x00, 0x09, 0x3a, 0x80, 0x00, 0x01, + 0x51, 0x80, + ]) + .unwrap(); + + assert_eq!(p.packet.transaction_id(), 0x63c4); + assert_eq!( + p.packet.flags(), + Flags::RESPONSE | Flags::RECURSION_DESIRED | Flags::RECURSION_AVAILABLE + ); + assert_eq!(p.packet.opcode(), Opcode::Query); + assert_eq!(p.packet.rcode(), Rcode::NXDomain); + assert_eq!(p.packet.question_count(), 1); + assert_eq!(p.packet.answer_record_count(), 0); + assert_eq!(p.packet.authority_record_count(), 1); + assert_eq!(p.packet.additional_record_count(), 0); + + assert_eq!(p.questions[0].type_, Type::A); + + // SOA authority + assert_eq!(p.authorities[0].name, &[0xc0, 0x20]); // com. + assert_eq!(p.authorities[0].ttl, 899); + assert!(matches!( + p.authorities[0].data, + RecordData::Other(Type::Soa, _) + )); + } + + #[test] + fn test_emit() { + let name = &[ + 0x09, 0x72, 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, + 0x00, + ]; + + let repr = Repr { + transaction_id: 0x1234, + flags: Flags::RECURSION_DESIRED, + opcode: Opcode::Query, + question: Question { + name, + type_: Type::A, + }, + }; + + let mut buf = Vec::new(); + buf.resize(repr.buffer_len(), 0); + repr.emit(&mut Packet::new_unchecked(&mut buf)); + + let want = &[ + 0x12, 0x34, 0x01, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x09, 0x72, + 0x75, 0x73, 0x74, 0x2d, 0x6c, 0x61, 0x6e, 0x67, 0x03, 0x6f, 0x72, 0x67, 0x00, 0x00, + 0x01, 0x00, 0x01, + ]; + assert_eq!(&buf, want); + } +} diff --git a/src/wire/ethernet.rs b/src/wire/ethernet.rs index 6847b62e2..53dc1eacb 100644 --- a/src/wire/ethernet.rs +++ b/src/wire/ethernet.rs @@ -1,7 +1,7 @@ -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; -use {Error, Result}; +use super::{Error, Result}; enum_with_unknown! { /// Ethernet protocol type. @@ -14,17 +14,18 @@ enum_with_unknown! { impl fmt::Display for EtherType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &EtherType::Ipv4 => write!(f, "IPv4"), - &EtherType::Ipv6 => write!(f, "IPv6"), - &EtherType::Arp => write!(f, "ARP"), - &EtherType::Unknown(id) => write!(f, "0x{:04x}", id) + match *self { + EtherType::Ipv4 => write!(f, "IPv4"), + EtherType::Ipv6 => write!(f, "IPv6"), + EtherType::Arp => write!(f, "ARP"), + EtherType::Unknown(id) => write!(f, "0x{id:04x}"), } } } /// A six-octet Ethernet II address. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Address(pub [u8; 6]); impl Address { @@ -42,14 +43,13 @@ impl Address { } /// Return an Ethernet address as a sequence of octets, in big-endian. - pub fn as_bytes(&self) -> &[u8] { + pub const fn as_bytes(&self) -> &[u8] { &self.0 } /// Query whether the address is an unicast address. pub fn is_unicast(&self) -> bool { - !(self.is_broadcast() || - self.is_multicast()) + !(self.is_broadcast() || self.is_multicast()) } /// Query whether this address is the broadcast address. @@ -58,12 +58,12 @@ impl Address { } /// Query whether the "multicast" bit in the OUI is set. - pub fn is_multicast(&self) -> bool { + pub const fn is_multicast(&self) -> bool { self.0[0] & 0x01 != 0 } /// Query whether the "locally administered" bit in the OUI is set. - pub fn is_local(&self) -> bool { + pub const fn is_local(&self) -> bool { self.0[0] & 0x02 != 0 } } @@ -71,29 +71,36 @@ impl Address { impl fmt::Display for Address { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { let bytes = self.0; - write!(f, "{:02x}-{:02x}-{:02x}-{:02x}-{:02x}-{:02x}", - bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5]) + write!( + f, + "{:02x}-{:02x}-{:02x}-{:02x}-{:02x}-{:02x}", + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5] + ) } } /// A read/write wrapper around an Ethernet II frame buffer. #[derive(Debug, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Frame> { - buffer: T + buffer: T, } mod field { - use wire::field::*; + use crate::wire::field::*; - pub const DESTINATION: Field = 0..6; - pub const SOURCE: Field = 6..12; - pub const ETHERTYPE: Field = 12..14; - pub const PAYLOAD: Rest = 14..; + pub const DESTINATION: Field = 0..6; + pub const SOURCE: Field = 6..12; + pub const ETHERTYPE: Field = 12..14; + pub const PAYLOAD: Rest = 14..; } +/// The Ethernet header length +pub const HEADER_LEN: usize = field::PAYLOAD.start; + impl> Frame { /// Imbue a raw octet buffer with Ethernet frame structure. - pub fn new_unchecked(buffer: T) -> Frame { + pub const fn new_unchecked(buffer: T) -> Frame { Frame { buffer } } @@ -108,11 +115,11 @@ impl> Frame { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); - if len < field::PAYLOAD.start { - Err(Error::Truncated) + if len < HEADER_LEN { + Err(Error) } else { Ok(()) } @@ -124,14 +131,14 @@ impl> Frame { } /// Return the length of a frame header. - pub fn header_len() -> usize { - field::PAYLOAD.start + pub const fn header_len() -> usize { + HEADER_LEN } /// Return the length of a buffer required to hold a packet with the payload /// of a given length. - pub fn buffer_len(payload_len: usize) -> usize { - field::PAYLOAD.start + payload_len + pub const fn buffer_len(payload_len: usize) -> usize { + HEADER_LEN + payload_len } /// Return the destination address field. @@ -204,21 +211,29 @@ impl> AsRef<[u8]> for Frame { impl> fmt::Display for Frame { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "EthernetII src={} dst={} type={}", - self.src_addr(), self.dst_addr(), self.ethertype()) + write!( + f, + "EthernetII src={} dst={} type={}", + self.src_addr(), + self.dst_addr(), + self.ethertype() + ) } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Frame { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { let frame = match Frame::new_checked(buffer) { - Err(err) => return write!(f, "{}({})", indent, err), - Ok(frame) => frame + Err(err) => return write!(f, "{indent}({err})"), + Ok(frame) => frame, }; - write!(f, "{}{}", indent, frame)?; + write!(f, "{indent}{frame}")?; match frame.ethertype() { #[cfg(feature = "proto-ipv4")] @@ -236,17 +251,18 @@ impl> PrettyPrint for Frame { indent.increase(f)?; super::Ipv6Packet::<&[u8]>::pretty_print(&frame.payload(), f, indent) } - _ => Ok(()) + _ => Ok(()), } } } /// A high-level representation of an Internet Protocol version 4 packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Repr { - pub src_addr: Address, - pub dst_addr: Address, - pub ethertype: EtherType, + pub src_addr: Address, + pub dst_addr: Address, + pub ethertype: EtherType, } impl Repr { @@ -261,8 +277,8 @@ impl Repr { } /// Return the length of a header that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { - field::PAYLOAD.start + pub const fn buffer_len(&self) -> usize { + HEADER_LEN } /// Emit a high-level representation into an Ethernet II frame. @@ -294,32 +310,32 @@ mod test_ipv4 { // Tests that are valid only with "proto-ipv4" use super::*; - static FRAME_BYTES: [u8; 64] = - [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, - 0x08, 0x00, - 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0xff]; - - static PAYLOAD_BYTES: [u8; 50] = - [0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0xff]; + static FRAME_BYTES: [u8; 64] = [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x08, 0x00, 0xaa, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0xff, + ]; + + static PAYLOAD_BYTES: [u8; 50] = [ + 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0xff, + ]; #[test] fn test_deconstruct() { let frame = Frame::new_unchecked(&FRAME_BYTES[..]); - assert_eq!(frame.dst_addr(), Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06])); - assert_eq!(frame.src_addr(), Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16])); + assert_eq!( + frame.dst_addr(), + Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]) + ); + assert_eq!( + frame.src_addr(), + Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16]) + ); assert_eq!(frame.ethertype(), EtherType::Ipv4); assert_eq!(frame.payload(), &PAYLOAD_BYTES[..]); } @@ -342,28 +358,30 @@ mod test_ipv6 { // Tests that are valid only with "proto-ipv6" use super::*; - static FRAME_BYTES: [u8; 54] = - [0x01, 0x02, 0x03, 0x04, 0x05, 0x06, - 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, - 0x86, 0xdd, - 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]; - - static PAYLOAD_BYTES: [u8; 40] = - [0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]; + static FRAME_BYTES: [u8; 54] = [ + 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x86, 0xdd, 0x60, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]; + + static PAYLOAD_BYTES: [u8; 40] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]; #[test] fn test_deconstruct() { let frame = Frame::new_unchecked(&FRAME_BYTES[..]); - assert_eq!(frame.dst_addr(), Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06])); - assert_eq!(frame.src_addr(), Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16])); + assert_eq!( + frame.dst_addr(), + Address([0x01, 0x02, 0x03, 0x04, 0x05, 0x06]) + ); + assert_eq!( + frame.src_addr(), + Address([0x11, 0x12, 0x13, 0x14, 0x15, 0x16]) + ); assert_eq!(frame.ethertype(), EtherType::Ipv6); assert_eq!(frame.payload(), &PAYLOAD_BYTES[..]); } diff --git a/src/wire/icmp.rs b/src/wire/icmp.rs index 684ea5e29..6bbc574cc 100644 --- a/src/wire/icmp.rs +++ b/src/wire/icmp.rs @@ -1,9 +1,10 @@ #[cfg(feature = "proto-ipv4")] -use super::icmpv4; +use crate::wire::icmpv4; #[cfg(feature = "proto-ipv6")] -use super::icmpv6; +use crate::wire::icmpv6; #[derive(Clone, PartialEq, Eq, Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr<'a> { #[cfg(feature = "proto-ipv4")] Ipv4(icmpv4::Repr<'a>), diff --git a/src/wire/icmpv4.rs b/src/wire/icmpv4.rs index a066bfd8e..60e12153f 100644 --- a/src/wire/icmpv4.rs +++ b/src/wire/icmpv4.rs @@ -1,14 +1,14 @@ -use core::{cmp, fmt}; use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt}; -use {Error, Result}; -use phy::ChecksumCapabilities; -use super::ip::checksum; -use super::{Ipv4Packet, Ipv4Repr}; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{Ipv4Packet, Ipv4Repr}; enum_with_unknown! { /// Internet protocol control message type. - pub doc enum Message(u8) { + pub enum Message(u8) { /// Echo reply EchoReply = 0, /// Destination unreachable @@ -34,25 +34,25 @@ enum_with_unknown! { impl fmt::Display for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Message::EchoReply => write!(f, "echo reply"), - &Message::DstUnreachable => write!(f, "destination unreachable"), - &Message::Redirect => write!(f, "message redirect"), - &Message::EchoRequest => write!(f, "echo request"), - &Message::RouterAdvert => write!(f, "router advertisement"), - &Message::RouterSolicit => write!(f, "router solicitation"), - &Message::TimeExceeded => write!(f, "time exceeded"), - &Message::ParamProblem => write!(f, "parameter problem"), - &Message::Timestamp => write!(f, "timestamp"), - &Message::TimestampReply => write!(f, "timestamp reply"), - &Message::Unknown(id) => write!(f, "{}", id) + match *self { + Message::EchoReply => write!(f, "echo reply"), + Message::DstUnreachable => write!(f, "destination unreachable"), + Message::Redirect => write!(f, "message redirect"), + Message::EchoRequest => write!(f, "echo request"), + Message::RouterAdvert => write!(f, "router advertisement"), + Message::RouterSolicit => write!(f, "router solicitation"), + Message::TimeExceeded => write!(f, "time exceeded"), + Message::ParamProblem => write!(f, "parameter problem"), + Message::Timestamp => write!(f, "timestamp"), + Message::TimestampReply => write!(f, "timestamp reply"), + Message::Unknown(id) => write!(f, "{id}"), } } } enum_with_unknown! { /// Internet protocol control message subtype for type "Destination Unreachable". - pub doc enum DstUnreachable(u8) { + pub enum DstUnreachable(u8) { /// Destination network unreachable NetUnreachable = 0, /// Destination host unreachable @@ -90,48 +90,33 @@ enum_with_unknown! { impl fmt::Display for DstUnreachable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &DstUnreachable::NetUnreachable => - write!(f, "destination network unreachable"), - &DstUnreachable::HostUnreachable => - write!(f, "destination host unreachable"), - &DstUnreachable::ProtoUnreachable => - write!(f, "destination protocol unreachable"), - &DstUnreachable::PortUnreachable => - write!(f, "destination port unreachable"), - &DstUnreachable::FragRequired => - write!(f, "fragmentation required, and DF flag set"), - &DstUnreachable::SrcRouteFailed => - write!(f, "source route failed"), - &DstUnreachable::DstNetUnknown => - write!(f, "destination network unknown"), - &DstUnreachable::DstHostUnknown => - write!(f, "destination host unknown"), - &DstUnreachable::SrcHostIsolated => - write!(f, "source host isolated"), - &DstUnreachable::NetProhibited => - write!(f, "network administratively prohibited"), - &DstUnreachable::HostProhibited => - write!(f, "host administratively prohibited"), - &DstUnreachable::NetUnreachToS => - write!(f, "network unreachable for ToS"), - &DstUnreachable::HostUnreachToS => - write!(f, "host unreachable for ToS"), - &DstUnreachable::CommProhibited => - write!(f, "communication administratively prohibited"), - &DstUnreachable::HostPrecedViol => - write!(f, "host precedence violation"), - &DstUnreachable::PrecedCutoff => - write!(f, "precedence cutoff in effect"), - &DstUnreachable::Unknown(id) => - write!(f, "{}", id) + match *self { + DstUnreachable::NetUnreachable => write!(f, "destination network unreachable"), + DstUnreachable::HostUnreachable => write!(f, "destination host unreachable"), + DstUnreachable::ProtoUnreachable => write!(f, "destination protocol unreachable"), + DstUnreachable::PortUnreachable => write!(f, "destination port unreachable"), + DstUnreachable::FragRequired => write!(f, "fragmentation required, and DF flag set"), + DstUnreachable::SrcRouteFailed => write!(f, "source route failed"), + DstUnreachable::DstNetUnknown => write!(f, "destination network unknown"), + DstUnreachable::DstHostUnknown => write!(f, "destination host unknown"), + DstUnreachable::SrcHostIsolated => write!(f, "source host isolated"), + DstUnreachable::NetProhibited => write!(f, "network administratively prohibited"), + DstUnreachable::HostProhibited => write!(f, "host administratively prohibited"), + DstUnreachable::NetUnreachToS => write!(f, "network unreachable for ToS"), + DstUnreachable::HostUnreachToS => write!(f, "host unreachable for ToS"), + DstUnreachable::CommProhibited => { + write!(f, "communication administratively prohibited") + } + DstUnreachable::HostPrecedViol => write!(f, "host precedence violation"), + DstUnreachable::PrecedCutoff => write!(f, "precedence cutoff in effect"), + DstUnreachable::Unknown(id) => write!(f, "{id}"), } } } enum_with_unknown! { /// Internet protocol control message subtype for type "Redirect Message". - pub doc enum Redirect(u8) { + pub enum Redirect(u8) { /// Redirect Datagram for the Network Net = 0, /// Redirect Datagram for the Host @@ -145,7 +130,7 @@ enum_with_unknown! { enum_with_unknown! { /// Internet protocol control message subtype for type "Time Exceeded". - pub doc enum TimeExceeded(u8) { + pub enum TimeExceeded(u8) { /// TTL expired in transit TtlExpired = 0, /// Fragment reassembly time exceeded @@ -153,9 +138,19 @@ enum_with_unknown! { } } +impl fmt::Display for TimeExceeded { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + TimeExceeded::TtlExpired => write!(f, "time-to-live exceeded in transit"), + TimeExceeded::FragExpired => write!(f, "fragment reassembly time exceeded"), + TimeExceeded::Unknown(id) => write!(f, "{id}"), + } + } +} + enum_with_unknown! { /// Internet protocol control message subtype for type "Parameter Problem". - pub doc enum ParamProblem(u8) { + pub enum ParamProblem(u8) { /// Pointer indicates the error AtPointer = 0, /// Missing a required option @@ -166,19 +161,20 @@ enum_with_unknown! { } /// A read/write wrapper around an Internet Control Message Protocol version 4 packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } mod field { - use wire::field::*; + use crate::wire::field::*; - pub const TYPE: usize = 0; - pub const CODE: usize = 1; - pub const CHECKSUM: Field = 2..4; + pub const TYPE: usize = 0; + pub const CODE: usize = 1; + pub const CHECKSUM: Field = 2..4; - pub const UNUSED: Field = 4..8; + pub const UNUSED: Field = 4..8; pub const ECHO_IDENT: Field = 4..6; pub const ECHO_SEQNO: Field = 6..8; @@ -188,7 +184,7 @@ mod field { impl> Packet { /// Imbue a raw octet buffer with ICMPv4 packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -203,7 +199,7 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_header_len]. /// @@ -211,7 +207,7 @@ impl> Packet { pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::HEADER_END { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -267,10 +263,10 @@ impl> Packet { /// The result depends on the value of the message type field. pub fn header_len(&self) -> usize { match self.msg_type() { - Message::EchoRequest => field::ECHO_SEQNO.end, - Message::EchoReply => field::ECHO_SEQNO.end, + Message::EchoRequest => field::ECHO_SEQNO.end, + Message::EchoReply => field::ECHO_SEQNO.end, Message::DstUnreachable => field::UNUSED.end, - _ => field::UNUSED.end // make a conservative assumption + _ => field::UNUSED.end, // make a conservative assumption } } @@ -279,7 +275,9 @@ impl> Packet { /// # Fuzzing /// This function always returns `true` when fuzzing. pub fn verify_checksum(&self) -> bool { - if cfg!(fuzzing) { return true } + if cfg!(fuzzing) { + return true; + } let data = self.buffer.as_ref(); checksum::data(data) == !0 @@ -366,51 +364,58 @@ impl> AsRef<[u8]> for Packet { /// A high-level representation of an Internet Control Message Protocol version 4 packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub enum Repr<'a> { EchoRequest { - ident: u16, + ident: u16, seq_no: u16, - data: &'a [u8] + data: &'a [u8], }, EchoReply { - ident: u16, + ident: u16, seq_no: u16, - data: &'a [u8] + data: &'a [u8], }, DstUnreachable { reason: DstUnreachable, header: Ipv4Repr, - data: &'a [u8] + data: &'a [u8], + }, + TimeExceeded { + reason: TimeExceeded, + header: Ipv4Repr, + data: &'a [u8], }, - #[doc(hidden)] - __Nonexhaustive } impl<'a> Repr<'a> { /// Parse an Internet Control Message Protocol version 4 packet and return /// a high-level representation. - pub fn parse(packet: &Packet<&'a T>, checksum_caps: &ChecksumCapabilities) - -> Result> - where T: AsRef<[u8]> + ?Sized { + pub fn parse( + packet: &Packet<&'a T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { // Valid checksum is expected. - if checksum_caps.icmpv4.rx() && !packet.verify_checksum() { return Err(Error::Checksum) } + if checksum_caps.icmpv4.rx() && !packet.verify_checksum() { + return Err(Error); + } match (packet.msg_type(), packet.msg_code()) { - (Message::EchoRequest, 0) => { - Ok(Repr::EchoRequest { - ident: packet.echo_ident(), - seq_no: packet.echo_seq_no(), - data: packet.data() - }) - }, - - (Message::EchoReply, 0) => { - Ok(Repr::EchoReply { - ident: packet.echo_ident(), - seq_no: packet.echo_seq_no(), - data: packet.data() - }) - }, + (Message::EchoRequest, 0) => Ok(Repr::EchoRequest { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.data(), + }), + + (Message::EchoReply, 0) => Ok(Repr::EchoReply { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.data(), + }), (Message::DstUnreachable, code) => { let ip_packet = Ipv4Packet::new_checked(packet.data())?; @@ -418,73 +423,124 @@ impl<'a> Repr<'a> { let payload = &packet.data()[ip_packet.header_len() as usize..]; // RFC 792 requires exactly eight bytes to be returned. // We allow more, since there isn't a reason not to, but require at least eight. - if payload.len() < 8 { return Err(Error::Truncated) } + if payload.len() < 8 { + return Err(Error); + } Ok(Repr::DstUnreachable { reason: DstUnreachable::from(code), header: Ipv4Repr { src_addr: ip_packet.src_addr(), dst_addr: ip_packet.dst_addr(), - protocol: ip_packet.protocol(), + next_header: ip_packet.next_header(), payload_len: payload.len(), - hop_limit: ip_packet.hop_limit() + hop_limit: ip_packet.hop_limit(), }, - data: payload + data: payload, }) } - _ => Err(Error::Unrecognized) + + (Message::TimeExceeded, code) => { + let ip_packet = Ipv4Packet::new_checked(packet.data())?; + + let payload = &packet.data()[ip_packet.header_len() as usize..]; + // RFC 792 requires exactly eight bytes to be returned. + // We allow more, since there isn't a reason not to, but require at least eight. + if payload.len() < 8 { + return Err(Error); + } + + Ok(Repr::TimeExceeded { + reason: TimeExceeded::from(code), + header: Ipv4Repr { + src_addr: ip_packet.src_addr(), + dst_addr: ip_packet.dst_addr(), + next_header: ip_packet.next_header(), + payload_len: payload.len(), + hop_limit: ip_packet.hop_limit(), + }, + data: payload, + }) + } + + _ => Err(Error), } } /// Return the length of a packet that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { match self { - &Repr::EchoRequest { data, .. } | - &Repr::EchoReply { data, .. } => { + &Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => { field::ECHO_SEQNO.end + data.len() - }, - &Repr::DstUnreachable { header, data, .. } => { + } + &Repr::DstUnreachable { header, data, .. } + | &Repr::TimeExceeded { header, data, .. } => { field::UNUSED.end + header.buffer_len() + data.len() } - &Repr::__Nonexhaustive => unreachable!() } } /// Emit a high-level representation into an Internet Control Message Protocol version 4 /// packet. pub fn emit(&self, packet: &mut Packet<&mut T>, checksum_caps: &ChecksumCapabilities) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { packet.set_msg_code(0); - match self { - &Repr::EchoRequest { ident, seq_no, data } => { + match *self { + Repr::EchoRequest { + ident, + seq_no, + data, + } => { packet.set_msg_type(Message::EchoRequest); packet.set_msg_code(0); packet.set_echo_ident(ident); packet.set_echo_seq_no(seq_no); let data_len = cmp::min(packet.data_mut().len(), data.len()); packet.data_mut()[..data_len].copy_from_slice(&data[..data_len]) - }, + } - &Repr::EchoReply { ident, seq_no, data } => { + Repr::EchoReply { + ident, + seq_no, + data, + } => { packet.set_msg_type(Message::EchoReply); packet.set_msg_code(0); packet.set_echo_ident(ident); packet.set_echo_seq_no(seq_no); let data_len = cmp::min(packet.data_mut().len(), data.len()); packet.data_mut()[..data_len].copy_from_slice(&data[..data_len]) - }, + } - &Repr::DstUnreachable { reason, header, data } => { + Repr::DstUnreachable { + reason, + header, + data, + } => { packet.set_msg_type(Message::DstUnreachable); packet.set_msg_code(reason.into()); let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut()); header.emit(&mut ip_packet, checksum_caps); let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; - payload.copy_from_slice(&data[..]) + payload.copy_from_slice(data) } - &Repr::__Nonexhaustive => unreachable!() + Repr::TimeExceeded { + reason, + header, + data, + } => { + packet.set_msg_type(Message::TimeExceeded); + packet.set_msg_code(reason.into()); + + let mut ip_packet = Ipv4Packet::new_unchecked(packet.data_mut()); + header.emit(&mut ip_packet, checksum_caps); + let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; + payload.copy_from_slice(data) + } } if checksum_caps.icmpv4.tx() { @@ -500,14 +556,18 @@ impl<'a> Repr<'a> { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self, &ChecksumCapabilities::default()) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "ICMPv4 ({})", err)?; + write!(f, "ICMPv4 ({err})")?; write!(f, " type={:?}", self.msg_type())?; match self.msg_type() { - Message::DstUnreachable => - write!(f, " code={:?}", DstUnreachable::from(self.msg_code())), - _ => write!(f, " code={}", self.msg_code()) + Message::DstUnreachable => { + write!(f, " code={:?}", DstUnreachable::from(self.msg_code())) + } + Message::TimeExceeded => { + write!(f, " code={:?}", TimeExceeded::from(self.msg_code())) + } + _ => write!(f, " code={}", self.msg_code()), } } } @@ -516,38 +576,59 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { impl<'a> fmt::Display for Repr<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Repr::EchoRequest { ident, seq_no, data } => - write!(f, "ICMPv4 echo request id={} seq={} len={}", - ident, seq_no, data.len()), - &Repr::EchoReply { ident, seq_no, data } => - write!(f, "ICMPv4 echo reply id={} seq={} len={}", - ident, seq_no, data.len()), - &Repr::DstUnreachable { reason, .. } => - write!(f, "ICMPv4 destination unreachable ({})", - reason), - &Repr::__Nonexhaustive => unreachable!() + match *self { + Repr::EchoRequest { + ident, + seq_no, + data, + } => write!( + f, + "ICMPv4 echo request id={} seq={} len={}", + ident, + seq_no, + data.len() + ), + Repr::EchoReply { + ident, + seq_no, + data, + } => write!( + f, + "ICMPv4 echo reply id={} seq={} len={}", + ident, + seq_no, + data.len() + ), + Repr::DstUnreachable { reason, .. } => { + write!(f, "ICMPv4 destination unreachable ({reason})") + } + Repr::TimeExceeded { reason, .. } => { + write!(f, "ICMPv4 time exceeded ({reason})") + } } } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { let packet = match Packet::new_checked(buffer) { - Err(err) => return write!(f, "{}({})", indent, err), - Ok(packet) => packet + Err(err) => return write!(f, "{indent}({err})"), + Ok(packet) => packet, }; - write!(f, "{}{}", indent, packet)?; + write!(f, "{indent}{packet}")?; match packet.msg_type() { - Message::DstUnreachable => { + Message::DstUnreachable | Message::TimeExceeded => { indent.increase(f)?; super::Ipv4Packet::<&[u8]>::pretty_print(&packet.data(), f, indent) } - _ => Ok(()) + _ => Ok(()), } } } @@ -556,13 +637,11 @@ impl> PrettyPrint for Packet { mod test { use super::*; - static ECHO_PACKET_BYTES: [u8; 12] = - [0x08, 0x00, 0x8e, 0xfe, - 0x12, 0x34, 0xab, 0xcd, - 0xaa, 0x00, 0x00, 0xff]; + static ECHO_PACKET_BYTES: [u8; 12] = [ + 0x08, 0x00, 0x8e, 0xfe, 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff, + ]; - static ECHO_DATA_BYTES: [u8; 4] = - [0xaa, 0x00, 0x00, 0xff]; + static ECHO_DATA_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; #[test] fn test_echo_deconstruct() { @@ -573,7 +652,7 @@ mod test { assert_eq!(packet.echo_ident(), 0x1234); assert_eq!(packet.echo_seq_no(), 0xabcd); assert_eq!(packet.data(), &ECHO_DATA_BYTES[..]); - assert_eq!(packet.verify_checksum(), true); + assert!(packet.verify_checksum()); } #[test] @@ -593,7 +672,7 @@ mod test { Repr::EchoRequest { ident: 0x1234, seq_no: 0xabcd, - data: &ECHO_DATA_BYTES + data: &ECHO_DATA_BYTES, } } @@ -615,10 +694,9 @@ mod test { #[test] fn test_check_len() { - let bytes = [0x0b, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00]; - assert_eq!(Packet::new_checked(&[]), Err(Error::Truncated)); - assert_eq!(Packet::new_checked(&bytes[..4]), Err(Error::Truncated)); + let bytes = [0x0b, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert_eq!(Packet::new_checked(&[]), Err(Error)); + assert_eq!(Packet::new_checked(&bytes[..4]), Err(Error)); assert!(Packet::new_checked(&bytes[..]).is_ok()); } } diff --git a/src/wire/icmpv6.rs b/src/wire/icmpv6.rs index fd180630c..d7130aa6a 100644 --- a/src/wire/icmpv6.rs +++ b/src/wire/icmpv6.rs @@ -1,17 +1,19 @@ -use core::{cmp, fmt}; use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt}; -use {Error, Result}; -use phy::ChecksumCapabilities; -use super::ip::checksum; -use super::{IpAddress, IpProtocol, Ipv6Packet, Ipv6Repr}; -use super::MldRepr; -#[cfg(feature = "ethernet")] -use super::NdiscRepr; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::MldRepr; +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +use crate::wire::NdiscRepr; +#[cfg(feature = "proto-rpl")] +use crate::wire::RplRepr; +use crate::wire::{IpAddress, IpProtocol, Ipv6Packet, Ipv6Repr}; enum_with_unknown! { /// Internet protocol control message type. - pub doc enum Message(u8) { + pub enum Message(u8) { /// Destination Unreachable. DstUnreachable = 0x01, /// Packet Too Big. @@ -37,7 +39,9 @@ enum_with_unknown! { /// Redirect Redirect = 0x89, /// Multicast Listener Report - MldReport = 0x8f + MldReport = 0x8f, + /// RPL Control Message + RplControl = 0x9b, } } @@ -55,10 +59,13 @@ impl Message { /// is an [NDISC] message type. /// /// [NDISC]: https://tools.ietf.org/html/rfc4861 - pub fn is_ndisc(&self) -> bool { + pub const fn is_ndisc(&self) -> bool { match *self { - Message::RouterSolicit | Message::RouterAdvert | Message::NeighborSolicit | - Message::NeighborAdvert | Message::Redirect => true, + Message::RouterSolicit + | Message::RouterAdvert + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect => true, _ => false, } } @@ -67,7 +74,7 @@ impl Message { /// is an [MLD] message type. /// /// [MLD]: https://tools.ietf.org/html/rfc3810 - pub fn is_mld(&self) -> bool { + pub const fn is_mld(&self) -> bool { match *self { Message::MldQuery | Message::MldReport => true, _ => false, @@ -77,28 +84,29 @@ impl Message { impl fmt::Display for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Message::DstUnreachable => write!(f, "destination unreachable"), - &Message::PktTooBig => write!(f, "packet too big"), - &Message::TimeExceeded => write!(f, "time exceeded"), - &Message::ParamProblem => write!(f, "parameter problem"), - &Message::EchoReply => write!(f, "echo reply"), - &Message::EchoRequest => write!(f, "echo request"), - &Message::RouterSolicit => write!(f, "router solicitation"), - &Message::RouterAdvert => write!(f, "router advertisement"), - &Message::NeighborSolicit => write!(f, "neighbor solicitation"), - &Message::NeighborAdvert => write!(f, "neighbor advert"), - &Message::Redirect => write!(f, "redirect"), - &Message::MldQuery => write!(f, "multicast listener query"), - &Message::MldReport => write!(f, "multicast listener report"), - &Message::Unknown(id) => write!(f, "{}", id) + match *self { + Message::DstUnreachable => write!(f, "destination unreachable"), + Message::PktTooBig => write!(f, "packet too big"), + Message::TimeExceeded => write!(f, "time exceeded"), + Message::ParamProblem => write!(f, "parameter problem"), + Message::EchoReply => write!(f, "echo reply"), + Message::EchoRequest => write!(f, "echo request"), + Message::RouterSolicit => write!(f, "router solicitation"), + Message::RouterAdvert => write!(f, "router advertisement"), + Message::NeighborSolicit => write!(f, "neighbor solicitation"), + Message::NeighborAdvert => write!(f, "neighbor advert"), + Message::Redirect => write!(f, "redirect"), + Message::MldQuery => write!(f, "multicast listener query"), + Message::MldReport => write!(f, "multicast listener report"), + Message::RplControl => write!(f, "RPL control message"), + Message::Unknown(id) => write!(f, "{id}"), } } } enum_with_unknown! { /// Internet protocol control message subtype for type "Destination Unreachable". - pub doc enum DstUnreachable(u8) { + pub enum DstUnreachable(u8) { /// No Route to destination. NoRoute = 0, /// Communication with destination administratively prohibited. @@ -118,30 +126,27 @@ enum_with_unknown! { impl fmt::Display for DstUnreachable { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &DstUnreachable::NoRoute => - write!(f, "no route to destination"), - &DstUnreachable::AdminProhibit => - write!(f, "communication with destination administratively prohibited"), - &DstUnreachable::BeyondScope => - write!(f, "beyond scope of source address"), - &DstUnreachable::AddrUnreachable => - write!(f, "address unreachable"), - &DstUnreachable::PortUnreachable => - write!(f, "port unreachable"), - &DstUnreachable::FailedPolicy => - write!(f, "source address failed ingress/egress policy"), - &DstUnreachable::RejectRoute => - write!(f, "reject route to destination"), - &DstUnreachable::Unknown(id) => - write!(f, "{}", id) + match *self { + DstUnreachable::NoRoute => write!(f, "no route to destination"), + DstUnreachable::AdminProhibit => write!( + f, + "communication with destination administratively prohibited" + ), + DstUnreachable::BeyondScope => write!(f, "beyond scope of source address"), + DstUnreachable::AddrUnreachable => write!(f, "address unreachable"), + DstUnreachable::PortUnreachable => write!(f, "port unreachable"), + DstUnreachable::FailedPolicy => { + write!(f, "source address failed ingress/egress policy") + } + DstUnreachable::RejectRoute => write!(f, "reject route to destination"), + DstUnreachable::Unknown(id) => write!(f, "{id}"), } } } enum_with_unknown! { /// Internet protocol control message subtype for the type "Parameter Problem". - pub doc enum ParamProblem(u8) { + pub enum ParamProblem(u8) { /// Erroneous header field encountered. ErroneousHdrField = 0, /// Unrecognized Next Header type encountered. @@ -153,22 +158,18 @@ enum_with_unknown! { impl fmt::Display for ParamProblem { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &ParamProblem::ErroneousHdrField => - write!(f, "erroneous header field."), - &ParamProblem::UnrecognizedNxtHdr => - write!(f, "unrecognized next header type."), - &ParamProblem::UnrecognizedOption => - write!(f, "unrecognized IPv6 option."), - &ParamProblem::Unknown(id) => - write!(f, "{}", id) + match *self { + ParamProblem::ErroneousHdrField => write!(f, "erroneous header field."), + ParamProblem::UnrecognizedNxtHdr => write!(f, "unrecognized next header type."), + ParamProblem::UnrecognizedOption => write!(f, "unrecognized IPv6 option."), + ParamProblem::Unknown(id) => write!(f, "{id}"), } } } enum_with_unknown! { /// Internet protocol control message subtype for the type "Time Exceeded". - pub doc enum TimeExceeded(u8) { + pub enum TimeExceeded(u8) { /// Hop limit exceeded in transit. HopLimitExceeded = 0, /// Fragment reassembly time exceeded. @@ -178,82 +179,80 @@ enum_with_unknown! { impl fmt::Display for TimeExceeded { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &TimeExceeded::HopLimitExceeded => - write!(f, "hop limit exceeded in transit"), - &TimeExceeded::FragReassemExceeded => - write!(f, "fragment reassembly time exceeded"), - &TimeExceeded::Unknown(id) => - write!(f, "{}", id) + match *self { + TimeExceeded::HopLimitExceeded => write!(f, "hop limit exceeded in transit"), + TimeExceeded::FragReassemExceeded => write!(f, "fragment reassembly time exceeded"), + TimeExceeded::Unknown(id) => write!(f, "{id}"), } } } /// A read/write wrapper around an Internet Control Message Protocol version 6 packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - pub(super) buffer: T + pub(super) buffer: T, } // Ranges and constants describing key boundaries in the ICMPv6 header. pub(super) mod field { - use wire::field::*; + use crate::wire::field::*; // ICMPv6: See https://tools.ietf.org/html/rfc4443 - pub const TYPE: usize = 0; - pub const CODE: usize = 1; - pub const CHECKSUM: Field = 2..4; + pub const TYPE: usize = 0; + pub const CODE: usize = 1; + pub const CHECKSUM: Field = 2..4; - pub const UNUSED: Field = 4..8; - pub const MTU: Field = 4..8; - pub const POINTER: Field = 4..8; - pub const ECHO_IDENT: Field = 4..6; - pub const ECHO_SEQNO: Field = 6..8; + pub const UNUSED: Field = 4..8; + pub const MTU: Field = 4..8; + pub const POINTER: Field = 4..8; + pub const ECHO_IDENT: Field = 4..6; + pub const ECHO_SEQNO: Field = 6..8; - pub const HEADER_END: usize = 8; + pub const HEADER_END: usize = 8; // NDISC: See https://tools.ietf.org/html/rfc4861 // Router Advertisement message offsets - pub const CUR_HOP_LIMIT: usize = 4; - pub const ROUTER_FLAGS: usize = 5; - pub const ROUTER_LT: Field = 6..8; - pub const REACHABLE_TM: Field = 8..12; - pub const RETRANS_TM: Field = 12..16; + pub const CUR_HOP_LIMIT: usize = 4; + pub const ROUTER_FLAGS: usize = 5; + pub const ROUTER_LT: Field = 6..8; + pub const REACHABLE_TM: Field = 8..12; + pub const RETRANS_TM: Field = 12..16; // Neighbor Solicitation message offsets - pub const TARGET_ADDR: Field = 8..24; + pub const TARGET_ADDR: Field = 8..24; // Neighbor Advertisement message offsets - pub const NEIGH_FLAGS: usize = 4; + pub const NEIGH_FLAGS: usize = 4; // Redirected Header message offsets - pub const DEST_ADDR: Field = 24..40; + pub const DEST_ADDR: Field = 24..40; // MLD: // - https://tools.ietf.org/html/rfc3810 // - https://tools.ietf.org/html/rfc3810 // Multicast Listener Query message - pub const MAX_RESP_CODE: Field = 4..6; - pub const QUERY_RESV: Field = 6..8; - pub const QUERY_MCAST_ADDR: Field = 8..24; - pub const SQRV: usize = 24; - pub const QQIC: usize = 25; - pub const QUERY_NUM_SRCS: Field = 26..28; + pub const MAX_RESP_CODE: Field = 4..6; + pub const QUERY_RESV: Field = 6..8; + pub const QUERY_MCAST_ADDR: Field = 8..24; + pub const SQRV: usize = 24; + pub const QQIC: usize = 25; + pub const QUERY_NUM_SRCS: Field = 26..28; // Multicast Listener Report Message - pub const RECORD_RESV: Field = 4..6; - pub const NR_MCAST_RCRDS: Field = 6..8; + pub const RECORD_RESV: Field = 4..6; + pub const NR_MCAST_RCRDS: Field = 6..8; // Multicast Address Record Offsets - pub const RECORD_TYPE: usize = 0; - pub const AUX_DATA_LEN: usize = 1; - pub const RECORD_NUM_SRCS: Field = 2..4; + pub const RECORD_TYPE: usize = 0; + pub const AUX_DATA_LEN: usize = 1; + pub const RECORD_NUM_SRCS: Field = 2..4; pub const RECORD_MCAST_ADDR: Field = 4..20; } impl> Packet { /// Imbue a raw octet buffer with ICMPv6 packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -268,18 +267,71 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); - if len < field::HEADER_END { - Err(Error::Truncated) - } else { - if len < self.header_len() { - Err(Error::Truncated) - } else { - Ok(()) + + if len < 4 { + return Err(Error); + } + + match self.msg_type() { + Message::DstUnreachable + | Message::PktTooBig + | Message::TimeExceeded + | Message::ParamProblem + | Message::EchoRequest + | Message::EchoReply + | Message::MldQuery + | Message::RouterSolicit + | Message::RouterAdvert + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect + | Message::MldReport => { + if len < field::HEADER_END || len < self.header_len() { + return Err(Error); + } } + #[cfg(feature = "proto-rpl")] + Message::RplControl => match super::rpl::RplControlMessage::from(self.msg_code()) { + super::rpl::RplControlMessage::DodagInformationSolicitation => { + // TODO(thvdveld): replace magic number + if len < 6 { + return Err(Error); + } + } + super::rpl::RplControlMessage::DodagInformationObject => { + // TODO(thvdveld): replace magic number + if len < 28 { + return Err(Error); + } + } + super::rpl::RplControlMessage::DestinationAdvertisementObject => { + // TODO(thvdveld): replace magic number + if len < 8 || (self.dao_dodag_id_present() && len < 24) { + return Err(Error); + } + } + super::rpl::RplControlMessage::DestinationAdvertisementObjectAck => { + // TODO(thvdveld): replace magic number + if len < 8 || (self.dao_dodag_id_present() && len < 24) { + return Err(Error); + } + } + super::rpl::RplControlMessage::SecureDodagInformationSolicitation + | super::rpl::RplControlMessage::SecureDodagInformationObject + | super::rpl::RplControlMessage::SecureDesintationAdvertismentObject + | super::rpl::RplControlMessage::SecureDestinationAdvertisementObjectAck + | super::rpl::RplControlMessage::ConsistencyCheck => return Err(Error), + super::rpl::RplControlMessage::Unknown(_) => return Err(Error), + }, + #[cfg(not(feature = "proto-rpl"))] + Message::RplControl => return Err(Error), + Message::Unknown(_) => return Err(Error), } + + Ok(()) } /// Consume the packet, returning the underlying buffer. @@ -336,29 +388,28 @@ impl> Packet { NetworkEndian::read_u32(&data[field::POINTER]) } - /// Return the header length. The result depends on the value of /// the message type field. pub fn header_len(&self) -> usize { match self.msg_type() { - Message::DstUnreachable => field::UNUSED.end, - Message::PktTooBig => field::MTU.end, - Message::TimeExceeded => field::UNUSED.end, - Message::ParamProblem => field::POINTER.end, - Message::EchoRequest => field::ECHO_SEQNO.end, - Message::EchoReply => field::ECHO_SEQNO.end, - Message::RouterSolicit => field::UNUSED.end, - Message::RouterAdvert => field::RETRANS_TM.end, + Message::DstUnreachable => field::UNUSED.end, + Message::PktTooBig => field::MTU.end, + Message::TimeExceeded => field::UNUSED.end, + Message::ParamProblem => field::POINTER.end, + Message::EchoRequest => field::ECHO_SEQNO.end, + Message::EchoReply => field::ECHO_SEQNO.end, + Message::RouterSolicit => field::UNUSED.end, + Message::RouterAdvert => field::RETRANS_TM.end, Message::NeighborSolicit => field::TARGET_ADDR.end, - Message::NeighborAdvert => field::TARGET_ADDR.end, - Message::Redirect => field::DEST_ADDR.end, - Message::MldQuery => field::QUERY_NUM_SRCS.end, - Message::MldReport => field::NR_MCAST_RCRDS.end, + Message::NeighborAdvert => field::TARGET_ADDR.end, + Message::Redirect => field::DEST_ADDR.end, + Message::MldQuery => field::QUERY_NUM_SRCS.end, + Message::MldReport => field::NR_MCAST_RCRDS.end, // For packets that are not included in RFC 4443, do not // include the last 32 bits of the ICMPv6 header in // `header_bytes`. This must be done so that these bytes // can be accessed in the `payload`. - _ => field::CHECKSUM.end + _ => field::CHECKSUM.end, } } @@ -367,13 +418,14 @@ impl> Packet { /// # Fuzzing /// This function always returns `true` when fuzzing. pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { - if cfg!(fuzzing) { return true } + if cfg!(fuzzing) { + return true; + } let data = self.buffer.as_ref(); checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, - data.len() as u32), - checksum::data(data) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, data.len() as u32), + checksum::data(data), ]) == !0 } } @@ -412,21 +464,23 @@ impl + AsMut<[u8]>> Packet { #[inline] pub fn clear_reserved(&mut self) { match self.msg_type() { - Message::RouterSolicit | Message::NeighborSolicit | - Message::NeighborAdvert | Message::Redirect => { + Message::RouterSolicit + | Message::NeighborSolicit + | Message::NeighborAdvert + | Message::Redirect => { let data = self.buffer.as_mut(); NetworkEndian::write_u32(&mut data[field::UNUSED], 0); - }, + } Message::MldQuery => { let data = self.buffer.as_mut(); NetworkEndian::write_u16(&mut data[field::QUERY_RESV], 0); - data[field::SQRV] = data[field::SQRV] & 0xf; - }, + data[field::SQRV] &= 0xf; + } Message::MldReport => { let data = self.buffer.as_mut(); NetworkEndian::write_u16(&mut data[field::RECORD_RESV], 0); } - ty => panic!("Message type `{}` does not have any reserved fields.", ty), + ty => panic!("Message type `{ty}` does not have any reserved fields."), } } @@ -482,9 +536,8 @@ impl + AsMut<[u8]>> Packet { let checksum = { let data = self.buffer.as_ref(); !checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, - data.len() as u32), - checksum::data(data) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Icmpv6, data.len() as u32), + checksum::data(data), ]) }; self.set_checksum(checksum) @@ -507,71 +560,81 @@ impl> AsRef<[u8]> for Packet { /// A high-level representation of an Internet Control Message Protocol version 6 packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub enum Repr<'a> { DstUnreachable { reason: DstUnreachable, header: Ipv6Repr, - data: &'a [u8] + data: &'a [u8], }, PktTooBig { mtu: u32, header: Ipv6Repr, - data: &'a [u8] + data: &'a [u8], }, TimeExceeded { reason: TimeExceeded, header: Ipv6Repr, - data: &'a [u8] + data: &'a [u8], }, ParamProblem { - reason: ParamProblem, + reason: ParamProblem, pointer: u32, - header: Ipv6Repr, - data: &'a [u8] + header: Ipv6Repr, + data: &'a [u8], }, EchoRequest { - ident: u16, + ident: u16, seq_no: u16, - data: &'a [u8] + data: &'a [u8], }, EchoReply { - ident: u16, + ident: u16, seq_no: u16, - data: &'a [u8] + data: &'a [u8], }, - #[cfg(feature = "ethernet")] + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] Ndisc(NdiscRepr<'a>), Mld(MldRepr<'a>), - #[doc(hidden)] - __Nonexhaustive + #[cfg(feature = "proto-rpl")] + Rpl(RplRepr<'a>), } impl<'a> Repr<'a> { /// Parse an Internet Control Message Protocol version 6 packet and return /// a high-level representation. - pub fn parse(src_addr: &IpAddress, dst_addr: &IpAddress, - packet: &Packet<&'a T>, checksum_caps: &ChecksumCapabilities) - -> Result> - where T: AsRef<[u8]> + ?Sized { - fn create_packet_from_payload<'a, T>(packet: &Packet<&'a T>) - -> Result<(&'a [u8], Ipv6Repr)> - where T: AsRef<[u8]> + ?Sized { + pub fn parse( + src_addr: &IpAddress, + dst_addr: &IpAddress, + packet: &Packet<&'a T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { + fn create_packet_from_payload<'a, T>(packet: &Packet<&'a T>) -> Result<(&'a [u8], Ipv6Repr)> + where + T: AsRef<[u8]> + ?Sized, + { let ip_packet = Ipv6Packet::new_checked(packet.payload())?; - let payload = &packet.payload()[ip_packet.header_len() as usize..]; - if payload.len() < 8 { return Err(Error::Truncated) } + let payload = &packet.payload()[ip_packet.header_len()..]; + if payload.len() < 8 { + return Err(Error); + } let repr = Ipv6Repr { src_addr: ip_packet.src_addr(), dst_addr: ip_packet.dst_addr(), next_header: ip_packet.next_header(), payload_len: payload.len(), - hop_limit: ip_packet.hop_limit() + hop_limit: ip_packet.hop_limit(), }; Ok((payload, repr)) } // Valid checksum is expected. if checksum_caps.icmpv6.rx() && !packet.verify_checksum(src_addr, dst_addr) { - return Err(Error::Checksum) + return Err(Error); } match (packet.msg_type(), packet.msg_code()) { @@ -580,152 +643,168 @@ impl<'a> Repr<'a> { Ok(Repr::DstUnreachable { reason: DstUnreachable::from(code), header: repr, - data: payload + data: payload, }) - }, + } (Message::PktTooBig, 0) => { let (payload, repr) = create_packet_from_payload(packet)?; Ok(Repr::PktTooBig { mtu: packet.pkt_too_big_mtu(), header: repr, - data: payload + data: payload, }) - }, + } (Message::TimeExceeded, code) => { let (payload, repr) = create_packet_from_payload(packet)?; Ok(Repr::TimeExceeded { reason: TimeExceeded::from(code), header: repr, - data: payload + data: payload, }) - }, + } (Message::ParamProblem, code) => { let (payload, repr) = create_packet_from_payload(packet)?; Ok(Repr::ParamProblem { reason: ParamProblem::from(code), pointer: packet.param_problem_ptr(), header: repr, - data: payload + data: payload, }) - }, - (Message::EchoRequest, 0) => { - Ok(Repr::EchoRequest { - ident: packet.echo_ident(), - seq_no: packet.echo_seq_no(), - data: packet.payload() - }) - }, - (Message::EchoReply, 0) => { - Ok(Repr::EchoReply { - ident: packet.echo_ident(), - seq_no: packet.echo_seq_no(), - data: packet.payload() - }) - }, - #[cfg(feature = "ethernet")] - (msg_type, 0) if msg_type.is_ndisc() => { - NdiscRepr::parse(packet).map(|repr| Repr::Ndisc(repr)) - }, - (msg_type, 0) if msg_type.is_mld() => { - MldRepr::parse(packet).map(|repr| Repr::Mld(repr)) - }, - _ => Err(Error::Unrecognized) + } + (Message::EchoRequest, 0) => Ok(Repr::EchoRequest { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.payload(), + }), + (Message::EchoReply, 0) => Ok(Repr::EchoReply { + ident: packet.echo_ident(), + seq_no: packet.echo_seq_no(), + data: packet.payload(), + }), + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + (msg_type, 0) if msg_type.is_ndisc() => NdiscRepr::parse(packet).map(Repr::Ndisc), + (msg_type, 0) if msg_type.is_mld() => MldRepr::parse(packet).map(Repr::Mld), + #[cfg(feature = "proto-rpl")] + (Message::RplControl, _) => RplRepr::parse(packet).map(Repr::Rpl), + _ => Err(Error), } } /// Return the length of a packet that will be emitted from this high-level representation. pub fn buffer_len(&self) -> usize { match self { - &Repr::DstUnreachable { header, data, .. } | &Repr::PktTooBig { header, data, .. } | - &Repr::TimeExceeded { header, data, .. } | &Repr::ParamProblem { header, data, .. } => { + &Repr::DstUnreachable { header, data, .. } + | &Repr::PktTooBig { header, data, .. } + | &Repr::TimeExceeded { header, data, .. } + | &Repr::ParamProblem { header, data, .. } => { field::UNUSED.end + header.buffer_len() + data.len() } - &Repr::EchoRequest { data, .. } | - &Repr::EchoReply { data, .. } => { + &Repr::EchoRequest { data, .. } | &Repr::EchoReply { data, .. } => { field::ECHO_SEQNO.end + data.len() - }, - #[cfg(feature = "ethernet")] - &Repr::Ndisc(ndisc) => { - ndisc.buffer_len() - }, - &Repr::Mld(mld) => { - mld.buffer_len() - }, - &Repr::__Nonexhaustive => unreachable!() + } + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + &Repr::Ndisc(ndisc) => ndisc.buffer_len(), + &Repr::Mld(mld) => mld.buffer_len(), + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => rpl.buffer_len(), } } /// Emit a high-level representation into an Internet Control Message Protocol version 6 /// packet. - pub fn emit(&self, src_addr: &IpAddress, dst_addr: &IpAddress, - packet: &mut Packet<&mut T>, checksum_caps: &ChecksumCapabilities) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { + pub fn emit( + &self, + src_addr: &IpAddress, + dst_addr: &IpAddress, + packet: &mut Packet<&mut T>, + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { fn emit_contained_packet(buffer: &mut [u8], header: Ipv6Repr, data: &[u8]) { let mut ip_packet = Ipv6Packet::new_unchecked(buffer); header.emit(&mut ip_packet); let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; - payload.copy_from_slice(&data[..]); + payload.copy_from_slice(data); } - match self { - &Repr::DstUnreachable { reason, header, data } => { + match *self { + Repr::DstUnreachable { + reason, + header, + data, + } => { packet.set_msg_type(Message::DstUnreachable); packet.set_msg_code(reason.into()); - emit_contained_packet(packet.payload_mut(), header, &data); - }, + emit_contained_packet(packet.payload_mut(), header, data); + } - &Repr::PktTooBig { mtu, header, data } => { + Repr::PktTooBig { mtu, header, data } => { packet.set_msg_type(Message::PktTooBig); packet.set_msg_code(0); packet.set_pkt_too_big_mtu(mtu); - emit_contained_packet(packet.payload_mut(), header, &data); - }, + emit_contained_packet(packet.payload_mut(), header, data); + } - &Repr::TimeExceeded { reason, header, data } => { + Repr::TimeExceeded { + reason, + header, + data, + } => { packet.set_msg_type(Message::TimeExceeded); packet.set_msg_code(reason.into()); - emit_contained_packet(packet.payload_mut(), header, &data); - }, + emit_contained_packet(packet.payload_mut(), header, data); + } - &Repr::ParamProblem { reason, pointer, header, data } => { + Repr::ParamProblem { + reason, + pointer, + header, + data, + } => { packet.set_msg_type(Message::ParamProblem); packet.set_msg_code(reason.into()); packet.set_param_problem_ptr(pointer); - emit_contained_packet(packet.payload_mut(), header, &data); - }, + emit_contained_packet(packet.payload_mut(), header, data); + } - &Repr::EchoRequest { ident, seq_no, data } => { + Repr::EchoRequest { + ident, + seq_no, + data, + } => { packet.set_msg_type(Message::EchoRequest); packet.set_msg_code(0); packet.set_echo_ident(ident); packet.set_echo_seq_no(seq_no); let data_len = cmp::min(packet.payload_mut().len(), data.len()); packet.payload_mut()[..data_len].copy_from_slice(&data[..data_len]) - }, + } - &Repr::EchoReply { ident, seq_no, data } => { + Repr::EchoReply { + ident, + seq_no, + data, + } => { packet.set_msg_type(Message::EchoReply); packet.set_msg_code(0); packet.set_echo_ident(ident); packet.set_echo_seq_no(seq_no); let data_len = cmp::min(packet.payload_mut().len(), data.len()); packet.payload_mut()[..data_len].copy_from_slice(&data[..data_len]) - }, + } - #[cfg(feature = "ethernet")] - &Repr::Ndisc(ndisc) => { - ndisc.emit(packet) - }, + #[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] + Repr::Ndisc(ndisc) => ndisc.emit(packet), - &Repr::Mld(mld) => { - mld.emit(packet) - }, + Repr::Mld(mld) => mld.emit(packet), - &Repr::__Nonexhaustive => unreachable!(), + #[cfg(feature = "proto-rpl")] + Repr::Rpl(ref rpl) => rpl.emit(packet), } if checksum_caps.icmpv6.tx() { @@ -739,60 +818,39 @@ impl<'a> Repr<'a> { #[cfg(test)] mod test { - use wire::{Ipv6Address, Ipv6Repr, IpProtocol}; - use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; use super::*; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; + use crate::wire::{IpProtocol, Ipv6Address, Ipv6Repr}; + + static ECHO_PACKET_BYTES: [u8; 12] = [ + 0x80, 0x00, 0x19, 0xb3, 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff, + ]; + + static ECHO_PACKET_PAYLOAD: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; + + static PKT_TOO_BIG_BYTES: [u8; 60] = [ + 0x02, 0x00, 0x0f, 0xc9, 0x00, 0x00, 0x05, 0xdc, 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, + 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; + + static PKT_TOO_BIG_IP_PAYLOAD: [u8; 52] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xbf, 0x00, 0x00, 0x35, 0x00, + 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; - static ECHO_PACKET_BYTES: [u8; 12] = - [0x80, 0x00, 0x19, 0xb3, - 0x12, 0x34, 0xab, 0xcd, - 0xaa, 0x00, 0x00, 0xff]; - - static ECHO_PACKET_PAYLOAD: [u8; 4] = - [0xaa, 0x00, 0x00, 0xff]; - - static PKT_TOO_BIG_BYTES: [u8; 60] = - [0x02, 0x00, 0x0f, 0xc9, - 0x00, 0x00, 0x05, 0xdc, - 0x60, 0x00, 0x00, 0x00, - 0x00, 0x0c, 0x11, 0x40, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02, - 0xbf, 0x00, 0x00, 0x35, - 0x00, 0x0c, 0x12, 0x4d, - 0xaa, 0x00, 0x00, 0xff]; - - static PKT_TOO_BIG_IP_PAYLOAD: [u8; 52] = - [0x60, 0x00, 0x00, 0x00, - 0x00, 0x0c, 0x11, 0x40, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02, - 0xbf, 0x00, 0x00, 0x35, - 0x00, 0x0c, 0x12, 0x4d, - 0xaa, 0x00, 0x00, 0xff]; - - static PKT_TOO_BIG_UDP_PAYLOAD: [u8; 12] = - [0xbf, 0x00, 0x00, 0x35, - 0x00, 0x0c, 0x12, 0x4d, - 0xaa, 0x00, 0x00, 0xff]; + static PKT_TOO_BIG_UDP_PAYLOAD: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; fn echo_packet_repr() -> Repr<'static> { Repr::EchoRequest { ident: 0x1234, seq_no: 0xabcd, - data: &ECHO_PACKET_PAYLOAD + data: &ECHO_PACKET_PAYLOAD, } } @@ -800,17 +858,17 @@ mod test { Repr::PktTooBig { mtu: 1500, header: Ipv6Repr { - src_addr: Ipv6Address([0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01]), - dst_addr: Ipv6Address([0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02]), + src_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ]), + dst_addr: Ipv6Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x02, + ]), next_header: IpProtocol::Udp, payload_len: 12, - hop_limit: 0x40 + hop_limit: 0x40, }, data: &PKT_TOO_BIG_UDP_PAYLOAD, } @@ -825,7 +883,7 @@ mod test { assert_eq!(packet.echo_ident(), 0x1234); assert_eq!(packet.echo_seq_no(), 0xabcd); assert_eq!(packet.payload(), &ECHO_PACKET_PAYLOAD[..]); - assert_eq!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2), true); + assert!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2)); assert!(!packet.msg_type().is_error()); } @@ -837,16 +895,23 @@ mod test { packet.set_msg_code(0); packet.set_echo_ident(0x1234); packet.set_echo_seq_no(0xabcd); - packet.payload_mut().copy_from_slice(&ECHO_PACKET_PAYLOAD[..]); + packet + .payload_mut() + .copy_from_slice(&ECHO_PACKET_PAYLOAD[..]); packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); - assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &ECHO_PACKET_BYTES[..]); } #[test] fn test_echo_repr_parse() { let packet = Packet::new_unchecked(&ECHO_PACKET_BYTES[..]); - let repr = Repr::parse(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &packet, &ChecksumCapabilities::default()).unwrap(); + let repr = Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default(), + ) + .unwrap(); assert_eq!(repr, echo_packet_repr()); } @@ -855,9 +920,13 @@ mod test { let repr = echo_packet_repr(); let mut bytes = vec![0xa5; repr.buffer_len()]; let mut packet = Packet::new_unchecked(&mut bytes); - repr.emit(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &mut packet, &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &ECHO_PACKET_BYTES[..]); + repr.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &ECHO_PACKET_BYTES[..]); } #[test] @@ -868,7 +937,7 @@ mod test { assert_eq!(packet.checksum(), 0x0fc9); assert_eq!(packet.pkt_too_big_mtu(), 1500); assert_eq!(packet.payload(), &PKT_TOO_BIG_IP_PAYLOAD[..]); - assert_eq!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2), true); + assert!(packet.verify_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2)); assert!(packet.msg_type().is_error()); } @@ -879,16 +948,23 @@ mod test { packet.set_msg_type(Message::PktTooBig); packet.set_msg_code(0); packet.set_pkt_too_big_mtu(1500); - packet.payload_mut().copy_from_slice(&PKT_TOO_BIG_IP_PAYLOAD[..]); + packet + .payload_mut() + .copy_from_slice(&PKT_TOO_BIG_IP_PAYLOAD[..]); packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); - assert_eq!(&packet.into_inner()[..], &PKT_TOO_BIG_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PKT_TOO_BIG_BYTES[..]); } #[test] fn test_too_big_repr_parse() { let packet = Packet::new_unchecked(&PKT_TOO_BIG_BYTES[..]); - let repr = Repr::parse(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &packet, &ChecksumCapabilities::default()).unwrap(); + let repr = Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default(), + ) + .unwrap(); assert_eq!(repr, too_big_packet_repr()); } @@ -897,8 +973,12 @@ mod test { let repr = too_big_packet_repr(); let mut bytes = vec![0xa5; repr.buffer_len()]; let mut packet = Packet::new_unchecked(&mut bytes); - repr.emit(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &mut packet, &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &PKT_TOO_BIG_BYTES[..]); + repr.emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &PKT_TOO_BIG_BYTES[..]); } } diff --git a/src/wire/ieee802154.rs b/src/wire/ieee802154.rs new file mode 100644 index 000000000..2431313f8 --- /dev/null +++ b/src/wire/ieee802154.rs @@ -0,0 +1,1085 @@ +use core::fmt; + +use byteorder::{ByteOrder, LittleEndian}; + +use super::{Error, Result}; +use crate::wire::ipv6::Address as Ipv6Address; + +enum_with_unknown! { + /// IEEE 802.15.4 frame type. + pub enum FrameType(u8) { + Beacon = 0b000, + Data = 0b001, + Acknowledgement = 0b010, + MacCommand = 0b011, + Multipurpose = 0b101, + FragmentOrFrak = 0b110, + Extended = 0b111, + } +} + +impl fmt::Display for FrameType { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + FrameType::Beacon => write!(f, "Beacon"), + FrameType::Data => write!(f, "Data"), + FrameType::Acknowledgement => write!(f, "Ack"), + FrameType::MacCommand => write!(f, "MAC command"), + FrameType::Multipurpose => write!(f, "Multipurpose"), + FrameType::FragmentOrFrak => write!(f, "FragmentOrFrak"), + FrameType::Extended => write!(f, "Extended"), + FrameType::Unknown(id) => write!(f, "0b{id:04b}"), + } + } +} +enum_with_unknown! { + /// IEEE 802.15.4 addressing mode for destination and source addresses. + pub enum AddressingMode(u8) { + Absent = 0b00, + Short = 0b10, + Extended = 0b11, + } +} + +impl AddressingMode { + /// Return the size in octets of the address. + const fn size(&self) -> usize { + match self { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + AddressingMode::Unknown(_) => 0, // TODO(thvdveld): what do we need to here? + } + } +} + +impl fmt::Display for AddressingMode { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + AddressingMode::Absent => write!(f, "Absent"), + AddressingMode::Short => write!(f, "Short"), + AddressingMode::Extended => write!(f, "Extended"), + AddressingMode::Unknown(id) => write!(f, "0b{id:04b}"), + } + } +} + +/// A IEEE 802.15.4 PAN. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub struct Pan(pub u16); + +impl Pan { + pub const BROADCAST: Self = Self(0xffff); + + /// Return the PAN ID as bytes. + pub fn as_bytes(&self) -> [u8; 2] { + let mut pan = [0u8; 2]; + LittleEndian::write_u16(&mut pan, self.0); + pan + } +} + +impl fmt::Display for Pan { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "{:0x}", self.0) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Pan { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{:02x}", self.0) + } +} + +/// A IEEE 802.15.4 address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +pub enum Address { + Absent, + Short([u8; 2]), + Extended([u8; 8]), +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + match self { + Self::Absent => defmt::write!(f, "not-present"), + Self::Short(bytes) => defmt::write!(f, "{:02x}:{:02x}", bytes[0], bytes[1]), + Self::Extended(bytes) => defmt::write!( + f, + "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + bytes[0], + bytes[1], + bytes[2], + bytes[3], + bytes[4], + bytes[5], + bytes[6], + bytes[7] + ), + } + } +} + +#[cfg(test)] +impl Default for Address { + fn default() -> Self { + Address::Extended([0u8; 8]) + } +} + +impl Address { + /// The broadcast address. + pub const BROADCAST: Address = Address::Short([0xff; 2]); + + /// Query whether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + !self.is_broadcast() + } + + /// Query whether this address is the broadcast address. + pub fn is_broadcast(&self) -> bool { + *self == Self::BROADCAST + } + + const fn short_from_bytes(a: [u8; 2]) -> Self { + Self::Short(a) + } + + const fn extended_from_bytes(a: [u8; 8]) -> Self { + Self::Extended(a) + } + + pub fn from_bytes(a: &[u8]) -> Self { + if a.len() == 2 { + let mut b = [0u8; 2]; + b.copy_from_slice(a); + Address::Short(b) + } else if a.len() == 8 { + let mut b = [0u8; 8]; + b.copy_from_slice(a); + Address::Extended(b) + } else { + panic!("Not an IEEE802.15.4 address"); + } + } + + pub const fn as_bytes(&self) -> &[u8] { + match self { + Address::Absent => &[], + Address::Short(value) => value, + Address::Extended(value) => value, + } + } + + /// Convert the extended address to an Extended Unique Identifier (EUI-64) + pub fn as_eui_64(&self) -> Option<[u8; 8]> { + match self { + Address::Absent | Address::Short(_) => None, + Address::Extended(value) => { + let mut bytes = [0; 8]; + bytes.copy_from_slice(&value[..]); + + bytes[0] ^= 1 << 1; + + Some(bytes) + } + } + } + + /// Convert an extended address to a link-local IPv6 address using the EUI-64 format from + /// RFC2464. + pub fn as_link_local_address(&self) -> Option { + let mut bytes = [0; 16]; + bytes[0] = 0xfe; + bytes[1] = 0x80; + bytes[8..].copy_from_slice(&self.as_eui_64()?); + + Some(Ipv6Address::from_bytes(&bytes)) + } +} + +impl fmt::Display for Address { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match self { + Self::Absent => write!(f, "not-present"), + Self::Short(bytes) => write!(f, "{:02x}:{:02x}", bytes[0], bytes[1]), + Self::Extended(bytes) => write!( + f, + "{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}:{:02x}", + bytes[0], bytes[1], bytes[2], bytes[3], bytes[4], bytes[5], bytes[6], bytes[7] + ), + } + } +} + +enum_with_unknown! { + /// IEEE 802.15.4 addressing mode for destination and source addresses. + pub enum FrameVersion(u8) { + Ieee802154_2003 = 0b00, + Ieee802154_2006 = 0b01, + Ieee802154 = 0b10, + } +} + +/// A read/write wrapper around an IEEE 802.15.4 frame buffer. +#[derive(Debug, Clone)] +pub struct Frame> { + buffer: T, +} + +mod field { + use crate::wire::field::*; + + pub const FRAMECONTROL: Field = 0..2; + pub const SEQUENCE_NUMBER: usize = 2; + pub const ADDRESSING: Rest = 3..; +} + +macro_rules! fc_bit_field { + ($field:ident, $bit:literal) => { + #[inline] + pub fn $field(&self) -> bool { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + + ((raw >> $bit) & 0b1) == 0b1 + } + }; +} + +macro_rules! set_fc_bit_field { + ($field:ident, $bit:literal) => { + #[inline] + pub fn $field(&mut self, val: bool) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + raw |= ((val as u16) << $bit); + + data.copy_from_slice(&raw.to_le_bytes()); + } + }; +} + +impl> Frame { + /// Input a raw octet buffer with Ethernet frame structure. + pub const fn new_unchecked(buffer: T) -> Frame { + Frame { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result> { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + if matches!(packet.dst_addressing_mode(), AddressingMode::Unknown(_)) { + return Err(Error); + } + + if matches!(packet.src_addressing_mode(), AddressingMode::Unknown(_)) { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + // We need at least 3 bytes + if self.buffer.as_ref().len() < 3 { + return Err(Error); + } + + let mut offset = field::ADDRESSING.start + 2; + + // Calculate the size of the addressing field. + offset += self.dst_addressing_mode().size(); + offset += self.src_addressing_mode().size(); + + if !self.pan_id_compression() { + offset += 2; + } + + if offset > self.buffer.as_ref().len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the FrameType field. + #[inline] + pub fn frame_type(&self) -> FrameType { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let ft = (raw & 0b11) as u8; + FrameType::from(ft) + } + + fc_bit_field!(security_enabled, 3); + fc_bit_field!(frame_pending, 4); + fc_bit_field!(ack_request, 5); + fc_bit_field!(pan_id_compression, 6); + + fc_bit_field!(sequence_number_suppression, 8); + fc_bit_field!(ie_present, 9); + + /// Return the destination addressing mode. + #[inline] + pub fn dst_addressing_mode(&self) -> AddressingMode { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let am = ((raw >> 10) & 0b11) as u8; + AddressingMode::from(am) + } + + /// Return the frame version. + #[inline] + pub fn frame_version(&self) -> FrameVersion { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let fv = ((raw >> 12) & 0b11) as u8; + FrameVersion::from(fv) + } + + /// Return the source addressing mode. + #[inline] + pub fn src_addressing_mode(&self) -> AddressingMode { + let data = self.buffer.as_ref(); + let raw = LittleEndian::read_u16(&data[field::FRAMECONTROL]); + let am = ((raw >> 14) & 0b11) as u8; + AddressingMode::from(am) + } + + /// Return the sequence number of the frame. + #[inline] + pub fn sequence_number(&self) -> Option { + match self.frame_type() { + FrameType::Beacon + | FrameType::Data + | FrameType::Acknowledgement + | FrameType::MacCommand + | FrameType::Multipurpose => { + let data = self.buffer.as_ref(); + let raw = data[field::SEQUENCE_NUMBER]; + Some(raw) + } + FrameType::Extended | FrameType::FragmentOrFrak | FrameType::Unknown(_) => None, + } + } + + /// Return the addressing fields. + #[inline] + fn addressing_fields(&self) -> Option<&[u8]> { + match self.frame_type() { + FrameType::Beacon + | FrameType::Data + | FrameType::MacCommand + | FrameType::Multipurpose => (), + FrameType::Acknowledgement if self.frame_version() == FrameVersion::Ieee802154 => (), + FrameType::Acknowledgement + | FrameType::Extended + | FrameType::FragmentOrFrak + | FrameType::Unknown(_) => return None, + } + + let mut offset = 2; + + // Calculate the size of the addressing field. + offset += self.dst_addressing_mode().size(); + offset += self.src_addressing_mode().size(); + + if !self.pan_id_compression() { + offset += 2; + } + + Some(&self.buffer.as_ref()[field::ADDRESSING][..offset]) + } + + /// Return the destination PAN field. + #[inline] + pub fn dst_pan_id(&self) -> Option { + let addressing_fields = self.addressing_fields()?; + match self.dst_addressing_mode() { + AddressingMode::Absent => None, + AddressingMode::Short | AddressingMode::Extended => { + Some(Pan(LittleEndian::read_u16(&addressing_fields[0..2]))) + } + AddressingMode::Unknown(_) => None, + } + } + + /// Return the destination address field. + #[inline] + pub fn dst_addr(&self) -> Option

{ + let addressing_fields = self.addressing_fields()?; + match self.dst_addressing_mode() { + AddressingMode::Absent => Some(Address::Absent), + AddressingMode::Short => { + let mut raw = [0u8; 2]; + raw.clone_from_slice(&addressing_fields[2..4]); + raw.reverse(); + Some(Address::short_from_bytes(raw)) + } + AddressingMode::Extended => { + let mut raw = [0u8; 8]; + raw.clone_from_slice(&addressing_fields[2..10]); + raw.reverse(); + Some(Address::extended_from_bytes(raw)) + } + AddressingMode::Unknown(_) => None, + } + } + + /// Return the destination PAN field. + #[inline] + pub fn src_pan_id(&self) -> Option { + if self.pan_id_compression() { + return None; + } + + let addressing_fields = self.addressing_fields()?; + let offset = self.dst_addressing_mode().size() + 2; + + match self.src_addressing_mode() { + AddressingMode::Absent => None, + AddressingMode::Short | AddressingMode::Extended => Some(Pan(LittleEndian::read_u16( + &addressing_fields[offset..offset + 2], + ))), + AddressingMode::Unknown(_) => None, + } + } + + /// Return the source address field. + #[inline] + pub fn src_addr(&self) -> Option
{ + let addressing_fields = self.addressing_fields()?; + let mut offset = match self.dst_addressing_mode() { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + _ => return None, // TODO(thvdveld): what do we do here? + } + 2; + + if !self.pan_id_compression() { + offset += 2; + } + + match self.src_addressing_mode() { + AddressingMode::Absent => Some(Address::Absent), + AddressingMode::Short => { + let mut raw = [0u8; 2]; + raw.clone_from_slice(&addressing_fields[offset..offset + 2]); + raw.reverse(); + Some(Address::short_from_bytes(raw)) + } + AddressingMode::Extended => { + let mut raw = [0u8; 8]; + raw.clone_from_slice(&addressing_fields[offset..offset + 8]); + raw.reverse(); + Some(Address::extended_from_bytes(raw)) + } + AddressingMode::Unknown(_) => None, + } + } + + /// Return the index where the auxiliary security header starts. + fn aux_security_header_start(&self) -> usize { + // We start with 3, because 2 bytes for frame control and the sequence number. + let mut index = 3; + index += self.addressing_fields().unwrap().len(); + index + } + + /// Return the index where the payload starts. + fn payload_start(&self) -> usize { + let mut index = self.aux_security_header_start(); + + if self.security_enabled() { + // We add 5 because 1 byte for control bits and 4 bytes for frame counter. + index += 5; + index += if let Some(len) = self.key_identifier_length() { + len as usize + } else { + 0 + }; + } + + index + } + + /// Return the length of the key identifier field. + fn key_identifier_length(&self) -> Option { + Some(match self.key_identifier_mode() { + 0 => 0, + 1 => 1, + 2 => 5, + 3 => 9, + _ => return None, + }) + } + + /// Return the security level of the auxiliary security header. + pub fn security_level(&self) -> u8 { + let index = self.aux_security_header_start(); + let b = self.buffer.as_ref()[index..][0]; + b & 0b111 + } + + /// Return the key identifier mode used by the auxiliary security header. + pub fn key_identifier_mode(&self) -> u8 { + let index = self.aux_security_header_start(); + let b = self.buffer.as_ref()[index..][0]; + (b >> 3) & 0b11 + } + + /// Return the frame counter field. + pub fn frame_counter(&self) -> u32 { + let index = self.aux_security_header_start(); + let b = &self.buffer.as_ref()[index..]; + LittleEndian::read_u32(&b[1..1 + 4]) + } + + /// Return the Key Identifier field. + fn key_identifier(&self) -> &[u8] { + let index = self.aux_security_header_start(); + let b = &self.buffer.as_ref()[index..]; + let length = if let Some(len) = self.key_identifier_length() { + len as usize + } else { + 0 + }; + &b[5..][..length] + } + + /// Return the Key Source field. + pub fn key_source(&self) -> Option<&[u8]> { + let ki = self.key_identifier(); + let len = ki.len(); + if len > 1 { + Some(&ki[..len - 1]) + } else { + None + } + } + + /// Return the Key Index field. + pub fn key_index(&self) -> Option { + let ki = self.key_identifier(); + let len = ki.len(); + + if len > 0 { + Some(ki[len - 1]) + } else { + None + } + } + + /// Return the Message Integrity Code (MIC). + pub fn message_integrity_code(&self) -> Option<&[u8]> { + let mic_len = match self.security_level() { + 0 | 4 => return None, + 1 | 5 => 4, + 2 | 6 => 8, + 3 | 7 => 16, + _ => panic!(), + }; + + let data = &self.buffer.as_ref(); + let len = data.len(); + + Some(&data[len - mic_len..]) + } + + /// Return the MAC header. + pub fn mac_header(&self) -> &[u8] { + let data = &self.buffer.as_ref(); + &data[..self.payload_start()] + } +} + +impl<'a, T: AsRef<[u8]> + ?Sized> Frame<&'a T> { + /// Return a pointer to the payload. + #[inline] + pub fn payload(&self) -> Option<&'a [u8]> { + match self.frame_type() { + FrameType::Data => { + let index = self.payload_start(); + let data = &self.buffer.as_ref(); + + Some(&data[index..]) + } + _ => None, + } + } +} + +impl + AsMut<[u8]>> Frame { + /// Set the frame type. + #[inline] + pub fn set_frame_type(&mut self, frame_type: FrameType) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b111)) | (u8::from(frame_type) as u16 & 0b111); + data.copy_from_slice(&raw.to_le_bytes()); + } + + set_fc_bit_field!(set_security_enabled, 3); + set_fc_bit_field!(set_frame_pending, 4); + set_fc_bit_field!(set_ack_request, 5); + set_fc_bit_field!(set_pan_id_compression, 6); + + /// Set the frame version. + #[inline] + pub fn set_frame_version(&mut self, version: FrameVersion) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 12)) | ((u8::from(version) as u16 & 0b11) << 12); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Set the frame sequence number. + #[inline] + pub fn set_sequence_number(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::SEQUENCE_NUMBER] = value; + } + + /// Set the destination PAN ID. + #[inline] + pub fn set_dst_pan_id(&mut self, value: Pan) { + // NOTE the destination addressing mode must be different than Absent. + // This is the reason why we set it to Extended. + self.set_dst_addressing_mode(AddressingMode::Extended); + + let data = self.buffer.as_mut(); + data[field::ADDRESSING][..2].copy_from_slice(&value.as_bytes()); + } + + /// Set the destination address. + #[inline] + pub fn set_dst_addr(&mut self, value: Address) { + match value { + Address::Absent => self.set_dst_addressing_mode(AddressingMode::Absent), + Address::Short(mut value) => { + value.reverse(); + self.set_dst_addressing_mode(AddressingMode::Short); + let data = self.buffer.as_mut(); + data[field::ADDRESSING][2..2 + 2].copy_from_slice(&value); + value.reverse(); + } + Address::Extended(mut value) => { + value.reverse(); + self.set_dst_addressing_mode(AddressingMode::Extended); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[2..2 + 8].copy_from_slice(&value); + value.reverse(); + } + } + } + + /// Set the destination addressing mode. + #[inline] + fn set_dst_addressing_mode(&mut self, value: AddressingMode) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 10)) | ((u8::from(value) as u16 & 0b11) << 10); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Set the source PAN ID. + #[inline] + pub fn set_src_pan_id(&mut self, value: Pan) { + let offset = match self.dst_addressing_mode() { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + _ => unreachable!(), + } + 2; + + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 2].copy_from_slice(&value.as_bytes()); + } + + /// Set the source address. + #[inline] + pub fn set_src_addr(&mut self, value: Address) { + let offset = match self.dst_addressing_mode() { + AddressingMode::Absent => 0, + AddressingMode::Short => 2, + AddressingMode::Extended => 8, + _ => unreachable!(), + } + 2; + + let offset = offset + if self.pan_id_compression() { 0 } else { 2 }; + + match value { + Address::Absent => self.set_src_addressing_mode(AddressingMode::Absent), + Address::Short(mut value) => { + value.reverse(); + self.set_src_addressing_mode(AddressingMode::Short); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 2].copy_from_slice(&value); + value.reverse(); + } + Address::Extended(mut value) => { + value.reverse(); + self.set_src_addressing_mode(AddressingMode::Extended); + let data = &mut self.buffer.as_mut()[field::ADDRESSING]; + data[offset..offset + 8].copy_from_slice(&value); + value.reverse(); + } + } + } + + /// Set the source addressing mode. + #[inline] + fn set_src_addressing_mode(&mut self, value: AddressingMode) { + let data = &mut self.buffer.as_mut()[field::FRAMECONTROL]; + let mut raw = LittleEndian::read_u16(data); + + raw = (raw & !(0b11 << 14)) | ((u8::from(value) as u16 & 0b11) << 14); + data.copy_from_slice(&raw.to_le_bytes()); + } + + /// Return a mutable pointer to the payload. + #[inline] + pub fn payload_mut(&mut self) -> Option<&mut [u8]> { + match self.frame_type() { + FrameType::Data => { + let index = self.payload_start(); + let data = self.buffer.as_mut(); + Some(&mut data[index..]) + } + _ => None, + } + } +} + +impl> fmt::Display for Frame { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + write!(f, "IEEE802.15.4 frame type={}", self.frame_type())?; + + if let Some(seq) = self.sequence_number() { + write!(f, " seq={:02x}", seq)?; + } + + if let Some(pan) = self.dst_pan_id() { + write!(f, " dst-pan={}", pan)?; + } + + if let Some(pan) = self.src_pan_id() { + write!(f, " src-pan={}", pan)?; + } + + if let Some(addr) = self.dst_addr() { + write!(f, " dst={}", addr)?; + } + + if let Some(addr) = self.src_addr() { + write!(f, " src={}", addr)?; + } + + Ok(()) + } +} + +#[cfg(feature = "defmt")] +impl> defmt::Format for Frame { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "IEEE802.15.4 frame type={}", self.frame_type()); + + if let Some(seq) = self.sequence_number() { + defmt::write!(f, " seq={:02x}", seq); + } + + if let Some(pan) = self.dst_pan_id() { + defmt::write!(f, " dst-pan={}", pan); + } + + if let Some(pan) = self.src_pan_id() { + defmt::write!(f, " src-pan={}", pan); + } + + if let Some(addr) = self.dst_addr() { + defmt::write!(f, " dst={}", addr); + } + + if let Some(addr) = self.src_addr() { + defmt::write!(f, " src={}", addr); + } + } +} + +/// A high-level representation of an IEEE802.15.4 frame. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr { + pub frame_type: FrameType, + pub security_enabled: bool, + pub frame_pending: bool, + pub ack_request: bool, + pub sequence_number: Option, + pub pan_id_compression: bool, + pub frame_version: FrameVersion, + pub dst_pan_id: Option, + pub dst_addr: Option
, + pub src_pan_id: Option, + pub src_addr: Option
, +} + +impl Repr { + /// Parse an IEEE 802.15.4 frame and return a high-level representation. + pub fn parse + ?Sized>(packet: &Frame<&T>) -> Result { + // Ensure the basic accessors will work. + packet.check_len()?; + + Ok(Repr { + frame_type: packet.frame_type(), + security_enabled: packet.security_enabled(), + frame_pending: packet.frame_pending(), + ack_request: packet.ack_request(), + sequence_number: packet.sequence_number(), + pan_id_compression: packet.pan_id_compression(), + frame_version: packet.frame_version(), + dst_pan_id: packet.dst_pan_id(), + dst_addr: packet.dst_addr(), + src_pan_id: packet.src_pan_id(), + src_addr: packet.src_addr(), + }) + } + + /// Return the length of a buffer required to hold a packet with the payload of a given length. + #[inline] + pub const fn buffer_len(&self) -> usize { + 3 + 2 + + match self.dst_addr { + Some(Address::Absent) | None => 0, + Some(Address::Short(_)) => 2, + Some(Address::Extended(_)) => 8, + } + + if !self.pan_id_compression { 2 } else { 0 } + + match self.src_addr { + Some(Address::Absent) | None => 0, + Some(Address::Short(_)) => 2, + Some(Address::Extended(_)) => 8, + } + } + + /// Emit a high-level representation into an IEEE802.15.4 frame. + pub fn emit + AsMut<[u8]>>(&self, frame: &mut Frame) { + frame.set_frame_type(self.frame_type); + frame.set_security_enabled(self.security_enabled); + frame.set_frame_pending(self.frame_pending); + frame.set_ack_request(self.ack_request); + frame.set_pan_id_compression(self.pan_id_compression); + frame.set_frame_version(self.frame_version); + + if let Some(sequence_number) = self.sequence_number { + frame.set_sequence_number(sequence_number); + } + + if let Some(dst_pan_id) = self.dst_pan_id { + frame.set_dst_pan_id(dst_pan_id); + } + if let Some(dst_addr) = self.dst_addr { + frame.set_dst_addr(dst_addr); + } + + if !self.pan_id_compression && self.src_pan_id.is_some() { + frame.set_src_pan_id(self.src_pan_id.unwrap()); + } + + if let Some(src_addr) = self.src_addr { + frame.set_src_addr(src_addr); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn test_broadcast() { + assert!(Address::BROADCAST.is_broadcast()); + assert!(!Address::BROADCAST.is_unicast()); + } + + #[test] + fn prepare_frame() { + let mut buffer = [0u8; 128]; + + let repr = Repr { + frame_type: FrameType::Data, + security_enabled: false, + frame_pending: false, + ack_request: true, + pan_id_compression: true, + frame_version: FrameVersion::Ieee802154, + sequence_number: Some(1), + dst_pan_id: Some(Pan(0xabcd)), + dst_addr: Some(Address::BROADCAST), + src_pan_id: None, + src_addr: Some(Address::Extended([ + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00, + ])), + }; + + let buffer_len = repr.buffer_len(); + + let mut frame = Frame::new_unchecked(&mut buffer[..buffer_len]); + repr.emit(&mut frame); + + println!("{frame:2x?}"); + + assert_eq!(frame.frame_type(), FrameType::Data); + assert!(!frame.security_enabled()); + assert!(!frame.frame_pending()); + assert!(frame.ack_request()); + assert!(frame.pan_id_compression()); + assert_eq!(frame.frame_version(), FrameVersion::Ieee802154); + assert_eq!(frame.sequence_number(), Some(1)); + assert_eq!(frame.dst_pan_id(), Some(Pan(0xabcd))); + assert_eq!(frame.dst_addr(), Some(Address::BROADCAST)); + assert_eq!(frame.src_pan_id(), None); + assert_eq!( + frame.src_addr(), + Some(Address::Extended([ + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00 + ])) + ); + } + + macro_rules! vector_test { + ($name:ident $bytes:expr ; $($test_method:ident -> $expected:expr,)*) => { + #[test] + #[allow(clippy::bool_assert_comparison)] + fn $name() -> Result<()> { + let frame = &$bytes; + let frame = Frame::new_checked(frame)?; + + $( + assert_eq!(frame.$test_method(), $expected, stringify!($test_method)); + )* + + Ok(()) + } + } + } + + vector_test! { + extended_addr + [ + 0b0000_0001, 0b1100_1100, // frame control + 0b0, // seq + 0xcd, 0xab, // pan id + 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, // dst addr + 0x03, 0x04, // pan id + 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x02, // src addr + ]; + frame_type -> FrameType::Data, + dst_addr -> Some(Address::Extended([0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00])), + src_addr -> Some(Address::Extended([0x02, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00])), + dst_pan_id -> Some(Pan(0xabcd)), + } + + vector_test! { + short_addr + [ + 0x01, 0x98, // frame control + 0x00, // sequence number + 0x34, 0x12, 0x78, 0x56, // PAN identifier and address of destination + 0x34, 0x12, 0xbc, 0x9a, // PAN identifier and address of source + ]; + frame_type -> FrameType::Data, + security_enabled -> false, + frame_pending -> false, + ack_request -> false, + pan_id_compression -> false, + dst_addressing_mode -> AddressingMode::Short, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Short, + dst_pan_id -> Some(Pan(0x1234)), + dst_addr -> Some(Address::Short([0x56, 0x78])), + src_pan_id -> Some(Pan(0x1234)), + src_addr -> Some(Address::Short([0x9a, 0xbc])), + } + + vector_test! { + zolertia_remote + [ + 0x41, 0xd8, // frame control + 0x01, // sequence number + 0xcd, 0xab, // Destination PAN id + 0xff, 0xff, // Short destination address + 0xc7, 0xd9, 0xb5, 0x14, 0x00, 0x4b, 0x12, 0x00, // Extended source address + 0x2b, 0x00, 0x00, 0x00, // payload + ]; + frame_type -> FrameType::Data, + security_enabled -> false, + frame_pending -> false, + ack_request -> false, + pan_id_compression -> true, + dst_addressing_mode -> AddressingMode::Short, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Extended, + payload -> Some(&[0x2b, 0x00, 0x00, 0x00][..]), + } + + vector_test! { + security + [ + 0x69,0xdc, // frame control + 0x32, // sequence number + 0xcd,0xab, // destination PAN id + 0xbf,0x9b,0x15,0x06,0x00,0x4b,0x12,0x00, // extended destination address + 0xc7,0xd9,0xb5,0x14,0x00,0x4b,0x12,0x00, // extended source address + 0x05, // security control field + 0x31,0x01,0x00,0x00, // frame counter + 0x3e,0xe8,0xfb,0x85,0xe4,0xcc,0xf4,0x48,0x90,0xfe,0x56,0x66,0xf7,0x1c,0x65,0x9e,0xf9, // data + 0x93,0xc8,0x34,0x2e,// MIC + ]; + frame_type -> FrameType::Data, + security_enabled -> true, + frame_pending -> false, + ack_request -> true, + pan_id_compression -> true, + dst_addressing_mode -> AddressingMode::Extended, + frame_version -> FrameVersion::Ieee802154_2006, + src_addressing_mode -> AddressingMode::Extended, + dst_pan_id -> Some(Pan(0xabcd)), + dst_addr -> Some(Address::Extended([0x00,0x12,0x4b,0x00,0x06,0x15,0x9b,0xbf])), + src_pan_id -> None, + src_addr -> Some(Address::Extended([0x00,0x12,0x4b,0x00,0x14,0xb5,0xd9,0xc7])), + security_level -> 5, + key_identifier_mode -> 0, + frame_counter -> 305, + key_source -> None, + key_index -> None, + payload -> Some(&[0x3e,0xe8,0xfb,0x85,0xe4,0xcc,0xf4,0x48,0x90,0xfe,0x56,0x66,0xf7,0x1c,0x65,0x9e,0xf9,0x93,0xc8,0x34,0x2e][..]), + message_integrity_code -> Some(&[0x93, 0xC8, 0x34, 0x2E][..]), + mac_header -> &[ + 0x69,0xdc, // frame control + 0x32, // sequence number + 0xcd,0xab, // destination PAN id + 0xbf,0x9b,0x15,0x06,0x00,0x4b,0x12,0x00, // extended destination address + 0xc7,0xd9,0xb5,0x14,0x00,0x4b,0x12,0x00, // extended source address + 0x05, // security control field + 0x31,0x01,0x00,0x00, // frame counter + ][..], + } +} diff --git a/src/wire/igmp.rs b/src/wire/igmp.rs index 49440a000..ac13ece1a 100644 --- a/src/wire/igmp.rs +++ b/src/wire/igmp.rs @@ -1,15 +1,15 @@ -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; -use {Error, Result}; -use super::ip::checksum; -use time::Duration; +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::ip::checksum; -use wire::Ipv4Address; +use crate::wire::Ipv4Address; enum_with_unknown! { /// Internet Group Management Protocol v1/v2 message version/type. - pub doc enum Message(u8) { + pub enum Message(u8) { /// Membership Query MembershipQuery = 0x11, /// Version 2 Membership Report @@ -23,12 +23,13 @@ enum_with_unknown! { /// A read/write wrapper around an Internet Group Management Protocol v1/v2 packet buffer. #[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { buffer: T, } mod field { - use wire::field::*; + use crate::wire::field::*; pub const TYPE: usize = 0; pub const MAX_RESP_CODE: usize = 1; @@ -38,12 +39,12 @@ mod field { impl fmt::Display for Message { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Message::MembershipQuery => write!(f, "membership query"), - &Message::MembershipReportV2 => write!(f, "version 2 membership report"), - &Message::LeaveGroup => write!(f, "leave group"), - &Message::MembershipReportV1 => write!(f, "version 1 membership report"), - &Message::Unknown(id) => write!(f, "{}", id), + match *self { + Message::MembershipQuery => write!(f, "membership query"), + Message::MembershipReportV2 => write!(f, "version 2 membership report"), + Message::LeaveGroup => write!(f, "leave group"), + Message::MembershipReportV1 => write!(f, "version 1 membership report"), + Message::Unknown(id) => write!(f, "{id}"), } } } @@ -53,7 +54,7 @@ impl fmt::Display for Message { /// [RFC 2236]: https://tools.ietf.org/html/rfc2236 impl> Packet { /// Imbue a raw octet buffer with IGMPv2 packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -68,11 +69,11 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); - if len < field::GROUP_ADDRESS.end as usize { - Err(Error::Truncated) + if len < field::GROUP_ADDRESS.end { + Err(Error) } else { Ok(()) } @@ -171,6 +172,7 @@ impl + AsMut<[u8]>> Packet { /// A high-level representation of an Internet Group Management Protocol v1/v2 header. #[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr { MembershipQuery { max_resp_time: Duration, @@ -188,6 +190,7 @@ pub enum Repr { /// Type of IGMP membership report version #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum IgmpVersion { /// IGMPv1 Version1, @@ -199,12 +202,13 @@ impl Repr { /// Parse an Internet Group Management Protocol v1/v2 packet and return /// a high-level representation. pub fn parse(packet: &Packet<&T>) -> Result - where T: AsRef<[u8]> + ?Sized + where + T: AsRef<[u8]> + ?Sized, { // Check if the address is 0.0.0.0 or multicast let addr = packet.group_addr(); if !addr.is_unspecified() && !addr.is_multicast() { - return Err(Error::Malformed); + return Err(Error); } // construct a packet based on the Type field @@ -223,13 +227,13 @@ impl Repr { version, }) } - Message::MembershipReportV2 => { - Ok(Repr::MembershipReport { - group_addr: packet.group_addr(), - version: IgmpVersion::Version2, - }) - } - Message::LeaveGroup => Ok(Repr::LeaveGroup { group_addr: packet.group_addr() }), + Message::MembershipReportV2 => Ok(Repr::MembershipReport { + group_addr: packet.group_addr(), + version: IgmpVersion::Version2, + }), + Message::LeaveGroup => Ok(Repr::LeaveGroup { + group_addr: packet.group_addr(), + }), Message::MembershipReportV1 => { // for backwards compatibility with IGMPv1 Ok(Repr::MembershipReport { @@ -237,36 +241,37 @@ impl Repr { version: IgmpVersion::Version1, }) } - _ => Err(Error::Unrecognized), + _ => Err(Error), } } /// Return the length of a packet that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { // always 8 bytes field::GROUP_ADDRESS.end } /// Emit a high-level representation into an Internet Group Management Protocol v2 packet. pub fn emit(&self, packet: &mut Packet<&mut T>) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, { - match self { - &Repr::MembershipQuery { + match *self { + Repr::MembershipQuery { max_resp_time, group_addr, - version + version, } => { packet.set_msg_type(Message::MembershipQuery); match version { - IgmpVersion::Version1 => - packet.set_max_resp_code(0), - IgmpVersion::Version2 => - packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time)), + IgmpVersion::Version1 => packet.set_max_resp_code(0), + IgmpVersion::Version2 => { + packet.set_max_resp_code(duration_to_max_resp_code(max_resp_time)) + } } packet.set_group_address(group_addr); } - &Repr::MembershipReport { + Repr::MembershipReport { group_addr, version, } => { @@ -277,7 +282,7 @@ impl Repr { packet.set_max_resp_code(0); packet.set_group_address(group_addr); } - &Repr::LeaveGroup { group_addr } => { + Repr::LeaveGroup { group_addr } => { packet.set_msg_type(Message::LeaveGroup); packet.set_group_address(group_addr); } @@ -289,22 +294,22 @@ impl Repr { fn max_resp_code_to_duration(value: u8) -> Duration { let value: u64 = value.into(); - let centisecs = if value < 128 { + let decisecs = if value < 128 { value } else { let mant = value & 0xF; let exp = (value >> 4) & 0x7; (mant | 0x10) << (exp + 3) }; - Duration::from_millis(centisecs * 100) + Duration::from_millis(decisecs * 100) } -fn duration_to_max_resp_code(duration: Duration) -> u8 { - let centisecs = duration.total_millis() / 100; - if centisecs < 128 { - centisecs as u8 - } else if centisecs < 31744 { - let mut mant = centisecs >> 3; +const fn duration_to_max_resp_code(duration: Duration) -> u8 { + let decisecs = duration.total_millis() / 100; + if decisecs < 128 { + decisecs as u8 + } else if decisecs < 31744 { + let mut mant = decisecs >> 3; let mut exp = 0u8; while mant > 0x1F && exp < 0x8 { mant >>= 1; @@ -319,52 +324,48 @@ fn duration_to_max_resp_code(duration: Duration) -> u8 { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), - Err(err) => write!(f, "IGMP ({})", err), + Ok(repr) => write!(f, "{repr}"), + Err(err) => write!(f, "IGMP ({err})"), } } } -impl<'a> fmt::Display for Repr { +impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Repr::MembershipQuery { + match *self { + Repr::MembershipQuery { max_resp_time, group_addr, version, - } => { - write!(f, - "IGMP membership query max_resp_time={} group_addr={} version={:?}", - max_resp_time, - group_addr, - version) - } - &Repr::MembershipReport { + } => write!( + f, + "IGMP membership query max_resp_time={max_resp_time} group_addr={group_addr} version={version:?}" + ), + Repr::MembershipReport { group_addr, version, - } => { - write!(f, - "IGMP membership report group_addr={} version={:?}", - group_addr, - version) - } - &Repr::LeaveGroup { group_addr } => { - write!(f, "IGMP leave group group_addr={})", group_addr) + } => write!( + f, + "IGMP membership report group_addr={group_addr} version={version:?}" + ), + Repr::LeaveGroup { group_addr } => { + write!(f, "IGMP leave group group_addr={group_addr})") } } } } -use super::pretty_print::{PrettyIndent, PrettyPrint}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, - f: &mut fmt::Formatter, - indent: &mut PrettyIndent) - -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { match Packet::new_checked(buffer) { - Err(err) => write!(f, "{}({})\n", indent, err), - Ok(packet) => write!(f, "{}{}\n", indent, packet), + Err(err) => writeln!(f, "{indent}({err})"), + Ok(packet) => writeln!(f, "{indent}{packet}"), } } } @@ -373,7 +374,6 @@ impl> PrettyPrint for Packet { mod test { use super::*; - static LEAVE_PACKET_BYTES: [u8; 8] = [0x17, 0x00, 0x02, 0x69, 0xe0, 0x00, 0x06, 0x96]; static REPORT_PACKET_BYTES: [u8; 8] = [0x16, 0x00, 0x08, 0xda, 0xe1, 0x00, 0x00, 0x25]; @@ -383,9 +383,11 @@ mod test { assert_eq!(packet.msg_type(), Message::LeaveGroup); assert_eq!(packet.max_resp_code(), 0); assert_eq!(packet.checksum(), 0x269); - assert_eq!(packet.group_addr(), - Ipv4Address::from_bytes(&[224, 0, 6, 150])); - assert_eq!(packet.verify_checksum(), true); + assert_eq!( + packet.group_addr(), + Ipv4Address::from_bytes(&[224, 0, 6, 150]) + ); + assert!(packet.verify_checksum()); } #[test] @@ -394,9 +396,11 @@ mod test { assert_eq!(packet.msg_type(), Message::MembershipReportV2); assert_eq!(packet.max_resp_code(), 0); assert_eq!(packet.checksum(), 0x08da); - assert_eq!(packet.group_addr(), - Ipv4Address::from_bytes(&[225, 0, 0, 37])); - assert_eq!(packet.verify_checksum(), true); + assert_eq!( + packet.group_addr(), + Ipv4Address::from_bytes(&[225, 0, 0, 37]) + ); + assert!(packet.verify_checksum()); } #[test] @@ -407,7 +411,7 @@ mod test { packet.set_max_resp_code(0); packet.set_group_address(Ipv4Address::from_bytes(&[224, 0, 6, 150])); packet.fill_checksum(); - assert_eq!(&packet.into_inner()[..], &LEAVE_PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &LEAVE_PACKET_BYTES[..]); } #[test] @@ -418,7 +422,7 @@ mod test { packet.set_max_resp_code(0); packet.set_group_address(Ipv4Address::from_bytes(&[225, 0, 0, 37])); packet.fill_checksum(); - assert_eq!(&packet.into_inner()[..], &REPORT_PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); } #[test] @@ -434,9 +438,7 @@ mod test { #[test] fn duration_to_max_resp_time_max() { for duration in 31744..65536 { - let time = duration_to_max_resp_code( - Duration::from_millis(duration * 100) - ); + let time = duration_to_max_resp_code(Duration::from_millis(duration * 100)); assert_eq!(time, 0xFF); } } diff --git a/src/wire/ip.rs b/src/wire/ip.rs index d29808685..1c6baa475 100644 --- a/src/wire/ip.rs +++ b/src/wire/ip.rs @@ -1,50 +1,46 @@ -use core::fmt; use core::convert::From; +use core::fmt; -use {Error, Result}; -use phy::ChecksumCapabilities; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; #[cfg(feature = "proto-ipv4")] -use super::{Ipv4Address, Ipv4Packet, Ipv4Repr, Ipv4Cidr}; +use crate::wire::{Ipv4Address, Ipv4Cidr, Ipv4Packet, Ipv4Repr}; #[cfg(feature = "proto-ipv6")] -use super::{Ipv6Address, Ipv6Cidr, Ipv6Packet, Ipv6Repr}; +use crate::wire::{Ipv6Address, Ipv6Cidr, Ipv6Packet, Ipv6Repr}; /// Internet protocol version. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Version { - Unspecified, #[cfg(feature = "proto-ipv4")] Ipv4, #[cfg(feature = "proto-ipv6")] Ipv6, - #[doc(hidden)] - __Nonexhaustive, } impl Version { /// Return the version of an IP packet stored in the provided buffer. /// /// This function never returns `Ok(IpVersion::Unspecified)`; instead, - /// unknown versions result in `Err(Error::Unrecognized)`. - pub fn of_packet(data: &[u8]) -> Result { + /// unknown versions result in `Err(Error)`. + pub const fn of_packet(data: &[u8]) -> Result { match data[0] >> 4 { #[cfg(feature = "proto-ipv4")] 4 => Ok(Version::Ipv4), #[cfg(feature = "proto-ipv6")] 6 => Ok(Version::Ipv6), - _ => Err(Error::Unrecognized) + _ => Err(Error), } } } impl fmt::Display for Version { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Version::Unspecified => write!(f, "IPv?"), + match *self { #[cfg(feature = "proto-ipv4")] - &Version::Ipv4 => write!(f, "IPv4"), + Version::Ipv4 => write!(f, "IPv4"), #[cfg(feature = "proto-ipv6")] - &Version::Ipv6 => write!(f, "IPv6"), - &Version::__Nonexhaustive => unreachable!() + Version::Ipv6 => write!(f, "IPv6"), } } } @@ -67,18 +63,18 @@ enum_with_unknown! { impl fmt::Display for Protocol { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Protocol::HopByHop => write!(f, "Hop-by-Hop"), - &Protocol::Icmp => write!(f, "ICMP"), - &Protocol::Igmp => write!(f, "IGMP"), - &Protocol::Tcp => write!(f, "TCP"), - &Protocol::Udp => write!(f, "UDP"), - &Protocol::Ipv6Route => write!(f, "IPv6-Route"), - &Protocol::Ipv6Frag => write!(f, "IPv6-Frag"), - &Protocol::Icmpv6 => write!(f, "ICMPv6"), - &Protocol::Ipv6NoNxt => write!(f, "IPv6-NoNxt"), - &Protocol::Ipv6Opts => write!(f, "IPv6-Opts"), - &Protocol::Unknown(id) => write!(f, "0x{:02x}", id) + match *self { + Protocol::HopByHop => write!(f, "Hop-by-Hop"), + Protocol::Icmp => write!(f, "ICMP"), + Protocol::Igmp => write!(f, "IGMP"), + Protocol::Tcp => write!(f, "TCP"), + Protocol::Udp => write!(f, "UDP"), + Protocol::Ipv6Route => write!(f, "IPv6-Route"), + Protocol::Ipv6Frag => write!(f, "IPv6-Frag"), + Protocol::Icmpv6 => write!(f, "ICMPv6"), + Protocol::Ipv6NoNxt => write!(f, "IPv6-NoNxt"), + Protocol::Ipv6Opts => write!(f, "IPv6-Opts"), + Protocol::Unknown(id) => write!(f, "0x{id:02x}"), } } } @@ -86,108 +82,91 @@ impl fmt::Display for Protocol { /// An internetworking address. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub enum Address { - /// An unspecified address. - /// May be used as a placeholder for storage where the address is not assigned yet. - Unspecified, /// An IPv4 address. #[cfg(feature = "proto-ipv4")] Ipv4(Ipv4Address), /// An IPv6 address. #[cfg(feature = "proto-ipv6")] Ipv6(Ipv6Address), - #[doc(hidden)] - __Nonexhaustive } impl Address { /// Create an address wrapping an IPv4 address with the given octets. #[cfg(feature = "proto-ipv4")] - pub fn v4(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { + pub const fn v4(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { Address::Ipv4(Ipv4Address::new(a0, a1, a2, a3)) } /// Create an address wrapping an IPv6 address with the given octets. #[cfg(feature = "proto-ipv6")] - pub fn v6(a0: u16, a1: u16, a2: u16, a3: u16, - a4: u16, a5: u16, a6: u16, a7: u16) -> Address { + #[allow(clippy::too_many_arguments)] + pub fn v6(a0: u16, a1: u16, a2: u16, a3: u16, a4: u16, a5: u16, a6: u16, a7: u16) -> Address { Address::Ipv6(Ipv6Address::new(a0, a1, a2, a3, a4, a5, a6, a7)) } + /// Return the protocol version. + pub const fn version(&self) -> Version { + match self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(_) => Version::Ipv4, + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(_) => Version::Ipv6, + } + } + /// Return an address as a sequence of octets, in big-endian. - pub fn as_bytes(&self) -> &[u8] { + pub const fn as_bytes(&self) -> &[u8] { match self { - &Address::Unspecified => &[], #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(ref addr) => addr.as_bytes(), + Address::Ipv4(addr) => addr.as_bytes(), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(ref addr) => addr.as_bytes(), - &Address::__Nonexhaustive => unreachable!() + Address::Ipv6(addr) => addr.as_bytes(), } } /// Query whether the address is a valid unicast address. pub fn is_unicast(&self) -> bool { match self { - &Address::Unspecified => false, #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(addr) => addr.is_unicast(), + Address::Ipv4(addr) => addr.is_unicast(), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(addr) => addr.is_unicast(), - &Address::__Nonexhaustive => unreachable!() + Address::Ipv6(addr) => addr.is_unicast(), } } /// Query whether the address is a valid multicast address. - pub fn is_multicast(&self) -> bool { + pub const fn is_multicast(&self) -> bool { match self { - &Address::Unspecified => false, #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(addr) => addr.is_multicast(), + Address::Ipv4(addr) => addr.is_multicast(), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(addr) => addr.is_multicast(), - &Address::__Nonexhaustive => unreachable!() + Address::Ipv6(addr) => addr.is_multicast(), } } /// Query whether the address is the broadcast address. pub fn is_broadcast(&self) -> bool { match self { - &Address::Unspecified => false, #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(addr) => addr.is_broadcast(), + Address::Ipv4(addr) => addr.is_broadcast(), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(_) => false, - &Address::__Nonexhaustive => unreachable!() + Address::Ipv6(_) => false, } } /// Query whether the address falls into the "unspecified" range. pub fn is_unspecified(&self) -> bool { match self { - &Address::Unspecified => true, #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(addr) => addr.is_unspecified(), + Address::Ipv4(addr) => addr.is_unspecified(), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(addr) => addr.is_unspecified(), - &Address::__Nonexhaustive => unreachable!() - } - } - - /// Return an unspecified address that has the same IP version as `self`. - pub fn to_unspecified(&self) -> Address { - match self { - &Address::Unspecified => Address::Unspecified, - #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(_) => Address::Ipv4(Ipv4Address::UNSPECIFIED), - #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(_) => Address::Ipv6(Ipv6Address::UNSPECIFIED), - &Address::__Nonexhaustive => unreachable!() + Address::Ipv6(addr) => addr.is_unspecified(), } } /// If `self` is a CIDR-compatible subnet mask, return `Some(prefix_len)`, /// where `prefix_len` is the number of leading zeroes. Return `None` otherwise. - pub fn to_prefix_len(&self) -> Option { + pub fn prefix_len(&self) -> Option { let mut ones = true; let mut prefix_len = 0; for byte in self.as_bytes() { @@ -201,11 +180,9 @@ impl Address { } else { ones = false; } - } else { - if one { - // 1 where 0 was expected - return None - } + } else if one { + // 1 where 0 was expected + return None; } mask >>= 1; } @@ -224,6 +201,18 @@ impl From<::std::net::IpAddr> for Address { } } +#[cfg(feature = "std")] +impl From
for ::std::net::IpAddr { + fn from(x: Address) -> ::std::net::IpAddr { + match x { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(ipv4) => ::std::net::IpAddr::V4(ipv4.into()), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(ipv6) => ::std::net::IpAddr::V6(ipv6.into()), + } + } +} + #[cfg(all(feature = "std", feature = "proto-ipv4"))] impl From<::std::net::Ipv4Addr> for Address { fn from(ipv4: ::std::net::Ipv4Addr) -> Address { @@ -238,12 +227,6 @@ impl From<::std::net::Ipv6Addr> for Address { } } -impl Default for Address { - fn default() -> Address { - Address::Unspecified - } -} - #[cfg(feature = "proto-ipv4")] impl From for Address { fn from(addr: Ipv4Address) -> Self { @@ -260,13 +243,23 @@ impl From for Address { impl fmt::Display for Address { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + #[cfg(feature = "proto-ipv4")] + Address::Ipv4(addr) => write!(f, "{addr}"), + #[cfg(feature = "proto-ipv6")] + Address::Ipv6(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { match self { - &Address::Unspecified => write!(f, "*"), #[cfg(feature = "proto-ipv4")] - &Address::Ipv4(addr) => write!(f, "{}", addr), + &Address::Ipv4(addr) => defmt::write!(f, "{:?}", addr), #[cfg(feature = "proto-ipv6")] - &Address::Ipv6(addr) => write!(f, "{}", addr), - &Address::__Nonexhaustive => unreachable!() + &Address::Ipv6(addr) => defmt::write!(f, "{:?}", addr), } } } @@ -279,48 +272,39 @@ pub enum Cidr { Ipv4(Ipv4Cidr), #[cfg(feature = "proto-ipv6")] Ipv6(Ipv6Cidr), - #[doc(hidden)] - __Nonexhaustive, } impl Cidr { /// Create a CIDR block from the given address and prefix length. /// /// # Panics - /// This function panics if the given address is unspecified, or - /// the given prefix length is invalid for the given address. + /// This function panics if the given prefix length is invalid for the given address. pub fn new(addr: Address, prefix_len: u8) -> Cidr { match addr { #[cfg(feature = "proto-ipv4")] Address::Ipv4(addr) => Cidr::Ipv4(Ipv4Cidr::new(addr, prefix_len)), #[cfg(feature = "proto-ipv6")] Address::Ipv6(addr) => Cidr::Ipv6(Ipv6Cidr::new(addr, prefix_len)), - Address::Unspecified => - panic!("a CIDR block cannot be based on an unspecified address"), - Address::__Nonexhaustive => - unreachable!() } } /// Return the IP address of this CIDR block. - pub fn address(&self) -> Address { - match self { + pub const fn address(&self) -> Address { + match *self { #[cfg(feature = "proto-ipv4")] - &Cidr::Ipv4(cidr) => Address::Ipv4(cidr.address()), + Cidr::Ipv4(cidr) => Address::Ipv4(cidr.address()), #[cfg(feature = "proto-ipv6")] - &Cidr::Ipv6(cidr) => Address::Ipv6(cidr.address()), - &Cidr::__Nonexhaustive => unreachable!() + Cidr::Ipv6(cidr) => Address::Ipv6(cidr.address()), } } /// Return the prefix length of this CIDR block. - pub fn prefix_len(&self) -> u8 { - match self { + pub const fn prefix_len(&self) -> u8 { + match *self { #[cfg(feature = "proto-ipv4")] - &Cidr::Ipv4(cidr) => cidr.prefix_len(), + Cidr::Ipv4(cidr) => cidr.prefix_len(), #[cfg(feature = "proto-ipv6")] - &Cidr::Ipv6(cidr) => cidr.prefix_len(), - &Cidr::__Nonexhaustive => unreachable!() + Cidr::Ipv6(cidr) => cidr.prefix_len(), } } @@ -329,21 +313,11 @@ impl Cidr { pub fn contains_addr(&self, addr: &Address) -> bool { match (self, addr) { #[cfg(feature = "proto-ipv4")] - (&Cidr::Ipv4(ref cidr), &Address::Ipv4(ref addr)) => - cidr.contains_addr(addr), + (Cidr::Ipv4(cidr), Address::Ipv4(addr)) => cidr.contains_addr(addr), #[cfg(feature = "proto-ipv6")] - (&Cidr::Ipv6(ref cidr), &Address::Ipv6(ref addr)) => - cidr.contains_addr(addr), - #[cfg(all(feature = "proto-ipv6", feature = "proto-ipv4"))] - (&Cidr::Ipv4(_), &Address::Ipv6(_)) | (&Cidr::Ipv6(_), &Address::Ipv4(_)) => - false, - (_, &Address::Unspecified) => - // a fully unspecified address covers both IPv4 and IPv6, - // and no CIDR block can do that. - false, - (&Cidr::__Nonexhaustive, _) | - (_, &Address::__Nonexhaustive) => - unreachable!() + (Cidr::Ipv6(cidr), Address::Ipv6(addr)) => cidr.contains_addr(addr), + #[allow(unreachable_patterns)] + _ => false, } } @@ -352,17 +326,11 @@ impl Cidr { pub fn contains_subnet(&self, subnet: &Cidr) -> bool { match (self, subnet) { #[cfg(feature = "proto-ipv4")] - (&Cidr::Ipv4(ref cidr), &Cidr::Ipv4(ref other)) => - cidr.contains_subnet(other), + (Cidr::Ipv4(cidr), Cidr::Ipv4(other)) => cidr.contains_subnet(other), #[cfg(feature = "proto-ipv6")] - (&Cidr::Ipv6(ref cidr), &Cidr::Ipv6(ref other)) => - cidr.contains_subnet(other), - #[cfg(all(feature = "proto-ipv6", feature = "proto-ipv4"))] - (&Cidr::Ipv4(_), &Cidr::Ipv6(_)) | (&Cidr::Ipv6(_), &Cidr::Ipv4(_)) => - false, - (&Cidr::__Nonexhaustive, _) | - (_, &Cidr::__Nonexhaustive) => - unreachable!() + (Cidr::Ipv6(cidr), Cidr::Ipv6(other)) => cidr.contains_subnet(other), + #[allow(unreachable_patterns)] + _ => false, } } } @@ -383,37 +351,43 @@ impl From for Cidr { impl fmt::Display for Cidr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + match *self { + #[cfg(feature = "proto-ipv4")] + Cidr::Ipv4(cidr) => write!(f, "{cidr}"), + #[cfg(feature = "proto-ipv6")] + Cidr::Ipv6(cidr) => write!(f, "{cidr}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { match self { #[cfg(feature = "proto-ipv4")] - &Cidr::Ipv4(cidr) => write!(f, "{}", cidr), + &Cidr::Ipv4(cidr) => defmt::write!(f, "{:?}", cidr), #[cfg(feature = "proto-ipv6")] - &Cidr::Ipv6(cidr) => write!(f, "{}", cidr), - &Cidr::__Nonexhaustive => unreachable!() + &Cidr::Ipv6(cidr) => defmt::write!(f, "{:?}", cidr), } } } /// An internet endpoint address. /// -/// An endpoint can be constructed from a port, in which case the address is unspecified. -#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +/// `Endpoint` always fully specifies both the address and the port. +/// +/// See also ['ListenEndpoint'], which allows not specifying the address +/// in order to listen on a given port on any address. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] pub struct Endpoint { pub addr: Address, - pub port: u16 + pub port: u16, } impl Endpoint { - /// An endpoint with unspecified address and port. - pub const UNSPECIFIED: Endpoint = Endpoint { addr: Address::Unspecified, port: 0 }; - /// Create an endpoint address from given address and port. - pub fn new(addr: Address, port: u16) -> Endpoint { - Endpoint { addr: addr, port: port } - } - - /// Query whether the endpoint has a specified address and port. - pub fn is_specified(&self) -> bool { - !self.addr.is_unspecified() && self.port != 0 + pub const fn new(addr: Address, port: u16) -> Endpoint { + Endpoint { addr: addr, port } } } @@ -431,7 +405,7 @@ impl From<::std::net::SocketAddr> for Endpoint { impl From<::std::net::SocketAddrV4> for Endpoint { fn from(x: ::std::net::SocketAddrV4) -> Endpoint { Endpoint { - addr: x.ip().clone().into(), + addr: (*x.ip()).into(), port: x.port(), } } @@ -441,7 +415,7 @@ impl From<::std::net::SocketAddrV4> for Endpoint { impl From<::std::net::SocketAddrV6> for Endpoint { fn from(x: ::std::net::SocketAddrV6) -> Endpoint { Endpoint { - addr: x.ip().clone().into(), + addr: (*x.ip()).into(), port: x.port(), } } @@ -453,38 +427,123 @@ impl fmt::Display for Endpoint { } } -impl From for Endpoint { - fn from(port: u16) -> Endpoint { - Endpoint { addr: Address::Unspecified, port: port } +#[cfg(feature = "defmt")] +impl defmt::Format for Endpoint { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{:?}:{=u16}", self.addr, self.port); } } impl> From<(T, u16)> for Endpoint { fn from((addr, port): (T, u16)) -> Endpoint { - Endpoint { addr: addr.into(), port: port } + Endpoint { + addr: addr.into(), + port, + } + } +} + +/// An internet endpoint address for listening. +/// +/// In contrast with [`Endpoint`], `ListenEndpoint` allows not specifying the address, +/// in order to listen on a given port at all our addresses. +/// +/// An endpoint can be constructed from a port, in which case the address is unspecified. +#[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] +pub struct ListenEndpoint { + pub addr: Option
, + pub port: u16, +} + +impl ListenEndpoint { + /// Query whether the endpoint has a specified address and port. + pub const fn is_specified(&self) -> bool { + self.addr.is_some() && self.port != 0 + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddr> for ListenEndpoint { + fn from(x: ::std::net::SocketAddr) -> ListenEndpoint { + ListenEndpoint { + addr: Some(x.ip().into()), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv4"))] +impl From<::std::net::SocketAddrV4> for ListenEndpoint { + fn from(x: ::std::net::SocketAddrV4) -> ListenEndpoint { + ListenEndpoint { + addr: Some((*x.ip()).into()), + port: x.port(), + } + } +} + +#[cfg(all(feature = "std", feature = "proto-ipv6"))] +impl From<::std::net::SocketAddrV6> for ListenEndpoint { + fn from(x: ::std::net::SocketAddrV6) -> ListenEndpoint { + ListenEndpoint { + addr: Some((*x.ip()).into()), + port: x.port(), + } + } +} + +impl fmt::Display for ListenEndpoint { + fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { + if let Some(addr) = self.addr { + write!(f, "{}:{}", addr, self.port) + } else { + write!(f, "*:{}", self.port) + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for ListenEndpoint { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{:?}:{=u16}", self.addr, self.port); + } +} + +impl From for ListenEndpoint { + fn from(port: u16) -> ListenEndpoint { + ListenEndpoint { addr: None, port } + } +} + +impl From for ListenEndpoint { + fn from(endpoint: Endpoint) -> ListenEndpoint { + ListenEndpoint { + addr: Some(endpoint.addr), + port: endpoint.port, + } + } +} + +impl> From<(T, u16)> for ListenEndpoint { + fn from((addr, port): (T, u16)) -> ListenEndpoint { + ListenEndpoint { + addr: Some(addr.into()), + port, + } } } /// An IP packet representation. /// -/// This enum abstracts the various versions of IP packets. It either contains a concrete -/// high-level representation for some IP protocol version, or an unspecified representation, -/// which permits the `IpAddress::Unspecified` addresses. +/// This enum abstracts the various versions of IP packets. It either contains an IPv4 +/// or IPv6 concrete high-level representation. #[derive(Debug, Clone, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr { - Unspecified { - src_addr: Address, - dst_addr: Address, - protocol: Protocol, - payload_len: usize, - hop_limit: u8 - }, #[cfg(feature = "proto-ipv4")] Ipv4(Ipv4Repr), #[cfg(feature = "proto-ipv6")] Ipv6(Ipv6Repr), - #[doc(hidden)] - __Nonexhaustive } #[cfg(feature = "proto-ipv4")] @@ -502,256 +561,131 @@ impl From for Repr { } impl Repr { + /// Create a new IpRepr, choosing the right IP version for the src/dst addrs. + /// + /// # Panics + /// + /// Panics if `src_addr` and `dst_addr` are different IP version. + pub fn new( + src_addr: Address, + dst_addr: Address, + next_header: Protocol, + payload_len: usize, + hop_limit: u8, + ) -> Self { + match (src_addr, dst_addr) { + #[cfg(feature = "proto-ipv4")] + (Address::Ipv4(src_addr), Address::Ipv4(dst_addr)) => Self::Ipv4(Ipv4Repr { + src_addr, + dst_addr, + next_header, + payload_len, + hop_limit, + }), + #[cfg(feature = "proto-ipv6")] + (Address::Ipv6(src_addr), Address::Ipv6(dst_addr)) => Self::Ipv6(Ipv6Repr { + src_addr, + dst_addr, + next_header, + payload_len, + hop_limit, + }), + #[allow(unreachable_patterns)] + _ => panic!("IP version mismatch: src={src_addr:?} dst={dst_addr:?}"), + } + } + /// Return the protocol version. - pub fn version(&self) -> Version { - match self { - &Repr::Unspecified { .. } => Version::Unspecified, + pub const fn version(&self) -> Version { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(_) => Version::Ipv4, + Repr::Ipv4(_) => Version::Ipv4, #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(_) => Version::Ipv6, - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(_) => Version::Ipv6, } } /// Return the source address. - pub fn src_addr(&self) -> Address { - match self { - &Repr::Unspecified { src_addr, .. } => src_addr, + pub const fn src_addr(&self) -> Address { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => Address::Ipv4(repr.src_addr), + Repr::Ipv4(repr) => Address::Ipv4(repr.src_addr), #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => Address::Ipv6(repr.src_addr), - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(repr) => Address::Ipv6(repr.src_addr), } } /// Return the destination address. - pub fn dst_addr(&self) -> Address { - match self { - &Repr::Unspecified { dst_addr, .. } => dst_addr, + pub const fn dst_addr(&self) -> Address { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => Address::Ipv4(repr.dst_addr), + Repr::Ipv4(repr) => Address::Ipv4(repr.dst_addr), #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => Address::Ipv6(repr.dst_addr), - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(repr) => Address::Ipv6(repr.dst_addr), } } - /// Return the protocol. - pub fn protocol(&self) -> Protocol { - match self { - &Repr::Unspecified { protocol, .. } => protocol, + /// Return the next header (protocol). + pub const fn next_header(&self) -> Protocol { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => repr.protocol, + Repr::Ipv4(repr) => repr.next_header, #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => repr.next_header, - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(repr) => repr.next_header, } } /// Return the payload length. - pub fn payload_len(&self) -> usize { - match self { - &Repr::Unspecified { payload_len, .. } => payload_len, + pub const fn payload_len(&self) -> usize { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => repr.payload_len, + Repr::Ipv4(repr) => repr.payload_len, #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => repr.payload_len, - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(repr) => repr.payload_len, } } /// Set the payload length. pub fn set_payload_len(&mut self, length: usize) { match self { - &mut Repr::Unspecified { ref mut payload_len, .. } => - *payload_len = length, #[cfg(feature = "proto-ipv4")] - &mut Repr::Ipv4(Ipv4Repr { ref mut payload_len, .. }) => - *payload_len = length, + Repr::Ipv4(Ipv4Repr { payload_len, .. }) => *payload_len = length, #[cfg(feature = "proto-ipv6")] - &mut Repr::Ipv6(Ipv6Repr { ref mut payload_len, .. }) => - *payload_len = length, - &mut Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(Ipv6Repr { payload_len, .. }) => *payload_len = length, } } /// Return the TTL value. - pub fn hop_limit(&self) -> u8 { - match self { - &Repr::Unspecified { hop_limit, .. } => hop_limit, + pub const fn hop_limit(&self) -> u8 { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(Ipv4Repr { hop_limit, .. }) => hop_limit, + Repr::Ipv4(Ipv4Repr { hop_limit, .. }) => hop_limit, #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(Ipv6Repr { hop_limit, ..}) => hop_limit, - &Repr::__Nonexhaustive => unreachable!() - } - } - - /// Convert an unspecified representation into a concrete one, or return - /// `Err(Error::Unaddressable)` if not possible. - /// - /// # Panics - /// This function panics if source and destination addresses belong to different families, - /// or the destination address is unspecified, since this indicates a logic error. - pub fn lower(&self, fallback_src_addrs: &[Cidr]) -> Result { - macro_rules! resolve_unspecified { - ($reprty:path, $ipty:path, $iprepr:expr, $fallbacks:expr) => { - if $iprepr.src_addr.is_unspecified() { - for cidr in $fallbacks { - match cidr.address() { - $ipty(addr) => { - $iprepr.src_addr = addr; - return Ok($reprty($iprepr)); - }, - _ => () - } - } - Err(Error::Unaddressable) - } else { - Ok($reprty($iprepr)) - } - } - } - - match self { - #[cfg(feature = "proto-ipv4")] - &Repr::Unspecified { - src_addr: src_addr @ Address::Unspecified, - dst_addr: Address::Ipv4(dst_addr), - protocol, payload_len, hop_limit - } | - &Repr::Unspecified { - src_addr: src_addr @ Address::Ipv4(_), - dst_addr: Address::Ipv4(dst_addr), - protocol, payload_len, hop_limit - } if src_addr.is_unspecified() => { - let mut src_addr = if let Address::Ipv4(src_ipv4_addr) = src_addr { - Some(src_ipv4_addr) - } else { - None - }; - for cidr in fallback_src_addrs { - if let Address::Ipv4(addr) = cidr.address() { - src_addr = Some(addr); - break; - } - } - Ok(Repr::Ipv4(Ipv4Repr { - src_addr: src_addr.ok_or(Error::Unaddressable)?, - dst_addr, protocol, payload_len, hop_limit - })) - } - - #[cfg(feature = "proto-ipv6")] - &Repr::Unspecified { - src_addr: src_addr @ Address::Unspecified, - dst_addr: Address::Ipv6(dst_addr), - protocol, payload_len, hop_limit - } | - &Repr::Unspecified { - src_addr: src_addr @ Address::Ipv6(_), - dst_addr: Address::Ipv6(dst_addr), - protocol, payload_len, hop_limit - } if src_addr.is_unspecified() => { - let mut src_addr = if let Address::Ipv6(src_ipv6_addr) = src_addr { - Some(src_ipv6_addr) - } else { - None - }; - for cidr in fallback_src_addrs { - if let Address::Ipv6(addr) = cidr.address() { - src_addr = Some(addr); - break; - } - } - Ok(Repr::Ipv6(Ipv6Repr { - src_addr: src_addr.ok_or(Error::Unaddressable)?, - next_header: protocol, - dst_addr, payload_len, hop_limit - })) - } - - #[cfg(feature = "proto-ipv4")] - &Repr::Unspecified { - src_addr: Address::Ipv4(src_addr), - dst_addr: Address::Ipv4(dst_addr), - protocol, payload_len, hop_limit - } => { - Ok(Repr::Ipv4(Ipv4Repr { - src_addr: src_addr, - dst_addr: dst_addr, - protocol: protocol, - payload_len: payload_len, hop_limit - })) - } - - #[cfg(feature = "proto-ipv6")] - &Repr::Unspecified { - src_addr: Address::Ipv6(src_addr), - dst_addr: Address::Ipv6(dst_addr), - protocol, payload_len, hop_limit - } => { - Ok(Repr::Ipv6(Ipv6Repr { - src_addr: src_addr, - dst_addr: dst_addr, - next_header: protocol, - payload_len: payload_len, - hop_limit: hop_limit - })) - } - - #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(mut repr) => - resolve_unspecified!(Repr::Ipv4, Address::Ipv4, repr, fallback_src_addrs), - - #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(mut repr) => - resolve_unspecified!(Repr::Ipv6, Address::Ipv6, repr, fallback_src_addrs), - - &Repr::Unspecified { .. } => - panic!("source and destination IP address families do not match"), - - &Repr::__Nonexhaustive => unreachable!() + Repr::Ipv6(Ipv6Repr { hop_limit, .. }) => hop_limit, } } /// Return the length of a header that will be emitted from this high-level representation. - /// - /// # Panics - /// This function panics if invoked on an unspecified representation. - pub fn buffer_len(&self) -> usize { - match self { - &Repr::Unspecified { .. } => - panic!("unspecified IP representation"), + pub const fn header_len(&self) -> usize { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => - repr.buffer_len(), + Repr::Ipv4(repr) => repr.buffer_len(), #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => - repr.buffer_len(), - &Repr::__Nonexhaustive => - unreachable!() + Repr::Ipv6(repr) => repr.buffer_len(), } } /// Emit this high-level representation into a buffer. - /// - /// # Panics - /// This function panics if invoked on an unspecified representation. - pub fn emit + AsMut<[u8]>>(&self, buffer: T, _checksum_caps: &ChecksumCapabilities) { - match self { - &Repr::Unspecified { .. } => - panic!("unspecified IP representation"), + pub fn emit + AsMut<[u8]>>( + &self, + buffer: T, + _checksum_caps: &ChecksumCapabilities, + ) { + match *self { #[cfg(feature = "proto-ipv4")] - &Repr::Ipv4(repr) => - repr.emit(&mut Ipv4Packet::new_unchecked(buffer), &_checksum_caps), + Repr::Ipv4(repr) => repr.emit(&mut Ipv4Packet::new_unchecked(buffer), _checksum_caps), #[cfg(feature = "proto-ipv6")] - &Repr::Ipv6(repr) => - repr.emit(&mut Ipv6Packet::new_unchecked(buffer)), - &Repr::__Nonexhaustive => - unreachable!() + Repr::Ipv6(repr) => repr.emit(&mut Ipv6Packet::new_unchecked(buffer)), } } @@ -759,11 +693,8 @@ impl Repr { /// high-level representation. /// /// This is the same as `repr.buffer_len() + repr.payload_len()`. - /// - /// # Panics - /// This function panics if invoked on an unspecified representation. - pub fn total_len(&self) -> usize { - self.buffer_len() + self.payload_len() + pub const fn buffer_len(&self) -> usize { + self.header_len() + self.payload_len() } } @@ -772,7 +703,7 @@ pub mod checksum { use super::*; - fn propagate_carries(word: u32) -> u16 { + const fn propagate_carries(word: u32) -> u16 { let sum = (word >> 16) + (word & 0xffff); ((sum >> 16) as u16) + (sum as u16) } @@ -819,36 +750,40 @@ pub mod checksum { } /// Compute an IP pseudo header checksum. - pub fn pseudo_header(src_addr: &Address, dst_addr: &Address, - protocol: Protocol, length: u32) -> u16 { + pub fn pseudo_header( + src_addr: &Address, + dst_addr: &Address, + next_header: Protocol, + length: u32, + ) -> u16 { match (src_addr, dst_addr) { #[cfg(feature = "proto-ipv4")] (&Address::Ipv4(src_addr), &Address::Ipv4(dst_addr)) => { let mut proto_len = [0u8; 4]; - proto_len[1] = protocol.into(); + proto_len[1] = next_header.into(); NetworkEndian::write_u16(&mut proto_len[2..4], length as u16); combine(&[ data(src_addr.as_bytes()), data(dst_addr.as_bytes()), - data(&proto_len[..]) + data(&proto_len[..]), ]) - }, + } #[cfg(feature = "proto-ipv6")] (&Address::Ipv6(src_addr), &Address::Ipv6(dst_addr)) => { let mut proto_len = [0u8; 8]; - proto_len[7] = protocol.into(); + proto_len[7] = next_header.into(); NetworkEndian::write_u32(&mut proto_len[0..4], length); combine(&[ data(src_addr.as_bytes()), data(dst_addr.as_bytes()), - data(&proto_len[..]) + data(&proto_len[..]), ]) } - _ => panic!("Unexpected pseudo header addresses: {}, {}", - src_addr, dst_addr) + #[allow(unreachable_patterns)] + _ => panic!("Unexpected pseudo header addresses: {src_addr}, {dst_addr}"), } } @@ -862,37 +797,51 @@ pub mod checksum { } } -use super::pretty_print::PrettyIndent; +use crate::wire::pretty_print::PrettyIndent; -pub fn pretty_print_ip_payload>(f: &mut fmt::Formatter, indent: &mut PrettyIndent, - ip_repr: T, payload: &[u8]) -> fmt::Result { - #[cfg(feature = "proto-ipv4")] - use wire::Icmpv4Packet; +pub fn pretty_print_ip_payload>( + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ip_repr: T, + payload: &[u8], +) -> fmt::Result { #[cfg(feature = "proto-ipv4")] use super::pretty_print::PrettyPrint; - use wire::{TcpPacket, TcpRepr, UdpPacket, UdpRepr}; - use wire::ip::checksum::format_checksum; + use crate::wire::ip::checksum::format_checksum; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Icmpv4Packet; + use crate::wire::{TcpPacket, TcpRepr, UdpPacket, UdpRepr}; let checksum_caps = ChecksumCapabilities::ignored(); let repr = ip_repr.into(); - match repr.protocol() { + match repr.next_header() { #[cfg(feature = "proto-ipv4")] Protocol::Icmp => { indent.increase(f)?; - Icmpv4Packet::<&[u8]>::pretty_print(&payload.as_ref(), f, indent) + Icmpv4Packet::<&[u8]>::pretty_print(&payload, f, indent) } Protocol::Udp => { indent.increase(f)?; - match UdpPacket::<&[u8]>::new_checked(payload.as_ref()) { - Err(err) => write!(f, "{}({})", indent, err), + match UdpPacket::<&[u8]>::new_checked(payload) { + Err(err) => write!(f, "{indent}({err})"), Ok(udp_packet) => { - match UdpRepr::parse(&udp_packet, &repr.src_addr(), - &repr.dst_addr(), &checksum_caps) { - Err(err) => write!(f, "{}{} ({})", indent, udp_packet, err), + match UdpRepr::parse( + &udp_packet, + &repr.src_addr(), + &repr.dst_addr(), + &checksum_caps, + ) { + Err(err) => write!(f, "{indent}{udp_packet} ({err})"), Ok(udp_repr) => { - write!(f, "{}{}", indent, udp_repr)?; - let valid = udp_packet.verify_checksum(&repr.src_addr(), - &repr.dst_addr()); + write!( + f, + "{}{} len={}", + indent, + udp_repr, + udp_packet.payload().len() + )?; + let valid = + udp_packet.verify_checksum(&repr.src_addr(), &repr.dst_addr()); format_checksum(f, valid) } } @@ -901,23 +850,27 @@ pub fn pretty_print_ip_payload>(f: &mut fmt::Formatter, indent: &m } Protocol::Tcp => { indent.increase(f)?; - match TcpPacket::<&[u8]>::new_checked(payload.as_ref()) { - Err(err) => write!(f, "{}({})", indent, err), + match TcpPacket::<&[u8]>::new_checked(payload) { + Err(err) => write!(f, "{indent}({err})"), Ok(tcp_packet) => { - match TcpRepr::parse(&tcp_packet, &repr.src_addr(), - &repr.dst_addr(), &checksum_caps) { - Err(err) => write!(f, "{}{} ({})", indent, tcp_packet, err), + match TcpRepr::parse( + &tcp_packet, + &repr.src_addr(), + &repr.dst_addr(), + &checksum_caps, + ) { + Err(err) => write!(f, "{indent}{tcp_packet} ({err})"), Ok(tcp_repr) => { - write!(f, "{}{}", indent, tcp_repr)?; - let valid = tcp_packet.verify_checksum(&repr.src_addr(), - &repr.dst_addr()); + write!(f, "{indent}{tcp_repr}")?; + let valid = + tcp_packet.verify_checksum(&repr.src_addr(), &repr.dst_addr()); format_checksum(f, valid) } } } } } - _ => Ok(()) + _ => Ok(()), } } @@ -926,17 +879,21 @@ pub(crate) mod test { #![allow(unused)] #[cfg(feature = "proto-ipv6")] - pub(crate) const MOCK_IP_ADDR_1: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 1])); + pub(crate) const MOCK_IP_ADDR_1: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, + ])); #[cfg(feature = "proto-ipv6")] - pub(crate) const MOCK_IP_ADDR_2: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 2])); + pub(crate) const MOCK_IP_ADDR_2: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2, + ])); #[cfg(feature = "proto-ipv6")] - pub(crate) const MOCK_IP_ADDR_3: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 3])); + pub(crate) const MOCK_IP_ADDR_3: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, + ])); #[cfg(feature = "proto-ipv6")] - pub(crate) const MOCK_IP_ADDR_4: IpAddress = IpAddress::Ipv6(Ipv6Address([0xfe, 0x80, 0, 0, 0, 0, 0, 0, - 0, 0, 0, 0, 0, 0, 0, 4])); + pub(crate) const MOCK_IP_ADDR_4: IpAddress = IpAddress::Ipv6(Ipv6Address([ + 0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 4, + ])); #[cfg(feature = "proto-ipv6")] pub(crate) const MOCK_UNSPECIFIED: IpAddress = IpAddress::Ipv6(Ipv6Address::UNSPECIFIED); @@ -951,200 +908,16 @@ pub(crate) mod test { #[cfg(all(feature = "proto-ipv4", not(feature = "proto-ipv6")))] pub(crate) const MOCK_UNSPECIFIED: IpAddress = IpAddress::Ipv4(Ipv4Address::UNSPECIFIED); - use super::*; - use wire::{IpAddress, IpProtocol,IpCidr}; - #[cfg(feature = "proto-ipv4")] - use wire::{Ipv4Address, Ipv4Repr}; - - macro_rules! generate_common_tests { - ($name:ident, $repr:ident, $ip_repr:path, $ip_addr:path, - $addr_from:path, $nxthdr:ident, $bytes_a:expr, $bytes_b:expr, - $unspecified:expr) => { - mod $name { - use super::*; - - #[test] - fn test_ip_repr_lower() { - let ip_addr_a = $addr_from(&$bytes_a); - let ip_addr_b = $addr_from(&$bytes_b); - let proto = IpProtocol::Icmp; - let payload_len = 10; - - assert_eq!( - Repr::Unspecified{ - src_addr: $ip_addr(ip_addr_a), - dst_addr: $ip_addr(ip_addr_b), - protocol: proto, - hop_limit: 0x2a, - payload_len, - }.lower(&[]), - Ok($ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 0x2a, - payload_len - })) - ); - - assert_eq!( - Repr::Unspecified{ - src_addr: IpAddress::Unspecified, - dst_addr: $ip_addr(ip_addr_b), - protocol: proto, - hop_limit: 64, - payload_len - }.lower(&[]), - Err(Error::Unaddressable) - ); - - assert_eq!( - Repr::Unspecified{ - src_addr: IpAddress::Unspecified, - dst_addr: $ip_addr(ip_addr_b), - protocol: proto, - hop_limit: 64, - payload_len - }.lower(&[IpCidr::new($ip_addr(ip_addr_a), 24)]), - Ok($ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 64, - payload_len - })) - ); - - assert_eq!( - Repr::Unspecified{ - src_addr: $ip_addr($unspecified), - dst_addr: $ip_addr(ip_addr_b), - protocol: proto, - hop_limit: 64, - payload_len - }.lower(&[IpCidr::new($ip_addr(ip_addr_a), 24)]), - Ok($ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 64, - payload_len - })) - ); - - assert_eq!( - Repr::Unspecified{ - src_addr: $ip_addr($unspecified), - dst_addr: $ip_addr(ip_addr_b), - protocol: proto, - hop_limit: 64, - payload_len - }.lower(&[]), - Ok($ip_repr($repr{ - src_addr: $unspecified, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 64, - payload_len - })) - ); - - assert_eq!( - $ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 255, - payload_len - }).lower(&[]), - Ok($ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 255, - payload_len - })) - ); - - assert_eq!( - $ip_repr($repr{ - src_addr: $unspecified, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 255, - payload_len - }).lower(&[]), - Err(Error::Unaddressable) - ); - - assert_eq!( - $ip_repr($repr{ - src_addr: $unspecified, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 64, - payload_len - }).lower(&[IpCidr::new($ip_addr(ip_addr_a), 24)]), - Ok($ip_repr($repr{ - src_addr: ip_addr_a, - dst_addr: ip_addr_b, - $nxthdr: proto, - hop_limit: 64, - payload_len - })) - ); - } - } - }; - (ipv4 $addr_bytes_a:expr, $addr_bytes_b:expr) => { - generate_common_tests!(ipv4, Ipv4Repr, Repr::Ipv4, IpAddress::Ipv4, - Ipv4Address::from_bytes, protocol, $addr_bytes_a, - $addr_bytes_b, Ipv4Address::UNSPECIFIED); - }; - (ipv6 $addr_bytes_a:expr, $addr_bytes_b:expr) => { - generate_common_tests!(ipv6, Ipv6Repr, Repr::Ipv6, IpAddress::Ipv6, - Ipv6Address::from_bytes, next_header, $addr_bytes_a, - $addr_bytes_b, Ipv6Address::UNSPECIFIED); - } - } - + use crate::wire::{IpAddress, IpCidr, IpProtocol}; #[cfg(feature = "proto-ipv4")] - generate_common_tests!(ipv4 - [1, 2, 3, 4], - [5, 6, 7, 8]); - - #[cfg(feature = "proto-ipv6")] - generate_common_tests!(ipv6 - [0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1], - [0xfe, 0x80, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 2]); - - #[test] - #[cfg(all(feature = "proto-ipv4", feature = "proto-ipv6"))] - #[should_panic(expected = "source and destination IP address families do not match")] - fn test_lower_between_families() { - Repr::Unspecified { - src_addr: Address::Ipv6(Ipv6Address::UNSPECIFIED), - dst_addr: Address::Ipv4(Ipv4Address::UNSPECIFIED), - protocol: IpProtocol::Icmpv6, - hop_limit: 0xff, - payload_len: 0 - }.lower(&[]); - } - - #[test] - fn endpoint_unspecified() { - assert!(!Endpoint::UNSPECIFIED.is_specified()); - } + use crate::wire::{Ipv4Address, Ipv4Repr}; #[test] #[cfg(feature = "proto-ipv4")] fn to_prefix_len_ipv4() { fn test_eq>(prefix_len: u8, mask: A) { - assert_eq!( - Some(prefix_len), - mask.into().to_prefix_len() - ); + assert_eq!(Some(prefix_len), mask.into().prefix_len()); } test_eq(0, Ipv4Address::new(0, 0, 0, 0)); @@ -1182,27 +955,40 @@ pub(crate) mod test { test_eq(32, Ipv4Address::new(255, 255, 255, 255)); } + #[test] #[cfg(feature = "proto-ipv4")] fn to_prefix_len_ipv4_error() { - assert_eq!(None, IpAddress::from(Ipv4Address::new(255,255,255,1)).to_prefix_len()); + assert_eq!( + None, + IpAddress::from(Ipv4Address::new(255, 255, 255, 1)).prefix_len() + ); } #[test] #[cfg(feature = "proto-ipv6")] fn to_prefix_len_ipv6() { fn test_eq>(prefix_len: u8, mask: A) { - assert_eq!( - Some(prefix_len), - mask.into().to_prefix_len() - ); + assert_eq!(Some(prefix_len), mask.into().prefix_len()); } test_eq(0, Ipv6Address::new(0, 0, 0, 0, 0, 0, 0, 0)); - test_eq(128, Ipv6Address::new(0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff)); + test_eq( + 128, + Ipv6Address::new( + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, + ), + ); } + #[test] #[cfg(feature = "proto-ipv6")] fn to_prefix_len_ipv6_error() { - assert_eq!(None, IpAddress::from(Ipv6Address::new(0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 1)).to_prefix_len()); + assert_eq!( + None, + IpAddress::from(Ipv6Address::new( + 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0xffff, 0, 1 + )) + .prefix_len() + ); } } diff --git a/src/wire/ipv4.rs b/src/wire/ipv4.rs index c2360cb9c..1027fc262 100644 --- a/src/wire/ipv4.rs +++ b/src/wire/ipv4.rs @@ -1,9 +1,9 @@ -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; -use {Error, Result}; -use phy::ChecksumCapabilities; -use super::ip::{checksum, pretty_print_ip_payload}; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::{checksum, pretty_print_ip_payload}; pub use super::IpProtocol as Protocol; @@ -21,16 +21,30 @@ pub use super::IpProtocol as Protocol; // accept a packet of the following size. pub const MIN_MTU: usize = 576; +/// Size of IPv4 adderess in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc791#section-3.2 +pub const ADDR_SIZE: usize = 4; + +#[derive(Debug, Eq, PartialEq, Ord, PartialOrd, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Key { + id: u16, + src_addr: Address, + dst_addr: Address, + protocol: Protocol, +} + /// A four-octet IPv4 address. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] -pub struct Address(pub [u8; 4]); +pub struct Address(pub [u8; ADDR_SIZE]); impl Address { /// An unspecified address. - pub const UNSPECIFIED: Address = Address([0x00; 4]); + pub const UNSPECIFIED: Address = Address([0x00; ADDR_SIZE]); /// The broadcast address. - pub const BROADCAST: Address = Address([0xff; 4]); + pub const BROADCAST: Address = Address([0xff; ADDR_SIZE]); /// All multicast-capable nodes pub const MULTICAST_ALL_SYSTEMS: Address = Address([224, 0, 0, 1]); @@ -39,7 +53,7 @@ impl Address { pub const MULTICAST_ALL_ROUTERS: Address = Address([224, 0, 0, 2]); /// Construct an IPv4 address from parts. - pub fn new(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { + pub const fn new(a0: u8, a1: u8, a2: u8, a3: u8) -> Address { Address([a0, a1, a2, a3]) } @@ -48,35 +62,33 @@ impl Address { /// # Panics /// The function panics if `data` is not four octets long. pub fn from_bytes(data: &[u8]) -> Address { - let mut bytes = [0; 4]; + let mut bytes = [0; ADDR_SIZE]; bytes.copy_from_slice(data); Address(bytes) } /// Return an IPv4 address as a sequence of octets, in big-endian. - pub fn as_bytes(&self) -> &[u8] { + pub const fn as_bytes(&self) -> &[u8] { &self.0 } /// Query whether the address is an unicast address. pub fn is_unicast(&self) -> bool { - !(self.is_broadcast() || - self.is_multicast() || - self.is_unspecified()) + !(self.is_broadcast() || self.is_multicast() || self.is_unspecified()) } /// Query whether the address is the broadcast address. pub fn is_broadcast(&self) -> bool { - self.0[0..4] == [255; 4] + self.0[0..4] == [255; ADDR_SIZE] } /// Query whether the address is a multicast address. - pub fn is_multicast(&self) -> bool { + pub const fn is_multicast(&self) -> bool { self.0[0] & 0xf0 == 224 } /// Query whether the address falls into the "unspecified" range. - pub fn is_unspecified(&self) -> bool { + pub const fn is_unspecified(&self) -> bool { self.0[0] == 0 } @@ -86,9 +98,16 @@ impl Address { } /// Query whether the address falls into the "loopback" range. - pub fn is_loopback(&self) -> bool { + pub const fn is_loopback(&self) -> bool { self.0[0] == 127 } + + /// Convert to an `IpAddress`. + /// + /// Same as `.into()`, but works in `const`. + pub const fn into_address(self) -> super::IpAddress { + super::IpAddress::Ipv4(self) + } } #[cfg(feature = "std")] @@ -112,11 +131,25 @@ impl fmt::Display for Address { } } +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + defmt::write!( + f, + "{=u8}.{=u8}.{=u8}.{=u8}", + self.0[0], + self.0[1], + self.0[2], + self.0[3] + ) + } +} + /// A specification of an IPv4 CIDR block, containing an address and a variable-length /// subnet masking prefix length. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] pub struct Cidr { - address: Address, + address: Address, prefix_len: u8, } @@ -125,33 +158,42 @@ impl Cidr { /// /// # Panics /// This function panics if the prefix length is larger than 32. - pub fn new(address: Address, prefix_len: u8) -> Cidr { - assert!(prefix_len <= 32); - Cidr { address, prefix_len } + #[allow(clippy::no_effect)] + pub const fn new(address: Address, prefix_len: u8) -> Cidr { + // Replace with const panic (or assert) when stabilized + // see: https://github.com/rust-lang/rust/issues/51999 + ["Prefix length should be <= 32"][(prefix_len > 32) as usize]; + Cidr { + address, + prefix_len, + } } /// Create an IPv4 CIDR block from the given address and network mask. pub fn from_netmask(addr: Address, netmask: Address) -> Result { let netmask = NetworkEndian::read_u32(&netmask.0[..]); if netmask.leading_zeros() == 0 && netmask.trailing_zeros() == netmask.count_zeros() { - Ok(Cidr { address: addr, prefix_len: netmask.count_ones() as u8 }) + Ok(Cidr { + address: addr, + prefix_len: netmask.count_ones() as u8, + }) } else { - Err(Error::Illegal) + Err(Error) } } /// Return the address of this IPv4 CIDR block. - pub fn address(&self) -> Address { + pub const fn address(&self) -> Address { self.address } /// Return the prefix length of this IPv4 CIDR block. - pub fn prefix_len(&self) -> u8 { + pub const fn prefix_len(&self) -> u8 { self.prefix_len } /// Return the network mask of this IPv4 CIDR. - pub fn netmask(&self) -> Address { + pub const fn netmask(&self) -> Address { if self.prefix_len == 0 { return Address([0, 0, 0, 0]); } @@ -160,8 +202,8 @@ impl Cidr { let data = [ ((number >> 24) & 0xff) as u8, ((number >> 16) & 0xff) as u8, - ((number >> 8) & 0xff) as u8, - ((number >> 0) & 0xff) as u8, + ((number >> 8) & 0xff) as u8, + ((number >> 0) & 0xff) as u8, ]; Address(data) @@ -180,15 +222,15 @@ impl Cidr { let data = [ ((number >> 24) & 0xff) as u8, ((number >> 16) & 0xff) as u8, - ((number >> 8) & 0xff) as u8, - ((number >> 0) & 0xff) as u8, + ((number >> 8) & 0xff) as u8, + ((number >> 0) & 0xff) as u8, ]; Some(Address(data)) } /// Return the network block of this IPv4 CIDR. - pub fn network(&self) -> Cidr { + pub const fn network(&self) -> Cidr { let mask = self.netmask().0; let network = [ self.address.0[0] & mask[0], @@ -196,14 +238,19 @@ impl Cidr { self.address.0[2] & mask[2], self.address.0[3] & mask[3], ]; - Cidr { address: Address(network), prefix_len: self.prefix_len } + Cidr { + address: Address(network), + prefix_len: self.prefix_len, + } } /// Query whether the subnetwork described by this IPv4 CIDR block contains /// the given address. pub fn contains_addr(&self, addr: &Address) -> bool { // right shift by 32 is not legal - if self.prefix_len == 0 { return true } + if self.prefix_len == 0 { + return true; + } let shift = 32 - self.prefix_len; let self_prefix = NetworkEndian::read_u32(self.address.as_bytes()) >> shift; @@ -224,30 +271,40 @@ impl fmt::Display for Cidr { } } +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}/{=u8}", self.address, self.prefix_len); + } +} + /// A read/write wrapper around an Internet Protocol version 4 packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } mod field { - use wire::field::*; + use crate::wire::field::*; - pub const VER_IHL: usize = 0; + pub const VER_IHL: usize = 0; pub const DSCP_ECN: usize = 1; - pub const LENGTH: Field = 2..4; - pub const IDENT: Field = 4..6; - pub const FLG_OFF: Field = 6..8; - pub const TTL: usize = 8; + pub const LENGTH: Field = 2..4; + pub const IDENT: Field = 4..6; + pub const FLG_OFF: Field = 6..8; + pub const TTL: usize = 8; pub const PROTOCOL: usize = 9; pub const CHECKSUM: Field = 10..12; pub const SRC_ADDR: Field = 12..16; pub const DST_ADDR: Field = 16..20; } +pub const HEADER_LEN: usize = field::DST_ADDR.end; + impl> Packet { /// Imbue a raw octet buffer with IPv4 packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -262,8 +319,8 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. - /// Returns `Err(Error::Malformed)` if the header length is greater + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the header length is greater /// than total length. /// /// The result of this check is invalidated by calling [set_header_len] @@ -271,16 +328,17 @@ impl> Packet { /// /// [set_header_len]: #method.set_header_len /// [set_total_len]: #method.set_total_len + #[allow(clippy::if_same_then_else)] pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::DST_ADDR.end { - Err(Error::Truncated) + Err(Error) } else if len < self.header_len() as usize { - Err(Error::Truncated) + Err(Error) } else if self.header_len() as u16 > self.total_len() { - Err(Error::Malformed) + Err(Error) } else if len < self.total_len() as usize { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -359,9 +417,9 @@ impl> Packet { data[field::TTL] } - /// Return the protocol field. + /// Return the next_header (protocol) field. #[inline] - pub fn protocol(&self) -> Protocol { + pub fn next_header(&self) -> Protocol { let data = self.buffer.as_ref(); Protocol::from(data[field::PROTOCOL]) } @@ -392,11 +450,23 @@ impl> Packet { /// # Fuzzing /// This function always returns `true` when fuzzing. pub fn verify_checksum(&self) -> bool { - if cfg!(fuzzing) { return true } + if cfg!(fuzzing) { + return true; + } let data = self.buffer.as_ref(); checksum::data(&data[..self.header_len() as usize]) == !0 } + + /// Returns the key for identifying the packet. + pub fn get_key(&self) -> Key { + Key { + id: self.ident(), + src_addr: self.src_addr(), + dst_addr: self.dst_addr(), + protocol: self.next_header(), + } + } } impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { @@ -493,9 +563,9 @@ impl + AsMut<[u8]>> Packet { data[field::TTL] = value } - /// Set the protocol field. + /// Set the next header (protocol) field. #[inline] - pub fn set_protocol(&mut self, value: Protocol) { + pub fn set_next_header(&mut self, value: Protocol) { let data = self.buffer.as_mut(); data[field::PROTOCOL] = value.into() } @@ -548,48 +618,62 @@ impl> AsRef<[u8]> for Packet { /// A high-level representation of an Internet Protocol version 4 packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Repr { - pub src_addr: Address, - pub dst_addr: Address, - pub protocol: Protocol, + pub src_addr: Address, + pub dst_addr: Address, + pub next_header: Protocol, pub payload_len: usize, - pub hop_limit: u8 + pub hop_limit: u8, } impl Repr { /// Parse an Internet Protocol version 4 packet and return a high-level representation. - pub fn parse + ?Sized>(packet: &Packet<&T>, - checksum_caps: &ChecksumCapabilities) -> Result { + pub fn parse + ?Sized>( + packet: &Packet<&T>, + checksum_caps: &ChecksumCapabilities, + ) -> Result { // Version 4 is expected. - if packet.version() != 4 { return Err(Error::Malformed) } + if packet.version() != 4 { + return Err(Error); + } // Valid checksum is expected. - if checksum_caps.ipv4.rx() && !packet.verify_checksum() { return Err(Error::Checksum) } + if checksum_caps.ipv4.rx() && !packet.verify_checksum() { + return Err(Error); + } + + #[cfg(not(feature = "proto-ipv4-fragmentation"))] // We do not support fragmentation. - if packet.more_frags() || packet.frag_offset() != 0 { return Err(Error::Fragmented) } - // Since the packet is not fragmented, it must include the entire payload. + if packet.more_frags() || packet.frag_offset() != 0 { + return Err(Error); + } + let payload_len = packet.total_len() as usize - packet.header_len() as usize; - if packet.payload().len() < payload_len { return Err(Error::Truncated) } // All DSCP values are acceptable, since they are of no concern to receiving endpoint. // All ECN values are acceptable, since ECN requires opt-in from both endpoints. // All TTL values are acceptable, since we do not perform routing. Ok(Repr { - src_addr: packet.src_addr(), - dst_addr: packet.dst_addr(), - protocol: packet.protocol(), - payload_len: payload_len, - hop_limit: packet.hop_limit() + src_addr: packet.src_addr(), + dst_addr: packet.dst_addr(), + next_header: packet.next_header(), + payload_len, + hop_limit: packet.hop_limit(), }) } /// Return the length of a header that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { // We never emit any options. field::DST_ADDR.end } /// Emit a high-level representation into an Internet Protocol version 4 packet. - pub fn emit + AsMut<[u8]>>(&self, packet: &mut Packet, checksum_caps: &ChecksumCapabilities) { + pub fn emit + AsMut<[u8]>>( + &self, + packet: &mut Packet, + checksum_caps: &ChecksumCapabilities, + ) { packet.set_version(4); packet.set_header_len(field::DST_ADDR.end as u8); packet.set_dscp(0); @@ -602,7 +686,7 @@ impl Repr { packet.set_dont_frag(true); packet.set_frag_offset(0); packet.set_hop_limit(self.hop_limit); - packet.set_protocol(self.protocol); + packet.set_next_header(self.next_header); packet.set_src_addr(self.src_addr); packet.set_dst_addr(self.dst_addr); @@ -619,11 +703,17 @@ impl Repr { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self, &ChecksumCapabilities::ignored()) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "IPv4 ({})", err)?; - write!(f, " src={} dst={} proto={} hop_limit={}", - self.src_addr(), self.dst_addr(), self.protocol(), self.hop_limit())?; + write!(f, "IPv4 ({err})")?; + write!( + f, + " src={} dst={} proto={} hop_limit={}", + self.src_addr(), + self.dst_addr(), + self.next_header(), + self.hop_limit() + )?; if self.version() != 4 { write!(f, " ver={}", self.version())?; } @@ -657,32 +747,47 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "IPv4 src={} dst={} proto={}", - self.src_addr, self.dst_addr, self.protocol) + write!( + f, + "IPv4 src={} dst={} proto={}", + self.src_addr, self.dst_addr, self.next_header + ) } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { - use wire::ip::checksum::format_checksum; + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { + use crate::wire::ip::checksum::format_checksum; let checksum_caps = ChecksumCapabilities::ignored(); let (ip_repr, payload) = match Packet::new_checked(buffer) { - Err(err) => return write!(f, "{}({})", indent, err), - Ok(ip_packet) => { - match Repr::parse(&ip_packet, &checksum_caps) { - Err(_) => return Ok(()), - Ok(ip_repr) => { - write!(f, "{}{}", indent, ip_repr)?; + Err(err) => return write!(f, "{indent}({err})"), + Ok(ip_packet) => match Repr::parse(&ip_packet, &checksum_caps) { + Err(_) => return Ok(()), + Ok(ip_repr) => { + if ip_packet.more_frags() || ip_packet.frag_offset() != 0 { + write!( + f, + "{}IPv4 Fragment more_frags={} offset={}", + indent, + ip_packet.more_frags(), + ip_packet.frag_offset() + )?; + return Ok(()); + } else { + write!(f, "{indent}{ip_repr}")?; format_checksum(f, ip_packet.verify_checksum())?; (ip_repr, ip_packet.payload()) } } - } + }, }; pretty_print_ip_payload(f, indent, ip_repr, payload) @@ -693,20 +798,12 @@ impl> PrettyPrint for Packet { mod test { use super::*; - static PACKET_BYTES: [u8; 30] = - [0x45, 0x00, 0x00, 0x1e, - 0x01, 0x02, 0x62, 0x03, - 0x1a, 0x01, 0xd5, 0x6e, - 0x11, 0x12, 0x13, 0x14, - 0x21, 0x22, 0x23, 0x24, - 0xaa, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0xff]; - - static PAYLOAD_BYTES: [u8; 10] = - [0xaa, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0xff]; + static PACKET_BYTES: [u8; 30] = [ + 0x45, 0x00, 0x00, 0x1e, 0x01, 0x02, 0x62, 0x03, 0x1a, 0x01, 0xd5, 0x6e, 0x11, 0x12, 0x13, + 0x14, 0x21, 0x22, 0x23, 0x24, 0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, + ]; + + static PAYLOAD_BYTES: [u8; 10] = [0xaa, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff]; #[test] fn test_deconstruct() { @@ -717,15 +814,15 @@ mod test { assert_eq!(packet.ecn(), 0); assert_eq!(packet.total_len(), 30); assert_eq!(packet.ident(), 0x102); - assert_eq!(packet.more_frags(), true); - assert_eq!(packet.dont_frag(), true); + assert!(packet.more_frags()); + assert!(packet.dont_frag()); assert_eq!(packet.frag_offset(), 0x203 * 8); assert_eq!(packet.hop_limit(), 0x1a); - assert_eq!(packet.protocol(), Protocol::Icmp); + assert_eq!(packet.next_header(), Protocol::Icmp); assert_eq!(packet.checksum(), 0xd56e); assert_eq!(packet.src_addr(), Address([0x11, 0x12, 0x13, 0x14])); assert_eq!(packet.dst_addr(), Address([0x21, 0x22, 0x23, 0x24])); - assert_eq!(packet.verify_checksum(), true); + assert!(packet.verify_checksum()); assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); } @@ -744,12 +841,12 @@ mod test { packet.set_dont_frag(true); packet.set_frag_offset(0x203 * 8); packet.set_hop_limit(0x1a); - packet.set_protocol(Protocol::Icmp); + packet.set_next_header(Protocol::Icmp); packet.set_src_addr(Address([0x11, 0x12, 0x13, 0x14])); packet.set_dst_addr(Address([0x21, 0x22, 0x23, 0x24])); packet.fill_checksum(); packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } #[test] @@ -758,10 +855,14 @@ mod test { bytes.extend(&PACKET_BYTES[..]); bytes.push(0); - assert_eq!(Packet::new_unchecked(&bytes).payload().len(), - PAYLOAD_BYTES.len()); - assert_eq!(Packet::new_unchecked(&mut bytes).payload_mut().len(), - PAYLOAD_BYTES.len()); + assert_eq!( + Packet::new_unchecked(&bytes).payload().len(), + PAYLOAD_BYTES.len() + ); + assert_eq!( + Packet::new_unchecked(&mut bytes).payload_mut().len(), + PAYLOAD_BYTES.len() + ); } #[test] @@ -770,28 +871,23 @@ mod test { bytes.extend(&PACKET_BYTES[..]); Packet::new_unchecked(&mut bytes).set_total_len(128); - assert_eq!(Packet::new_checked(&bytes).unwrap_err(), - Error::Truncated); + assert_eq!(Packet::new_checked(&bytes).unwrap_err(), Error); } - static REPR_PACKET_BYTES: [u8; 24] = - [0x45, 0x00, 0x00, 0x18, - 0x00, 0x00, 0x40, 0x00, - 0x40, 0x01, 0xd2, 0x79, - 0x11, 0x12, 0x13, 0x14, - 0x21, 0x22, 0x23, 0x24, - 0xaa, 0x00, 0x00, 0xff]; + static REPR_PACKET_BYTES: [u8; 24] = [ + 0x45, 0x00, 0x00, 0x18, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0xd2, 0x79, 0x11, 0x12, 0x13, + 0x14, 0x21, 0x22, 0x23, 0x24, 0xaa, 0x00, 0x00, 0xff, + ]; - static REPR_PAYLOAD_BYTES: [u8; 4] = - [0xaa, 0x00, 0x00, 0xff]; + static REPR_PAYLOAD_BYTES: [u8; ADDR_SIZE] = [0xaa, 0x00, 0x00, 0xff]; - fn packet_repr() -> Repr { + const fn packet_repr() -> Repr { Repr { - src_addr: Address([0x11, 0x12, 0x13, 0x14]), - dst_addr: Address([0x21, 0x22, 0x23, 0x24]), - protocol: Protocol::Icmp, + src_addr: Address([0x11, 0x12, 0x13, 0x14]), + dst_addr: Address([0x21, 0x22, 0x23, 0x24]), + next_header: Protocol::Icmp, payload_len: 4, - hop_limit: 64 + hop_limit: 64, } } @@ -810,14 +906,17 @@ mod test { packet.set_version(6); packet.fill_checksum(); let packet = Packet::new_unchecked(&*packet.into_inner()); - assert_eq!(Repr::parse(&packet, &ChecksumCapabilities::default()), Err(Error::Malformed)); + assert_eq!( + Repr::parse(&packet, &ChecksumCapabilities::default()), + Err(Error) + ); } #[test] fn test_parse_total_len_less_than_header_len() { let mut bytes = vec![0; 40]; bytes[0] = 0x09; - assert_eq!(Packet::new_checked(&mut bytes), Err(Error::Malformed)); + assert_eq!(Packet::new_checked(&mut bytes), Err(Error)); } #[test] @@ -827,7 +926,7 @@ mod test { let mut packet = Packet::new_unchecked(&mut bytes); repr.emit(&mut packet, &ChecksumCapabilities::default()); packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES); - assert_eq!(&packet.into_inner()[..], &REPR_PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &REPR_PACKET_BYTES[..]); } #[test] @@ -853,28 +952,34 @@ mod test { let cidr = Cidr::new(Address::new(192, 168, 1, 10), 24); let inside_subnet = [ - [192, 168, 1, 0], [192, 168, 1, 1], - [192, 168, 1, 2], [192, 168, 1, 10], - [192, 168, 1, 127], [192, 168, 1, 255], + [192, 168, 1, 0], + [192, 168, 1, 1], + [192, 168, 1, 2], + [192, 168, 1, 10], + [192, 168, 1, 127], + [192, 168, 1, 255], ]; let outside_subnet = [ - [192, 168, 0, 0], [127, 0, 0, 1], - [192, 168, 2, 0], [192, 168, 0, 255], - [ 0, 0, 0, 0], [255, 255, 255, 255], + [192, 168, 0, 0], + [127, 0, 0, 1], + [192, 168, 2, 0], + [192, 168, 0, 255], + [0, 0, 0, 0], + [255, 255, 255, 255], ]; let subnets = [ - ([192, 168, 1, 0], 32), - ([192, 168, 1, 255], 24), - ([192, 168, 1, 10], 30), + ([192, 168, 1, 0], 32), + ([192, 168, 1, 255], 24), + ([192, 168, 1, 10], 30), ]; let not_subnets = [ - ([192, 168, 1, 10], 23), - ([127, 0, 0, 1], 8), - ([192, 168, 1, 0], 0), - ([192, 168, 0, 255], 32), + ([192, 168, 1, 10], 23), + ([127, 0, 0, 1], 8), + ([192, 168, 1, 0], 0), + ([192, 168, 0, 255], 32), ]; for addr in inside_subnet.iter().map(|a| Address::from_bytes(a)) { @@ -885,13 +990,17 @@ mod test { assert!(!cidr.contains_addr(&addr)); } - for subnet in subnets.iter().map( - |&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) { + for subnet in subnets + .iter() + .map(|&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) + { assert!(cidr.contains_subnet(&subnet)); } - for subnet in not_subnets.iter().map( - |&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) { + for subnet in not_subnets + .iter() + .map(|&(a, p)| Cidr::new(Address::new(a[0], a[1], a[2], a[3]), p)) + { assert!(!cidr.contains_subnet(&subnet)); } @@ -901,94 +1010,169 @@ mod test { #[test] fn test_cidr_from_netmask() { - assert_eq!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([1, 0, 2, 0])).is_err(), - true); - assert_eq!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([0, 0, 0, 0])).is_err(), - true); - assert_eq!(Cidr::from_netmask(Address([0, 0, 0, 1]), Address([255, 255, 255, 0])).unwrap(), - Cidr::new(Address([0, 0, 0, 1]), 24)); - assert_eq!(Cidr::from_netmask(Address([192, 168, 0, 1]), Address([255, 255, 0, 0])).unwrap(), - Cidr::new(Address([192, 168, 0, 1]), 16)); - assert_eq!(Cidr::from_netmask(Address([172, 16, 0, 1]), Address([255, 240, 0, 0])).unwrap(), - Cidr::new(Address([172, 16, 0, 1]), 12)); - assert_eq!(Cidr::from_netmask(Address([255, 255, 255, 1]), Address([255, 255, 255, 0])).unwrap(), - Cidr::new(Address([255, 255, 255, 1]), 24)); - assert_eq!(Cidr::from_netmask(Address([255, 255, 255, 255]), Address([255, 255, 255, 255])).unwrap(), - Cidr::new(Address([255, 255, 255, 255]), 32)); + assert!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([1, 0, 2, 0])).is_err()); + assert!(Cidr::from_netmask(Address([0, 0, 0, 0]), Address([0, 0, 0, 0])).is_err()); + assert_eq!( + Cidr::from_netmask(Address([0, 0, 0, 1]), Address([255, 255, 255, 0])).unwrap(), + Cidr::new(Address([0, 0, 0, 1]), 24) + ); + assert_eq!( + Cidr::from_netmask(Address([192, 168, 0, 1]), Address([255, 255, 0, 0])).unwrap(), + Cidr::new(Address([192, 168, 0, 1]), 16) + ); + assert_eq!( + Cidr::from_netmask(Address([172, 16, 0, 1]), Address([255, 240, 0, 0])).unwrap(), + Cidr::new(Address([172, 16, 0, 1]), 12) + ); + assert_eq!( + Cidr::from_netmask(Address([255, 255, 255, 1]), Address([255, 255, 255, 0])).unwrap(), + Cidr::new(Address([255, 255, 255, 1]), 24) + ); + assert_eq!( + Cidr::from_netmask(Address([255, 255, 255, 255]), Address([255, 255, 255, 255])) + .unwrap(), + Cidr::new(Address([255, 255, 255, 255]), 32) + ); } #[test] fn test_cidr_netmask() { - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 0).netmask(), - Address([0, 0, 0, 0])); - assert_eq!(Cidr::new(Address([0, 0, 0, 1]), 24).netmask(), - Address([255, 255, 255, 0])); - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 32).netmask(), - Address([255, 255, 255, 255])); - assert_eq!(Cidr::new(Address([127, 0, 0, 0]), 8).netmask(), - Address([255, 0, 0, 0])); - assert_eq!(Cidr::new(Address([192, 168, 0, 0]), 16).netmask(), - Address([255, 255, 0, 0])); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 16).netmask(), - Address([255, 255, 0, 0])); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 17).netmask(), - Address([255, 255, 128, 0])); - assert_eq!(Cidr::new(Address([172, 16, 0, 0]), 12).netmask(), - Address([255, 240, 0, 0])); - assert_eq!(Cidr::new(Address([255, 255, 255, 1]), 24).netmask(), - Address([255, 255, 255, 0])); - assert_eq!(Cidr::new(Address([255, 255, 255, 255]), 32).netmask(), - Address([255, 255, 255, 255])); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).netmask(), + Address([0, 0, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).netmask(), + Address([255, 255, 255, 0]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 32).netmask(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).netmask(), + Address([255, 0, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16).netmask(), + Address([255, 255, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16).netmask(), + Address([255, 255, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17).netmask(), + Address([255, 255, 128, 0]) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 0]), 12).netmask(), + Address([255, 240, 0, 0]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24).netmask(), + Address([255, 255, 255, 0]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).netmask(), + Address([255, 255, 255, 255]) + ); } #[test] fn test_cidr_broadcast() { - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 0).broadcast().unwrap(), - Address([255, 255, 255, 255])); - assert_eq!(Cidr::new(Address([0, 0, 0, 1]), 24).broadcast().unwrap(), - Address([0, 0, 0, 255])); - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 32).broadcast(), - None); - assert_eq!(Cidr::new(Address([127, 0, 0, 0]), 8).broadcast().unwrap(), - Address([127, 255, 255, 255])); - assert_eq!(Cidr::new(Address([192, 168, 0, 0]), 16).broadcast().unwrap(), - Address([192, 168, 255, 255])); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 16).broadcast().unwrap(), - Address([192, 168, 255, 255])); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 17).broadcast().unwrap(), - Address([192, 168, 127, 255])); - assert_eq!(Cidr::new(Address([172, 16, 0, 1]), 12).broadcast().unwrap(), - Address([172, 31, 255, 255])); - assert_eq!(Cidr::new(Address([255, 255, 255, 1]), 24).broadcast().unwrap(), - Address([255, 255, 255, 255])); - assert_eq!(Cidr::new(Address([255, 255, 255, 254]), 31).broadcast(), - None); - assert_eq!(Cidr::new(Address([255, 255, 255, 255]), 32).broadcast(), - None); - + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).broadcast().unwrap(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).broadcast().unwrap(), + Address([0, 0, 0, 255]) + ); + assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 32).broadcast(), None); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).broadcast().unwrap(), + Address([127, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16) + .broadcast() + .unwrap(), + Address([192, 168, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16) + .broadcast() + .unwrap(), + Address([192, 168, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17) + .broadcast() + .unwrap(), + Address([192, 168, 127, 255]) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 1]), 12).broadcast().unwrap(), + Address([172, 31, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24) + .broadcast() + .unwrap(), + Address([255, 255, 255, 255]) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 254]), 31).broadcast(), + None + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).broadcast(), + None + ); } #[test] fn test_cidr_network() { - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 0).network(), - Cidr::new(Address([0, 0, 0, 0]), 0)); - assert_eq!(Cidr::new(Address([0, 0, 0, 1]), 24).network(), - Cidr::new(Address([0, 0, 0, 0]), 24)); - assert_eq!(Cidr::new(Address([0, 0, 0, 0]), 32).network(), - Cidr::new(Address([0, 0, 0, 0]), 32)); - assert_eq!(Cidr::new(Address([127, 0, 0, 0]), 8).network(), - Cidr::new(Address([127, 0, 0, 0]), 8)); - assert_eq!(Cidr::new(Address([192, 168, 0, 0]), 16).network(), - Cidr::new(Address([192, 168, 0, 0]), 16)); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 16).network(), - Cidr::new(Address([192, 168, 0, 0]), 16)); - assert_eq!(Cidr::new(Address([192, 168, 1, 1]), 17).network(), - Cidr::new(Address([192, 168, 0, 0]), 17)); - assert_eq!(Cidr::new(Address([172, 16, 0, 1]), 12).network(), - Cidr::new(Address([172, 16, 0, 0]), 12)); - assert_eq!(Cidr::new(Address([255, 255, 255, 1]), 24).network(), - Cidr::new(Address([255, 255, 255, 0]), 24)); - assert_eq!(Cidr::new(Address([255, 255, 255, 255]), 32).network(), - Cidr::new(Address([255, 255, 255, 255]), 32)); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 0).network(), + Cidr::new(Address([0, 0, 0, 0]), 0) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 1]), 24).network(), + Cidr::new(Address([0, 0, 0, 0]), 24) + ); + assert_eq!( + Cidr::new(Address([0, 0, 0, 0]), 32).network(), + Cidr::new(Address([0, 0, 0, 0]), 32) + ); + assert_eq!( + Cidr::new(Address([127, 0, 0, 0]), 8).network(), + Cidr::new(Address([127, 0, 0, 0]), 8) + ); + assert_eq!( + Cidr::new(Address([192, 168, 0, 0]), 16).network(), + Cidr::new(Address([192, 168, 0, 0]), 16) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 16).network(), + Cidr::new(Address([192, 168, 0, 0]), 16) + ); + assert_eq!( + Cidr::new(Address([192, 168, 1, 1]), 17).network(), + Cidr::new(Address([192, 168, 0, 0]), 17) + ); + assert_eq!( + Cidr::new(Address([172, 16, 0, 1]), 12).network(), + Cidr::new(Address([172, 16, 0, 0]), 12) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 1]), 24).network(), + Cidr::new(Address([255, 255, 255, 0]), 24) + ); + assert_eq!( + Cidr::new(Address([255, 255, 255, 255]), 32).network(), + Cidr::new(Address([255, 255, 255, 255]), 32) + ); } } diff --git a/src/wire/ipv6.rs b/src/wire/ipv6.rs index 065b239f6..d624f246d 100644 --- a/src/wire/ipv6.rs +++ b/src/wire/ipv6.rs @@ -1,61 +1,100 @@ #![deny(missing_docs)] -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; + +use super::{Error, Result}; +use crate::wire::ip::pretty_print_ip_payload; +#[cfg(feature = "proto-ipv4")] +use crate::wire::ipv4; -use {Error, Result}; pub use super::IpProtocol as Protocol; -use super::ip::pretty_print_ip_payload; /// Minimum MTU required of all links supporting IPv6. See [RFC 8200 § 5]. /// /// [RFC 8200 § 5]: https://tools.ietf.org/html/rfc8200#section-5 pub const MIN_MTU: usize = 1280; +/// Size of IPv6 adderess in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc4291#section-2 +pub const ADDR_SIZE: usize = 16; + +/// Size of IPv4-mapping prefix in octets. +/// +/// [RFC 8200 § 2]: https://www.rfc-editor.org/rfc/rfc4291#section-2 +pub const IPV4_MAPPED_PREFIX_SIZE: usize = ADDR_SIZE - 4; // 4 == ipv4::ADDR_SIZE , cannot DRY here because of dependency on a IPv4 module which is behind the feature + /// A sixteen-octet IPv6 address. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] -pub struct Address(pub [u8; 16]); +pub struct Address(pub [u8; ADDR_SIZE]); impl Address { /// The [unspecified address]. /// /// [unspecified address]: https://tools.ietf.org/html/rfc4291#section-2.5.2 - pub const UNSPECIFIED: Address = Address([0x00; 16]); - - /// The link-local [all routers multicast address]. - /// - /// [all routers multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7.1 - pub const LINK_LOCAL_ALL_NODES: Address = - Address([0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]); + pub const UNSPECIFIED: Address = Address([0x00; ADDR_SIZE]); /// The link-local [all nodes multicast address]. /// /// [all nodes multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7.1 - pub const LINK_LOCAL_ALL_ROUTERS: Address = - Address([0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02]); + pub const LINK_LOCAL_ALL_NODES: Address = Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + ]); + + /// The link-local [all routers multicast address]. + /// + /// [all routers multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7.1 + pub const LINK_LOCAL_ALL_ROUTERS: Address = Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, + ]); /// The [loopback address]. /// /// [loopback address]: https://tools.ietf.org/html/rfc4291#section-2.5.3 - pub const LOOPBACK: Address = - Address([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01]); + pub const LOOPBACK: Address = Address([ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + ]); + + /// The prefix used in [IPv4-mapped addresses]. + /// + /// [IPv4-mapped addresses]: https://www.rfc-editor.org/rfc/rfc4291#section-2.5.5.2 + pub const IPV4_MAPPED_PREFIX: [u8; IPV4_MAPPED_PREFIX_SIZE] = + [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff]; /// Construct an IPv6 address from parts. - pub fn new(a0: u16, a1: u16, a2: u16, a3: u16, - a4: u16, a5: u16, a6: u16, a7: u16) -> Address { - let mut addr = [0u8; 16]; - NetworkEndian::write_u16(&mut addr[0..2], a0); - NetworkEndian::write_u16(&mut addr[2..4], a1); - NetworkEndian::write_u16(&mut addr[4..6], a2); - NetworkEndian::write_u16(&mut addr[6..8], a3); - NetworkEndian::write_u16(&mut addr[8..10], a4); - NetworkEndian::write_u16(&mut addr[10..12], a5); - NetworkEndian::write_u16(&mut addr[12..14], a6); - NetworkEndian::write_u16(&mut addr[14..16], a7); - Address(addr) + #[allow(clippy::too_many_arguments)] + pub const fn new( + a0: u16, + a1: u16, + a2: u16, + a3: u16, + a4: u16, + a5: u16, + a6: u16, + a7: u16, + ) -> Address { + Address([ + (a0 >> 8) as u8, + a0 as u8, + (a1 >> 8) as u8, + a1 as u8, + (a2 >> 8) as u8, + a2 as u8, + (a3 >> 8) as u8, + a3 as u8, + (a4 >> 8) as u8, + a4 as u8, + (a5 >> 8) as u8, + a5 as u8, + (a6 >> 8) as u8, + a6 as u8, + (a7 >> 8) as u8, + a7 as u8, + ]) } /// Construct an IPv6 address from a sequence of octets, in big-endian. @@ -63,7 +102,7 @@ impl Address { /// # Panics /// The function panics if `data` is not sixteen octets long. pub fn from_bytes(data: &[u8]) -> Address { - let mut bytes = [0; 16]; + let mut bytes = [0; ADDR_SIZE]; bytes.copy_from_slice(data); Address(bytes) } @@ -74,10 +113,9 @@ impl Address { /// The function panics if `data` is not 8 words long. pub fn from_parts(data: &[u16]) -> Address { assert!(data.len() >= 8); - let mut bytes = [0; 16]; - for word_idx in 0..8 { - let byte_idx = word_idx * 2; - NetworkEndian::write_u16(&mut bytes[byte_idx..(byte_idx + 2)], data[word_idx]); + let mut bytes = [0; ADDR_SIZE]; + for (word_idx, chunk) in bytes.chunks_mut(2).enumerate() { + NetworkEndian::write_u16(chunk, data[word_idx]); } Address(bytes) } @@ -88,14 +126,13 @@ impl Address { /// The function panics if `data` is not 8 words long. pub fn write_parts(&self, data: &mut [u16]) { assert!(data.len() >= 8); - for i in 0..8 { - let byte_idx = i * 2; - data[i] = NetworkEndian::read_u16(&self.0[byte_idx..(byte_idx + 2)]); + for (i, chunk) in self.0.chunks(2).enumerate() { + data[i] = NetworkEndian::read_u16(chunk); } } /// Return an IPv6 address as a sequence of octets, in big-endian. - pub fn as_bytes(&self) -> &[u8] { + pub const fn as_bytes(&self) -> &[u8] { &self.0 } @@ -109,7 +146,7 @@ impl Address { /// Query whether the IPv6 address is a [multicast address]. /// /// [multicast address]: https://tools.ietf.org/html/rfc4291#section-2.7 - pub fn is_multicast(&self) -> bool { + pub const fn is_multicast(&self) -> bool { self.0[0] == 0xff } @@ -117,15 +154,14 @@ impl Address { /// /// [unspecified address]: https://tools.ietf.org/html/rfc4291#section-2.5.2 pub fn is_unspecified(&self) -> bool { - self.0 == [0x00; 16] + self.0 == [0x00; ADDR_SIZE] } /// Query whether the IPv6 address is in the [link-local] scope. /// /// [link-local]: https://tools.ietf.org/html/rfc4291#section-2.5.6 pub fn is_link_local(&self) -> bool { - self.0[0..8] == [0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00] + self.0[0..8] == [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00] } /// Query whether the IPv6 address is the [loopback address]. @@ -139,31 +175,33 @@ impl Address { /// /// [IPv4 mapped IPv6 address]: https://tools.ietf.org/html/rfc4291#section-2.5.5.2 pub fn is_ipv4_mapped(&self) -> bool { - self.0[0..12] == [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff] + self.0[..IPV4_MAPPED_PREFIX_SIZE] == Self::IPV4_MAPPED_PREFIX } #[cfg(feature = "proto-ipv4")] /// Convert an IPv4 mapped IPv6 address to an IPv4 address. - pub fn as_ipv4(&self) -> Option<::wire::ipv4::Address> { + pub fn as_ipv4(&self) -> Option { if self.is_ipv4_mapped() { - Some(::wire::ipv4::Address::new(self.0[12], self.0[13], self.0[14], self.0[15])) + Some(ipv4::Address::from_bytes( + &self.0[IPV4_MAPPED_PREFIX_SIZE..], + )) } else { None } } - /// Helper function used to mask an addres given a prefix. + /// Helper function used to mask an address given a prefix. /// /// # Panics /// This function panics if `mask` is greater than 128. - pub(super) fn mask(&self, mask: u8) -> [u8; 16] { + pub(super) fn mask(&self, mask: u8) -> [u8; ADDR_SIZE] { assert!(mask <= 128); - let mut bytes = [0u8; 16]; + let mut bytes = [0u8; ADDR_SIZE]; let idx = (mask as usize) / 8; let modulus = (mask as usize) % 8; let (first, second) = self.0.split_at(idx); - bytes[0..idx].copy_from_slice(&first); - if idx < 16 { + bytes[0..idx].copy_from_slice(first); + if idx < ADDR_SIZE { let part = second[0]; bytes[idx] = part & (!(0xff >> modulus) as u8); } @@ -177,10 +215,17 @@ impl Address { /// unicast. pub fn solicited_node(&self) -> Address { assert!(self.is_unicast()); - let mut bytes = [0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; - bytes[14..].copy_from_slice(&self.0[14..]); - Address(bytes) + Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xFF, + self.0[13], self.0[14], self.0[15], + ]) + } + + /// Convert to an `IpAddress`. + /// + /// Same as `.into()`, but works in `const`. + pub const fn into_address(self) -> super::IpAddress { + super::IpAddress::Ipv6(self) } } @@ -201,7 +246,14 @@ impl From
for ::std::net::Ipv6Addr { impl fmt::Display for Address { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { if self.is_ipv4_mapped() { - return write!(f, "::ffff:{}.{}.{}.{}", self.0[12], self.0[13], self.0[14], self.0[15]) + return write!( + f, + "::ffff:{}.{}.{}.{}", + self.0[IPV4_MAPPED_PREFIX_SIZE + 0], + self.0[IPV4_MAPPED_PREFIX_SIZE + 1], + self.0[IPV4_MAPPED_PREFIX_SIZE + 2], + self.0[IPV4_MAPPED_PREFIX_SIZE + 3] + ); } // The string representation of an IPv6 address should @@ -214,7 +266,7 @@ impl fmt::Display for Address { Head, HeadBody, Tail, - TailBody + TailBody, } let mut words = [0u16; 8]; self.write_parts(&mut words); @@ -226,24 +278,24 @@ impl fmt::Display for Address { (0, &State::Head) | (0, &State::HeadBody) => { write!(f, "::")?; State::Tail - }, + } // Continue iterating without writing any characters until - // we hit anothing non-zero value. + // we hit a non-zero value. (0, &State::Tail) => State::Tail, // When the state is Head or Tail write a u16 in hexadecimal // without the leading colon if the value is not 0. (_, &State::Head) => { - write!(f, "{:x}", word)?; + write!(f, "{word:x}")?; State::HeadBody - }, + } (_, &State::Tail) => { - write!(f, "{:x}", word)?; + write!(f, "{word:x}")?; State::TailBody - }, + } // Write the u16 with a leading colon when parsing a value // that isn't the first in a section (_, &State::HeadBody) | (_, &State::TailBody) => { - write!(f, ":{:x}", word)?; + write!(f, ":{word:x}")?; state } } @@ -252,13 +304,75 @@ impl fmt::Display for Address { } } +#[cfg(feature = "defmt")] +impl defmt::Format for Address { + fn format(&self, f: defmt::Formatter) { + if self.is_ipv4_mapped() { + return defmt::write!( + f, + "::ffff:{}.{}.{}.{}", + self.0[IPV4_MAPPED_PREFIX_SIZE + 0], + self.0[IPV4_MAPPED_PREFIX_SIZE + 1], + self.0[IPV4_MAPPED_PREFIX_SIZE + 2], + self.0[IPV4_MAPPED_PREFIX_SIZE + 3] + ); + } + + // The string representation of an IPv6 address should + // collapse a series of 16 bit sections that evaluate + // to 0 to "::" + // + // See https://tools.ietf.org/html/rfc4291#section-2.2 + // for details. + enum State { + Head, + HeadBody, + Tail, + TailBody, + } + let mut words = [0u16; 8]; + self.write_parts(&mut words); + let mut state = State::Head; + for word in words.iter() { + state = match (*word, &state) { + // Once a u16 equal to zero write a double colon and + // skip to the next non-zero u16. + (0, &State::Head) | (0, &State::HeadBody) => { + defmt::write!(f, "::"); + State::Tail + } + // Continue iterating without writing any characters until + // we hit a non-zero value. + (0, &State::Tail) => State::Tail, + // When the state is Head or Tail write a u16 in hexadecimal + // without the leading colon if the value is not 0. + (_, &State::Head) => { + defmt::write!(f, "{:x}", word); + State::HeadBody + } + (_, &State::Tail) => { + defmt::write!(f, "{:x}", word); + State::TailBody + } + // Write the u16 with a leading colon when parsing a value + // that isn't the first in a section + (_, &State::HeadBody) | (_, &State::TailBody) => { + defmt::write!(f, ":{:x}", word); + state + } + } + } + } +} + #[cfg(feature = "proto-ipv4")] /// Convert the given IPv4 address into a IPv4-mapped IPv6 address -impl From<::wire::ipv4::Address> for Address { - fn from(address: ::wire::ipv4::Address) -> Self { - let octets = address.0; - Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, - octets[0], octets[1], octets[2], octets[3]]) +impl From for Address { + fn from(address: ipv4::Address) -> Self { + let mut b = [0_u8; ADDR_SIZE]; + b[..Self::IPV4_MAPPED_PREFIX.len()].copy_from_slice(&Self::IPV4_MAPPED_PREFIX); + b[Self::IPV4_MAPPED_PREFIX.len()..].copy_from_slice(&address.0); + Self(b) } } @@ -266,7 +380,7 @@ impl From<::wire::ipv4::Address> for Address { /// subnet masking prefix length. #[derive(Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, Copy, Default)] pub struct Cidr { - address: Address, + address: Address, prefix_len: u8, } @@ -274,29 +388,33 @@ impl Cidr { /// The [solicited node prefix]. /// /// [solicited node prefix]: https://tools.ietf.org/html/rfc4291#section-2.7.1 - pub const SOLICITED_NODE_PREFIX: Cidr = - Cidr { - address: Address([0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, 0xff, 0x00, 0x00, 0x00]), - prefix_len: 104 - }; + pub const SOLICITED_NODE_PREFIX: Cidr = Cidr { + address: Address([ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x00, + 0x00, 0x00, + ]), + prefix_len: 104, + }; /// Create an IPv6 CIDR block from the given address and prefix length. /// /// # Panics /// This function panics if the prefix length is larger than 128. - pub fn new(address: Address, prefix_len: u8) -> Cidr { + pub const fn new(address: Address, prefix_len: u8) -> Cidr { assert!(prefix_len <= 128); - Cidr { address, prefix_len } + Cidr { + address, + prefix_len, + } } /// Return the address of this IPv6 CIDR block. - pub fn address(&self) -> Address { + pub const fn address(&self) -> Address { self.address } /// Return the prefix length of this IPv6 CIDR block. - pub fn prefix_len(&self) -> u8 { + pub const fn prefix_len(&self) -> u8 { self.prefix_len } @@ -304,10 +422,11 @@ impl Cidr { /// the given address. pub fn contains_addr(&self, addr: &Address) -> bool { // right shift by 128 is not legal - if self.prefix_len == 0 { return true } + if self.prefix_len == 0 { + return true; + } - let shift = 128 - self.prefix_len; - self.address.mask(shift) == addr.mask(shift) + self.address.mask(self.prefix_len) == addr.mask(self.prefix_len) } /// Query whether the subnetwork described by this IPV6 CIDR block contains @@ -324,10 +443,18 @@ impl fmt::Display for Cidr { } } +#[cfg(feature = "defmt")] +impl defmt::Format for Cidr { + fn format(&self, f: defmt::Formatter) { + defmt::write!(f, "{}/{=u8}", self.address, self.prefix_len); + } +} + /// A read/write wrapper around an Internet Protocol version 6 packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } // Ranges and constants describing the IPv6 header @@ -356,29 +483,32 @@ pub struct Packet> { // // See https://tools.ietf.org/html/rfc2460#section-3 for details. mod field { - use wire::field::*; + use crate::wire::field::*; // 4-bit version number, 8-bit traffic class, and the // 20-bit flow label. pub const VER_TC_FLOW: Field = 0..4; // 16-bit value representing the length of the payload. // Note: Options are included in this length. - pub const LENGTH: Field = 4..6; + pub const LENGTH: Field = 4..6; // 8-bit value identifying the type of header following this // one. Note: The same numbers are used in IPv4. - pub const NXT_HDR: usize = 6; + pub const NXT_HDR: usize = 6; // 8-bit value decremented by each node that forwards this // packet. The packet is discarded when the value is 0. - pub const HOP_LIMIT: usize = 7; + pub const HOP_LIMIT: usize = 7; // IPv6 address of the source node. - pub const SRC_ADDR: Field = 8..24; + pub const SRC_ADDR: Field = 8..24; // IPv6 address of the destination node. - pub const DST_ADDR: Field = 24..40; + pub const DST_ADDR: Field = 24..40; } +/// Length of an IPv6 header. +pub const HEADER_LEN: usize = field::DST_ADDR.end; + impl> Packet { /// Create a raw octet buffer with an IPv6 packet structure. #[inline] - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -394,7 +524,7 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_payload_len]. /// @@ -403,7 +533,7 @@ impl> Packet { pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::DST_ADDR.end || len < self.total_len() { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -417,7 +547,7 @@ impl> Packet { /// Return the header length. #[inline] - pub fn header_len(&self) -> usize { + pub const fn header_len(&self) -> usize { // This is not a strictly necessary function, but it makes // code more readable. field::DST_ADDR.end @@ -574,9 +704,9 @@ impl + AsMut<[u8]>> Packet { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "IPv6 ({})", err)?; + write!(f, "IPv6 ({err})")?; Ok(()) } } @@ -593,15 +723,15 @@ impl> AsRef<[u8]> for Packet { #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Repr { /// IPv6 address of the source node. - pub src_addr: Address, + pub src_addr: Address, /// IPv6 address of the destination node. - pub dst_addr: Address, + pub dst_addr: Address, /// Protocol contained in the next header. pub next_header: Protocol, /// Length of the payload including the extension headers. pub payload_len: usize, /// The 8-bit hop limit field. - pub hop_limit: u8 + pub hop_limit: u8, } impl Repr { @@ -609,18 +739,20 @@ impl Repr { pub fn parse + ?Sized>(packet: &Packet<&T>) -> Result { // Ensure basic accessors will work packet.check_len()?; - if packet.version() != 6 { return Err(Error::Malformed); } + if packet.version() != 6 { + return Err(Error); + } Ok(Repr { - src_addr: packet.src_addr(), - dst_addr: packet.dst_addr(), + src_addr: packet.src_addr(), + dst_addr: packet.dst_addr(), next_header: packet.next_header(), payload_len: packet.payload_len() as usize, - hop_limit: packet.hop_limit() + hop_limit: packet.hop_limit(), }) } /// Return the length of a header that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { // This function is not strictly necessary, but it can make client code more readable. field::DST_ADDR.end } @@ -642,29 +774,47 @@ impl Repr { impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "IPv6 src={} dst={} nxt_hdr={} hop_limit={}", - self.src_addr, self.dst_addr, self.next_header, self.hop_limit) + write!( + f, + "IPv6 src={} dst={} nxt_hdr={} hop_limit={}", + self.src_addr, self.dst_addr, self.next_header, self.hop_limit + ) } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "IPv6 src={} dst={} nxt_hdr={} hop_limit={}", + self.src_addr, + self.dst_addr, + self.next_header, + self.hop_limit + ) + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; // TODO: This is very similar to the implementation for IPv4. Make // a way to have less copy and pasted code here. impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { let (ip_repr, payload) = match Packet::new_checked(buffer) { - Err(err) => return write!(f, "{}({})", indent, err), - Ok(ip_packet) => { - match Repr::parse(&ip_packet) { - Err(_) => return Ok(()), - Ok(ip_repr) => { - write!(f, "{}{}", indent, ip_repr)?; - (ip_repr, ip_packet.payload()) - } + Err(err) => return write!(f, "{indent}({err})"), + Ok(ip_packet) => match Repr::parse(&ip_packet) { + Err(_) => return Ok(()), + Ok(ip_repr) => { + write!(f, "{indent}{ip_repr}")?; + (ip_repr, ip_packet.payload()) } - } + }, }; pretty_print_ip_payload(f, indent, ip_repr, payload) @@ -673,18 +823,18 @@ impl> PrettyPrint for Packet { #[cfg(test)] mod test { - use Error; + use super::Error; use super::{Address, Cidr}; use super::{Packet, Protocol, Repr}; - use wire::pretty_print::{PrettyPrinter}; + use crate::wire::pretty_print::PrettyPrinter; #[cfg(feature = "proto-ipv4")] - use wire::ipv4::Address as Ipv4Address; + use crate::wire::ipv4::Address as Ipv4Address; - static LINK_LOCAL_ADDR: Address = Address([0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01]); + static LINK_LOCAL_ADDR: Address = Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x01, + ]); #[test] fn test_basic_multicast() { assert!(!Address::LINK_LOCAL_ALL_ROUTERS.is_unspecified()); @@ -715,48 +865,62 @@ mod test { #[test] fn test_address_format() { - assert_eq!("ff02::1", - format!("{}", Address::LINK_LOCAL_ALL_NODES)); - assert_eq!("fe80::1", - format!("{}", LINK_LOCAL_ADDR)); - assert_eq!("fe80::7f00:0:1", - format!("{}", Address::new(0xfe80, 0, 0, 0, 0, 0x7f00, 0x0000, 0x0001))); - assert_eq!("::", - format!("{}", Address::UNSPECIFIED)); - assert_eq!("::1", - format!("{}", Address::LOOPBACK)); + assert_eq!("ff02::1", format!("{}", Address::LINK_LOCAL_ALL_NODES)); + assert_eq!("fe80::1", format!("{LINK_LOCAL_ADDR}")); + assert_eq!( + "fe80::7f00:0:1", + format!( + "{}", + Address::new(0xfe80, 0, 0, 0, 0, 0x7f00, 0x0000, 0x0001) + ) + ); + assert_eq!("::", format!("{}", Address::UNSPECIFIED)); + assert_eq!("::1", format!("{}", Address::LOOPBACK)); #[cfg(feature = "proto-ipv4")] - assert_eq!("::ffff:192.168.1.1", - format!("{}", Address::from(Ipv4Address::new(192, 168, 1, 1)))); + assert_eq!( + "::ffff:192.168.1.1", + format!("{}", Address::from(Ipv4Address::new(192, 168, 1, 1))) + ); } #[test] fn test_new() { - assert_eq!(Address::new(0xff02, 0, 0, 0, 0, 0, 0, 1), - Address::LINK_LOCAL_ALL_NODES); - assert_eq!(Address::new(0xff02, 0, 0, 0, 0, 0, 0, 2), - Address::LINK_LOCAL_ALL_ROUTERS); - assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 1), - Address::LOOPBACK); - assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 0), - Address::UNSPECIFIED); - assert_eq!(Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), - LINK_LOCAL_ADDR); + assert_eq!( + Address::new(0xff02, 0, 0, 0, 0, 0, 0, 1), + Address::LINK_LOCAL_ALL_NODES + ); + assert_eq!( + Address::new(0xff02, 0, 0, 0, 0, 0, 0, 2), + Address::LINK_LOCAL_ALL_ROUTERS + ); + assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 1), Address::LOOPBACK); + assert_eq!(Address::new(0, 0, 0, 0, 0, 0, 0, 0), Address::UNSPECIFIED); + assert_eq!(Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), LINK_LOCAL_ADDR); } #[test] fn test_from_parts() { - assert_eq!(Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 1]), - Address::LINK_LOCAL_ALL_NODES); - assert_eq!(Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 2]), - Address::LINK_LOCAL_ALL_ROUTERS); - assert_eq!(Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 1]), - Address::LOOPBACK); - assert_eq!(Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 0]), - Address::UNSPECIFIED); - assert_eq!(Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 1]), - LINK_LOCAL_ADDR); + assert_eq!( + Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 1]), + Address::LINK_LOCAL_ALL_NODES + ); + assert_eq!( + Address::from_parts(&[0xff02, 0, 0, 0, 0, 0, 0, 2]), + Address::LINK_LOCAL_ALL_ROUTERS + ); + assert_eq!( + Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 1]), + Address::LOOPBACK + ); + assert_eq!( + Address::from_parts(&[0, 0, 0, 0, 0, 0, 0, 0]), + Address::UNSPECIFIED + ); + assert_eq!( + Address::from_parts(&[0xfe80, 0, 0, 0, 0, 0, 0, 1]), + LINK_LOCAL_ADDR + ); } #[test] @@ -779,18 +943,33 @@ mod test { #[test] fn test_mask() { let addr = Address::new(0x0123, 0x4567, 0x89ab, 0, 0, 0, 0, 1); - assert_eq!(addr.mask(11), [0x01, 0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - assert_eq!(addr.mask(15), [0x01, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - assert_eq!(addr.mask(26), [0x01, 0x23, 0x45, 0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); - assert_eq!(addr.mask(128), [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1]); - assert_eq!(addr.mask(127), [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]); + assert_eq!( + addr.mask(11), + [0x01, 0x20, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(15), + [0x01, 0x22, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(26), + [0x01, 0x23, 0x45, 0x40, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); + assert_eq!( + addr.mask(128), + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1] + ); + assert_eq!( + addr.mask(127), + [0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0] + ); } #[cfg(feature = "proto-ipv4")] #[test] fn test_is_ipv4_mapped() { - assert_eq!(false, Address::UNSPECIFIED.is_ipv4_mapped()); - assert_eq!(true, Address::from(Ipv4Address::new(192, 168, 1, 1)).is_ipv4_mapped()); + assert!(!Address::UNSPECIFIED.is_ipv4_mapped()); + assert!(Address::from(Ipv4Address::new(192, 168, 1, 1)).is_ipv4_mapped()); } #[cfg(feature = "proto-ipv4")] @@ -805,63 +984,129 @@ mod test { #[cfg(feature = "proto-ipv4")] #[test] fn test_from_ipv4_address() { - assert_eq!(Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]), - Address::from(Ipv4Address::new(192, 168, 1, 1))); - assert_eq!(Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 222, 1, 41, 90]), - Address::from(Ipv4Address::new(222, 1, 41, 90))); + assert_eq!( + Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 192, 168, 1, 1]), + Address::from(Ipv4Address::new(192, 168, 1, 1)) + ); + assert_eq!( + Address([0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0xff, 0xff, 222, 1, 41, 90]), + Address::from(Ipv4Address::new(222, 1, 41, 90)) + ); } #[test] fn test_cidr() { - let cidr = Cidr::new(LINK_LOCAL_ADDR, 64); + // fe80::1/56 + // 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + // 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + let cidr = Cidr::new(LINK_LOCAL_ADDR, 56); let inside_subnet = [ - [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02], - [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, 0x77, 0x88], - [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00], - [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff] + // fe80::2 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ], + // fe80::1122:3344:5566:7788 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x11, 0x22, 0x33, 0x44, 0x55, 0x66, + 0x77, 0x88, + ], + // fe80::ff00:0:0:0 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ], + // fe80::ff + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0xff, + ], ]; let outside_subnet = [ - [0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], - [0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], - [0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], - [0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02] + // fe80:0:0:101::1 + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ::1 + [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ff02::1 + [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ], + // ff02::2 + [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x02, + ], ]; let subnets = [ - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], - 65), - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], - 128), - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x12, 0x34, 0x56, 0x78], - 96) + // fe80::ffff:ffff:ffff:ffff/65 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 65, + ), + // fe80::1/128 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ], + 128, + ), + // fe80::1234:5678/96 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x12, + 0x34, 0x56, 0x78, + ], + 96, + ), ]; let not_subnets = [ - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], - 63), - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], - 64), - ([0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff], - 65), - ([0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01], - 128) + // fe80::101:ffff:ffff:ffff:ffff/55 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 55, + ), + // fe80::101:ffff:ffff:ffff:ffff/56 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 56, + ), + // fe80::101:ffff:ffff:ffff:ffff/57 + ( + [ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x01, 0x01, 0xff, 0xff, 0xff, 0xff, 0xff, + 0xff, 0xff, 0xff, + ], + 57, + ), + // ::1/128 + ( + [ + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x01, + ], + 128, + ), ]; for addr in inside_subnet.iter().map(|a| Address::from_bytes(a)) { @@ -872,13 +1117,11 @@ mod test { assert!(!cidr.contains_addr(&addr)); } - for subnet in subnets.iter().map( - |&(a, p)| Cidr::new(Address(a), p)) { + for subnet in subnets.iter().map(|&(a, p)| Cidr::new(Address(a), p)) { assert!(cidr.contains_subnet(&subnet)); } - for subnet in not_subnets.iter().map( - |&(a, p)| Cidr::new(Address(a), p)) { + for subnet in not_subnets.iter().map(|&(a, p)| Cidr::new(Address(a), p)) { assert!(!cidr.contains_subnet(&subnet)); } @@ -887,7 +1130,7 @@ mod test { } #[test] - #[should_panic(expected = "destination and source slices have different lengths")] + #[should_panic(expected = "length")] fn test_from_bytes_too_long() { let _ = Address::from_bytes(&[0u8; 15]); } @@ -898,33 +1141,26 @@ mod test { let _ = Address::from_parts(&[0u16; 7]); } - static REPR_PACKET_BYTES: [u8; 52] = [0x60, 0x00, 0x00, 0x00, - 0x00, 0x0c, 0x11, 0x40, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0x00, 0x01, 0x00, 0x02, - 0x00, 0x0c, 0x02, 0x4e, - 0xff, 0xff, 0xff, 0xff]; - static REPR_PAYLOAD_BYTES: [u8; 12] = [0x00, 0x01, 0x00, 0x02, - 0x00, 0x0c, 0x02, 0x4e, - 0xff, 0xff, 0xff, 0xff]; - - fn packet_repr() -> Repr { + static REPR_PACKET_BYTES: [u8; 52] = [ + 0x60, 0x00, 0x00, 0x00, 0x00, 0x0c, 0x11, 0x40, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x01, 0x00, 0x02, 0x00, + 0x0c, 0x02, 0x4e, 0xff, 0xff, 0xff, 0xff, + ]; + static REPR_PAYLOAD_BYTES: [u8; 12] = [ + 0x00, 0x01, 0x00, 0x02, 0x00, 0x0c, 0x02, 0x4e, 0xff, 0xff, 0xff, 0xff, + ]; + + const fn packet_repr() -> Repr { Repr { - src_addr: Address([0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01]), - dst_addr: Address::LINK_LOCAL_ALL_NODES, + src_addr: Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]), + dst_addr: Address::LINK_LOCAL_ALL_NODES, next_header: Protocol::Udp, payload_len: 12, - hop_limit: 64 + hop_limit: 64, } } @@ -939,10 +1175,13 @@ mod test { assert_eq!(packet.payload_len() as usize, REPR_PAYLOAD_BYTES.len()); assert_eq!(packet.next_header(), Protocol::Udp); assert_eq!(packet.hop_limit(), 0x40); - assert_eq!(packet.src_addr(), Address([0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01])); + assert_eq!( + packet.src_addr(), + Address([ + 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01 + ]) + ); assert_eq!(packet.dst_addr(), Address::LINK_LOCAL_ALL_NODES); assert_eq!(packet.payload(), &REPR_PAYLOAD_BYTES[..]); } @@ -967,20 +1206,19 @@ mod test { packet.set_hop_limit(0xfe); packet.set_src_addr(Address::LINK_LOCAL_ALL_ROUTERS); packet.set_dst_addr(Address::LINK_LOCAL_ALL_NODES); - packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES[..]); + packet + .payload_mut() + .copy_from_slice(&REPR_PAYLOAD_BYTES[..]); let mut expected_bytes = [ - 0x69, 0x95, 0x43, 0x21, 0x00, 0x0c, 0x11, 0xfe, - 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, - 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 + 0x69, 0x95, 0x43, 0x21, 0x00, 0x0c, 0x11, 0xfe, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0x02, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, ]; let start = expected_bytes.len() - REPR_PAYLOAD_BYTES.len(); expected_bytes[start..].copy_from_slice(&REPR_PAYLOAD_BYTES[..]); assert_eq!(packet.check_len(), Ok(())); - assert_eq!(&packet.into_inner()[..], &expected_bytes[..]); + assert_eq!(&*packet.into_inner(), &expected_bytes[..]); } #[test] @@ -989,10 +1227,14 @@ mod test { bytes.extend(&REPR_PACKET_BYTES[..]); bytes.push(0); - assert_eq!(Packet::new_unchecked(&bytes).payload().len(), - REPR_PAYLOAD_BYTES.len()); - assert_eq!(Packet::new_unchecked(&mut bytes).payload_mut().len(), - REPR_PAYLOAD_BYTES.len()); + assert_eq!( + Packet::new_unchecked(&bytes).payload().len(), + REPR_PAYLOAD_BYTES.len() + ); + assert_eq!( + Packet::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PAYLOAD_BYTES.len() + ); } #[test] @@ -1001,8 +1243,7 @@ mod test { bytes.extend(&REPR_PACKET_BYTES[..]); Packet::new_unchecked(&mut bytes).set_payload_len(0x80); - assert_eq!(Packet::new_checked(&bytes).unwrap_err(), - Error::Truncated); + assert_eq!(Packet::new_checked(&bytes).unwrap_err(), Error); } #[test] @@ -1019,7 +1260,7 @@ mod test { packet.set_version(4); packet.set_payload_len(0); let packet = Packet::new_unchecked(&*packet.into_inner()); - assert_eq!(Repr::parse(&packet), Err(Error::Malformed)); + assert_eq!(Repr::parse(&packet), Err(Error)); } #[test] @@ -1029,7 +1270,7 @@ mod test { packet.set_version(6); packet.set_payload_len(39); let packet = Packet::new_unchecked(&*packet.into_inner()); - assert_eq!(Repr::parse(&packet), Err(Error::Truncated)); + assert_eq!(Repr::parse(&packet), Err(Error)); } #[test] @@ -1039,7 +1280,7 @@ mod test { packet.set_version(6); packet.set_payload_len(1); let packet = Packet::new_unchecked(&*packet.into_inner()); - assert_eq!(Repr::parse(&packet), Err(Error::Truncated)); + assert_eq!(Repr::parse(&packet), Err(Error)); } #[test] @@ -1049,12 +1290,17 @@ mod test { let mut packet = Packet::new_unchecked(&mut bytes); repr.emit(&mut packet); packet.payload_mut().copy_from_slice(&REPR_PAYLOAD_BYTES); - assert_eq!(&packet.into_inner()[..], &REPR_PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &REPR_PACKET_BYTES[..]); } #[test] fn test_pretty_print() { - assert_eq!(format!("{}", PrettyPrinter::>::new("\n", &&REPR_PACKET_BYTES[..])), - "\nIPv6 src=fe80::1 dst=ff02::1 nxt_hdr=UDP hop_limit=64\n \\ UDP src=1 dst=2 len=4"); + assert_eq!( + format!( + "{}", + PrettyPrinter::>::new("\n", &&REPR_PACKET_BYTES[..]) + ), + "\nIPv6 src=fe80::1 dst=ff02::1 nxt_hdr=UDP hop_limit=64\n \\ UDP src=1 dst=2 len=4" + ); } } diff --git a/src/wire/ipv6ext_header.rs b/src/wire/ipv6ext_header.rs new file mode 100644 index 000000000..bc8ef879f --- /dev/null +++ b/src/wire/ipv6ext_header.rs @@ -0,0 +1,305 @@ +#![allow(unused)] + +use super::IpProtocol; +use super::{Error, Result}; + +mod field { + #![allow(non_snake_case)] + + use crate::wire::field::*; + + pub const MIN_HEADER_SIZE: usize = 8; + + pub const NXT_HDR: usize = 0; + pub const LENGTH: usize = 1; + // Variable-length field. + // + // Length of the header is in 8-octet units, not including the first 8 octets. + // The first two octets are the next header type and the header length. + pub const fn PAYLOAD(length_field: u8) -> Field { + let bytes = length_field as usize * 8 + 8; + 2..bytes + } +} + +/// A read/write wrapper around an IPv6 Extension Header buffer. +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Header> { + buffer: T, +} + +/// Core getter methods relevant to any IPv6 extension header. +impl> Header { + /// Create a raw octet buffer with an IPv6 Extension Header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Header { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result { + let header = Self::new_unchecked(buffer); + header.check_len()?; + Ok(header) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + /// + /// The result of this check is invalidated by calling [set_header_len]. + /// + /// [set_header_len]: #method.set_header_len + pub fn check_len(&self) -> Result<()> { + let data = self.buffer.as_ref(); + + let len = data.len(); + if len < field::MIN_HEADER_SIZE { + return Err(Error); + } + + let of = field::PAYLOAD(data[field::LENGTH]); + if len < of.end { + return Err(Error); + } + + Ok(()) + } + + /// Consume the header, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the next header field. + pub fn next_header(&self) -> IpProtocol { + let data = self.buffer.as_ref(); + IpProtocol::from(data[field::NXT_HDR]) + } + + /// Return the header length field. + pub fn header_len(&self) -> u8 { + let data = self.buffer.as_ref(); + data[field::LENGTH] + } +} + +impl<'h, T: AsRef<[u8]> + ?Sized> Header<&'h T> { + /// Return the payload of the IPv6 extension header. + pub fn payload(&self) -> &'h [u8] { + let data = self.buffer.as_ref(); + &data[field::PAYLOAD(data[field::LENGTH])] + } +} + +impl + AsMut<[u8]>> Header { + /// Set the next header field. + #[inline] + pub fn set_next_header(&mut self, value: IpProtocol) { + let data = self.buffer.as_mut(); + data[field::NXT_HDR] = value.into(); + } + + /// Set the extension header data length. The length of the header is + /// in 8-octet units, not including the first 8 octets. + #[inline] + pub fn set_header_len(&mut self, value: u8) { + let data = self.buffer.as_mut(); + data[field::LENGTH] = value; + } +} + +impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Header<&'a mut T> { + /// Return a mutable pointer to the payload data. + #[inline] + pub fn payload_mut(&mut self) -> &mut [u8] { + let data = self.buffer.as_mut(); + let len = data[field::LENGTH]; + &mut data[field::PAYLOAD(len)] + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Repr<'a> { + pub next_header: IpProtocol, + pub length: u8, + pub data: &'a [u8], +} + +impl<'a> Repr<'a> { + /// Parse an IPv6 Extension Header Header and return a high-level representation. + pub fn parse(header: &Header<&'a T>) -> Result + where + T: AsRef<[u8]> + ?Sized, + { + Ok(Self { + next_header: header.next_header(), + length: header.header_len(), + data: header.payload(), + }) + } + + /// Return the length, in bytes, of a header that will be emitted from this high-level + /// representation. + pub const fn header_len(&self) -> usize { + 2 + } + + /// Emit a high-level representation into an IPv6 Extension Header. + pub fn emit + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { + header.set_next_header(self.next_header); + header.set_header_len(self.length); + } +} + +#[cfg(test)] +mod test { + use super::*; + + // A Hop-by-Hop Option header with a PadN option of option data length 4. + static REPR_PACKET_PAD4: [u8; 8] = [0x6, 0x0, 0x1, 0x4, 0x0, 0x0, 0x0, 0x0]; + + // A Hop-by-Hop Option header with a PadN option of option data length 12. + static REPR_PACKET_PAD12: [u8; 16] = [ + 0x06, 0x1, 0x1, 0x0C, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; + + #[test] + fn test_check_len() { + // zero byte buffer + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..0]).check_len() + ); + // no length field + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..1]).check_len() + ); + // less than 8 bytes + assert_eq!( + Err(Error), + Header::new_unchecked(&REPR_PACKET_PAD4[..7]).check_len() + ); + // valid + assert_eq!(Ok(()), Header::new_unchecked(&REPR_PACKET_PAD4).check_len()); + // valid + assert_eq!( + Ok(()), + Header::new_unchecked(&REPR_PACKET_PAD12).check_len() + ); + // length field value greater than number of bytes + let header: [u8; 8] = [0x06, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0]; + assert_eq!(Err(Error), Header::new_unchecked(&header).check_len()); + } + + #[test] + fn test_header_deconstruct() { + let header = Header::new_unchecked(&REPR_PACKET_PAD4); + assert_eq!(header.next_header(), IpProtocol::Tcp); + assert_eq!(header.header_len(), 0); + assert_eq!(header.payload(), &REPR_PACKET_PAD4[2..]); + + let header = Header::new_unchecked(&REPR_PACKET_PAD12); + assert_eq!(header.next_header(), IpProtocol::Tcp); + assert_eq!(header.header_len(), 1); + assert_eq!(header.payload(), &REPR_PACKET_PAD12[2..]); + } + + #[test] + fn test_overlong() { + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_PAD4[..]); + bytes.push(0); + + assert_eq!( + Header::new_unchecked(&bytes).payload().len(), + REPR_PACKET_PAD4[2..].len() + ); + assert_eq!( + Header::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PACKET_PAD4[2..].len() + ); + + let mut bytes = vec![]; + bytes.extend(&REPR_PACKET_PAD12[..]); + bytes.push(0); + + assert_eq!( + Header::new_unchecked(&bytes).payload().len(), + REPR_PACKET_PAD12[2..].len() + ); + assert_eq!( + Header::new_unchecked(&mut bytes).payload_mut().len(), + REPR_PACKET_PAD12[2..].len() + ); + } + + #[test] + fn test_header_len_overflow() { + let mut bytes = vec![]; + bytes.extend(REPR_PACKET_PAD4); + let len = bytes.len() as u8; + Header::new_unchecked(&mut bytes).set_header_len(len + 1); + + assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error); + + let mut bytes = vec![]; + bytes.extend(REPR_PACKET_PAD12); + let len = bytes.len() as u8; + Header::new_unchecked(&mut bytes).set_header_len(len + 1); + + assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error); + } + + #[test] + fn test_repr_parse_valid() { + let header = Header::new_unchecked(&REPR_PACKET_PAD4); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + next_header: IpProtocol::Tcp, + length: 0, + data: &REPR_PACKET_PAD4[2..] + } + ); + + let header = Header::new_unchecked(&REPR_PACKET_PAD12); + let repr = Repr::parse(&header).unwrap(); + assert_eq!( + repr, + Repr { + next_header: IpProtocol::Tcp, + length: 1, + data: &REPR_PACKET_PAD12[2..] + } + ); + } + + #[test] + fn test_repr_emit() { + let repr = Repr { + next_header: IpProtocol::Tcp, + length: 0, + data: &REPR_PACKET_PAD4[2..], + }; + let mut bytes = [0u8; 2]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &REPR_PACKET_PAD4[..2]); + + let repr = Repr { + next_header: IpProtocol::Tcp, + length: 1, + data: &REPR_PACKET_PAD12[2..], + }; + let mut bytes = [0u8; 2]; + let mut header = Header::new_unchecked(&mut bytes); + repr.emit(&mut header); + assert_eq!(header.into_inner(), &REPR_PACKET_PAD12[..2]); + } +} diff --git a/src/wire/ipv6fragment.rs b/src/wire/ipv6fragment.rs index b5b641fe2..46246deb1 100644 --- a/src/wire/ipv6fragment.rs +++ b/src/wire/ipv6fragment.rs @@ -1,14 +1,15 @@ +use super::{Error, Result}; use core::fmt; -use {Error, Result}; use byteorder::{ByteOrder, NetworkEndian}; pub use super::IpProtocol as Protocol; /// A read/write wrapper around an IPv6 Fragment Header. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Header> { - buffer: T + buffer: T, } // Format of the Fragment Header @@ -20,22 +21,22 @@ pub struct Header> { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // // See https://tools.ietf.org/html/rfc8200#section-4.5 for details. +// +// **NOTE**: The fields start counting after the header length field. mod field { - use wire::field::*; + use crate::wire::field::*; - // 8-bit identifier of the header immediately following this header. - pub const NXT_HDR: usize = 0; - // 8-bit reserved field. - pub const RESERVED: usize = 1; // 16-bit field containing the fragment offset, reserved and more fragments values. - pub const FR_OF_M: Field = 2..4; + pub const FR_OF_M: Field = 0..2; // 32-bit field identifying the fragmented packet - pub const IDENT: Field = 4..8; + pub const IDENT: Field = 2..6; + /// 1 bit flag indicating if there are more fragments coming. + pub const M: usize = 1; } impl> Header { /// Create a raw octet buffer with an IPv6 Fragment Header structure. - pub fn new_unchecked(buffer: T) -> Header { + pub const fn new_unchecked(buffer: T) -> Header { Header { buffer } } @@ -50,13 +51,13 @@ impl> Header { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. pub fn check_len(&self) -> Result<()> { let data = self.buffer.as_ref(); let len = data.len(); if len < field::IDENT.end { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -67,13 +68,6 @@ impl> Header { self.buffer } - /// Return the next header field. - #[inline] - pub fn next_header(&self) -> Protocol { - let data = self.buffer.as_ref(); - Protocol::from(data[field::NXT_HDR]) - } - /// Return the fragment offset field. #[inline] pub fn frag_offset(&self) -> u16 { @@ -85,7 +79,7 @@ impl> Header { #[inline] pub fn more_frags(&self) -> bool { let data = self.buffer.as_ref(); - (data[3] & 0x1) == 1 + (data[field::M] & 0x1) == 1 } /// Return the fragment identification value field. @@ -97,13 +91,6 @@ impl> Header { } impl + AsMut<[u8]>> Header { - /// Set the next header field. - #[inline] - pub fn set_next_header(&mut self, value: Protocol) { - let data = self.buffer.as_mut(); - data[field::NXT_HDR] = value.into(); - } - /// Set reserved fields. /// /// Set 8-bit reserved field after the next header field. @@ -111,11 +98,8 @@ impl + AsMut<[u8]>> Header { #[inline] pub fn clear_reserved(&mut self) { let data = self.buffer.as_mut(); - - data[field::RESERVED] = 0; - // Retain the higher order 5 bits and lower order 1 bit - data[3] = data[3] & 0xf9; + data[field::M] &= 0xf9; } /// Set the fragment offset field. @@ -123,7 +107,7 @@ impl + AsMut<[u8]>> Header { pub fn set_frag_offset(&mut self, value: u16) { let data = self.buffer.as_mut(); // Retain the lower order 3 bits - let raw = ((value & 0x1fff) << 3) | ((data[3] & 0x7) as u16); + let raw = ((value & 0x1fff) << 3) | ((data[field::M] & 0x7) as u16); NetworkEndian::write_u16(&mut data[field::FR_OF_M], raw); } @@ -132,8 +116,8 @@ impl + AsMut<[u8]>> Header { pub fn set_more_frags(&mut self, value: bool) { let data = self.buffer.as_mut(); // Retain the high order 7 bits - let raw = (data[3] & 0xfe) | (value as u8 & 0x1); - data[3] = raw; + let raw = (data[field::M] & 0xfe) | (value as u8 & 0x1); + data[field::M] = raw; } /// Set the fragmentation identification field. @@ -147,9 +131,9 @@ impl + AsMut<[u8]>> Header { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "IPv6 Fragment ({})", err)?; + write!(f, "IPv6 Fragment ({err})")?; Ok(()) } } @@ -158,39 +142,38 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { /// A high-level representation of an IPv6 Fragment header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Repr { - /// The type of header immediately following the Fragment header. - pub next_header: Protocol, /// The offset of the data following this header, relative to the start of the Fragmentable /// Part of the original packet. pub frag_offset: u16, - /// Whethere are not there are more fragments following this header + /// When there are more fragments following this header pub more_frags: bool, /// The identification for every packet that is fragmented. pub ident: u32, - } impl Repr { /// Parse an IPv6 Fragment Header and return a high-level representation. - pub fn parse(header: &Header<&T>) -> Result where T: AsRef<[u8]> + ?Sized { + pub fn parse(header: &Header<&T>) -> Result + where + T: AsRef<[u8]> + ?Sized, + { Ok(Repr { - next_header: header.next_header(), frag_offset: header.frag_offset(), more_frags: header.more_frags(), - ident: header.ident() + ident: header.ident(), }) } /// Return the length, in bytes, of a header that will be emitted from this high-level /// representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { field::IDENT.end } /// Emit a high-level representation into an IPv6 Fragment Header. pub fn emit + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { - header.set_next_header(self.next_header); header.clear_reserved(); header.set_frag_offset(self.frag_offset); header.set_more_frags(self.more_frags); @@ -198,10 +181,13 @@ impl Repr { } } -impl<'a> fmt::Display for Repr { +impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "IPv6 Fragment next_hdr={} offset={} more={} ident={}", - self.next_header, self.frag_offset, self.more_frags, self.ident) + write!( + f, + "IPv6 Fragment offset={} more={} ident={}", + self.frag_offset, self.more_frags, self.ident + ) } } @@ -210,35 +196,35 @@ mod test { use super::*; // A Fragment Header with more fragments remaining - static BYTES_HEADER_MORE_FRAG: [u8; 8] = [0x6, 0x0, 0x0, 0x1, - 0x0, 0x0, 0x30, 0x39]; + static BYTES_HEADER_MORE_FRAG: [u8; 6] = [0x0, 0x1, 0x0, 0x0, 0x30, 0x39]; // A Fragment Header with no more fragments remaining - static BYTES_HEADER_LAST_FRAG: [u8; 8] = [0x6, 0x0, 0xa, 0x0, - 0x0, 0x1, 0x9, 0x32]; + static BYTES_HEADER_LAST_FRAG: [u8; 6] = [0xa, 0x0, 0x0, 0x1, 0x9, 0x32]; #[test] fn test_check_len() { - // less than 8 bytes - assert_eq!(Err(Error::Truncated), - Header::new_unchecked(&BYTES_HEADER_MORE_FRAG[..7]).check_len()); + // less than 6 bytes + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_HEADER_MORE_FRAG[..5]).check_len() + ); // valid - assert_eq!(Ok(()), - Header::new_unchecked(&BYTES_HEADER_MORE_FRAG).check_len()); + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_HEADER_MORE_FRAG).check_len() + ); } #[test] fn test_header_deconstruct() { let header = Header::new_unchecked(&BYTES_HEADER_MORE_FRAG); - assert_eq!(header.next_header(), Protocol::Tcp); assert_eq!(header.frag_offset(), 0); - assert_eq!(header.more_frags(), true); + assert!(header.more_frags()); assert_eq!(header.ident(), 12345); let header = Header::new_unchecked(&BYTES_HEADER_LAST_FRAG); - assert_eq!(header.next_header(), Protocol::Tcp); assert_eq!(header.frag_offset(), 320); - assert_eq!(header.more_frags(), false); + assert!(!header.more_frags()); assert_eq!(header.ident(), 67890); } @@ -246,28 +232,48 @@ mod test { fn test_repr_parse_valid() { let header = Header::new_unchecked(&BYTES_HEADER_MORE_FRAG); let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr, - Repr{ next_header: Protocol::Tcp, frag_offset: 0, more_frags: true, ident: 12345 }); + assert_eq!( + repr, + Repr { + frag_offset: 0, + more_frags: true, + ident: 12345 + } + ); let header = Header::new_unchecked(&BYTES_HEADER_LAST_FRAG); let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr, - Repr{ next_header: Protocol::Tcp, frag_offset: 320, more_frags: false, ident: 67890 }); + assert_eq!( + repr, + Repr { + frag_offset: 320, + more_frags: false, + ident: 67890 + } + ); } #[test] fn test_repr_emit() { - let repr = Repr{ next_header: Protocol::Tcp, frag_offset: 0, more_frags: true, ident: 12345 }; - let mut bytes = [0u8; 8]; + let repr = Repr { + frag_offset: 0, + more_frags: true, + ident: 12345, + }; + let mut bytes = [0u8; 6]; let mut header = Header::new_unchecked(&mut bytes); repr.emit(&mut header); - assert_eq!(header.into_inner(), &BYTES_HEADER_MORE_FRAG[0..8]); - - let repr = Repr{ next_header: Protocol::Tcp, frag_offset: 320, more_frags: false, ident: 67890 }; - let mut bytes = [0u8; 8]; + assert_eq!(header.into_inner(), &BYTES_HEADER_MORE_FRAG[0..6]); + + let repr = Repr { + frag_offset: 320, + more_frags: false, + ident: 67890, + }; + let mut bytes = [0u8; 6]; let mut header = Header::new_unchecked(&mut bytes); repr.emit(&mut header); - assert_eq!(header.into_inner(), &BYTES_HEADER_LAST_FRAG[0..8]); + assert_eq!(header.into_inner(), &BYTES_HEADER_LAST_FRAG[0..6]); } #[test] diff --git a/src/wire/ipv6hopbyhop.rs b/src/wire/ipv6hopbyhop.rs deleted file mode 100644 index 3eb115fe9..000000000 --- a/src/wire/ipv6hopbyhop.rs +++ /dev/null @@ -1,332 +0,0 @@ -use core::fmt; -use {Error, Result}; - -use super::ipv6option::Ipv6OptionsIterator; -pub use super::IpProtocol as Protocol; - -/// A read/write wrapper around an IPv6 Hop-by-Hop Options Header. -#[derive(Debug, PartialEq)] -pub struct Header> { - buffer: T -} - -// Format of the Hop-by-Hop Options Header -// -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// | Next Header | Hdr Ext Len | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + -// | | -// . . -// . Options . -// . . -// | | -// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ -// -// -// See https://tools.ietf.org/html/rfc8200#section-4.3 for details. -mod field { - #![allow(non_snake_case)] - - use wire::field::*; - - // Minimum size of the header. - pub const MIN_HEADER_SIZE: usize = 8; - - // 8-bit identifier of the header immediately following this header. - pub const NXT_HDR: usize = 0; - // 8-bit unsigned integer. Length of the OPTIONS field in 8-octet units, - // not including the first 8 octets. - pub const LENGTH: usize = 1; - // Variable-length field. Option-Type-specific data. - // - // Length of the header is in 8-octet units, not including the first 8 octets. The first two - // octets are the next header type and the header length. - pub fn OPTIONS(length_field: u8) -> Field { - let bytes = length_field * 8 + 8; - 2..bytes as usize - } -} - -impl> Header { - /// Create a raw octet buffer with an IPv6 Hop-by-Hop Options Header structure. - pub fn new_unchecked(buffer: T) -> Header { - Header { buffer } - } - - /// Shorthand for a combination of [new_unchecked] and [check_len]. - /// - /// [new_unchecked]: #method.new_unchecked - /// [check_len]: #method.check_len - pub fn new_checked(buffer: T) -> Result> { - let header = Self::new_unchecked(buffer); - header.check_len()?; - Ok(header) - } - - /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. - /// - /// The result of this check is invalidated by calling [set_header_len]. - /// - /// [set_header_len]: #method.set_header_len - pub fn check_len(&self) -> Result<()> { - let data = self.buffer.as_ref(); - let len = data.len(); - - if len < field::MIN_HEADER_SIZE { - return Err(Error::Truncated); - } - - let of = field::OPTIONS(data[field::LENGTH]); - - if len < of.end { - return Err(Error::Truncated); - } - - Ok(()) - } - - /// Consume the header, returning the underlying buffer. - pub fn into_inner(self) -> T { - self.buffer - } - - /// Return the next header field. - #[inline] - pub fn next_header(&self) -> Protocol { - let data = self.buffer.as_ref(); - Protocol::from(data[field::NXT_HDR]) - } - - /// Return length of the Hop-by-Hop Options header in 8-octet units, not including the first - /// 8 octets. - #[inline] - pub fn header_len(&self) -> u8 { - let data = self.buffer.as_ref(); - data[field::LENGTH] - } -} - -impl<'a, T: AsRef<[u8]> + ?Sized> Header<&'a T> { - /// Return the option data. - #[inline] - pub fn options(&self) -> &'a[u8] { - let data = self.buffer.as_ref(); - &data[field::OPTIONS(data[field::LENGTH])] - } -} - -impl + AsMut<[u8]>> Header { - /// Set the next header field. - #[inline] - pub fn set_next_header(&mut self, value: Protocol) { - let data = self.buffer.as_mut(); - data[field::NXT_HDR] = value.into(); - } - - /// Set the option data length. Length of the Hop-by-Hop Options header in 8-octet units, - /// not including the first 8 octets. - #[inline] - pub fn set_header_len(&mut self, value: u8) { - let data = self.buffer.as_mut(); - data[field::LENGTH] = value; - } -} - -impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Header<&'a mut T> { - /// Return a mutable pointer to the option data. - #[inline] - pub fn options_mut(&mut self) -> &mut [u8] { - let data = self.buffer.as_mut(); - let len = data[field::LENGTH]; - &mut data[field::OPTIONS(len)] - } -} - -impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), - Err(err) => { - write!(f, "IPv6 Hop-by-Hop Options ({})", err)?; - Ok(()) - } - } - } -} - -/// A high-level representation of an IPv6 Hop-by-Hop Options header. -#[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct Repr<'a> { - /// The type of header immediately following the Hop-by-Hop Options header. - pub next_header: Protocol, - /// Length of the Hop-by-Hop Options header in 8-octet units, not including the first 8 octets. - pub length: u8, - /// The options contained in the Hop-by-Hop Options header. - pub options: &'a [u8] -} - -impl<'a> Repr<'a> { - /// Parse an IPv6 Hop-by-Hop Options Header and return a high-level representation. - pub fn parse(header: &Header<&'a T>) -> Result> where T: AsRef<[u8]> + ?Sized { - Ok(Repr { - next_header: header.next_header(), - length: header.header_len(), - options: header.options() - }) - } - - /// Return the length, in bytes, of a header that will be emitted from this high-level - /// representation. - pub fn buffer_len(&self) -> usize { - field::OPTIONS(self.length).end - } - - /// Emit a high-level representation into an IPv6 Hop-by-Hop Options Header. - pub fn emit + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { - header.set_next_header(self.next_header); - header.set_header_len(self.length); - header.options_mut().copy_from_slice(self.options); - } - - /// Return an `Iterator` for the contained options. - pub fn options(&self) -> Ipv6OptionsIterator { - Ipv6OptionsIterator::new(self.options, self.buffer_len() - 2) - } -} - -impl<'a> fmt::Display for Repr<'a> { - fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "IPv6 Hop-by-Hop Options next_hdr={} length={} ", self.next_header, self.length) - } -} - -#[cfg(test)] -mod test { - use super::*; - - // A Hop-by-Hop Option header with a PadN option of option data length 4. - static REPR_PACKET_PAD4: [u8; 8] = [0x6, 0x0, 0x1, 0x4, - 0x0, 0x0, 0x0, 0x0]; - - // A Hop-by-Hop Option header with a PadN option of option data length 12. - static REPR_PACKET_PAD12: [u8; 16] = [0x06, 0x1, 0x1, 0x12, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0]; - - #[test] - fn test_check_len() { - // zero byte buffer - assert_eq!(Err(Error::Truncated), - Header::new_unchecked(&REPR_PACKET_PAD4[..0]).check_len()); - // no length field - assert_eq!(Err(Error::Truncated), - Header::new_unchecked(&REPR_PACKET_PAD4[..1]).check_len()); - // less than 8 bytes - assert_eq!(Err(Error::Truncated), - Header::new_unchecked(&REPR_PACKET_PAD4[..7]).check_len()); - // valid - assert_eq!(Ok(()), - Header::new_unchecked(&REPR_PACKET_PAD4).check_len()); - // valid - assert_eq!(Ok(()), - Header::new_unchecked(&REPR_PACKET_PAD12).check_len()); - // length field value greater than number of bytes - let header: [u8; 8] = [0x06, 0x2, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0]; - assert_eq!(Err(Error::Truncated), - Header::new_unchecked(&header).check_len()); - } - - #[test] - fn test_header_deconstruct() { - let header = Header::new_unchecked(&REPR_PACKET_PAD4); - assert_eq!(header.next_header(), Protocol::Tcp); - assert_eq!(header.header_len(), 0); - assert_eq!(header.options(), &REPR_PACKET_PAD4[2..]); - - let header = Header::new_unchecked(&REPR_PACKET_PAD12); - assert_eq!(header.next_header(), Protocol::Tcp); - assert_eq!(header.header_len(), 1); - assert_eq!(header.options(), &REPR_PACKET_PAD12[2..]); - } - - #[test] - fn test_overlong() { - let mut bytes = vec![]; - bytes.extend(&REPR_PACKET_PAD4[..]); - bytes.push(0); - - assert_eq!(Header::new_unchecked(&bytes).options().len(), - REPR_PACKET_PAD4[2..].len()); - assert_eq!(Header::new_unchecked(&mut bytes).options_mut().len(), - REPR_PACKET_PAD4[2..].len()); - - let mut bytes = vec![]; - bytes.extend(&REPR_PACKET_PAD12[..]); - bytes.push(0); - - assert_eq!(Header::new_unchecked(&bytes).options().len(), - REPR_PACKET_PAD12[2..].len()); - assert_eq!(Header::new_unchecked(&mut bytes).options_mut().len(), - REPR_PACKET_PAD12[2..].len()); - } - - #[test] - fn test_header_len_overflow() { - let mut bytes = vec![]; - bytes.extend(&REPR_PACKET_PAD4); - let len = bytes.len() as u8; - Header::new_unchecked(&mut bytes).set_header_len(len + 1); - - assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error::Truncated); - - let mut bytes = vec![]; - bytes.extend(&REPR_PACKET_PAD12); - let len = bytes.len() as u8; - Header::new_unchecked(&mut bytes).set_header_len(len + 1); - - assert_eq!(Header::new_checked(&bytes).unwrap_err(), Error::Truncated); - } - - #[test] - fn test_repr_parse_valid() { - let header = Header::new_unchecked(&REPR_PACKET_PAD4); - let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr, Repr { - next_header: Protocol::Tcp, length: 0, options: &REPR_PACKET_PAD4[2..] - }); - - let header = Header::new_unchecked(&REPR_PACKET_PAD12); - let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr, Repr { - next_header: Protocol::Tcp, length: 1, options: &REPR_PACKET_PAD12[2..] - }); - } - - #[test] - fn test_repr_emit() { - let repr = Repr{ next_header: Protocol::Tcp, length: 0, options: &REPR_PACKET_PAD4[2..] }; - let mut bytes = [0u8; 8]; - let mut header = Header::new_unchecked(&mut bytes); - repr.emit(&mut header); - assert_eq!(header.into_inner(), &REPR_PACKET_PAD4[..]); - - let repr = Repr{ next_header: Protocol::Tcp, length: 1, options: &REPR_PACKET_PAD12[2..] }; - let mut bytes = [0u8; 16]; - let mut header = Header::new_unchecked(&mut bytes); - repr.emit(&mut header); - assert_eq!(header.into_inner(), &REPR_PACKET_PAD12[..]); - } - - #[test] - fn test_buffer_len() { - let header = Header::new_unchecked(&REPR_PACKET_PAD4); - let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr.buffer_len(), REPR_PACKET_PAD4.len()); - - let header = Header::new_unchecked(&REPR_PACKET_PAD12); - let repr = Repr::parse(&header).unwrap(); - assert_eq!(repr.buffer_len(), REPR_PACKET_PAD12.len()); - } -} diff --git a/src/wire/ipv6option.rs b/src/wire/ipv6option.rs index 26ba587b2..dfbd6acad 100644 --- a/src/wire/ipv6option.rs +++ b/src/wire/ipv6option.rs @@ -1,22 +1,28 @@ +use super::{Error, Result}; +#[cfg(feature = "proto-rpl")] +use super::{RplHopByHopPacket, RplHopByHopRepr}; + use core::fmt; -use {Error, Result}; enum_with_unknown! { /// IPv6 Extension Header Option Type - pub doc enum Type(u8) { + pub enum Type(u8) { /// 1 byte of padding - Pad1 = 0, + Pad1 = 0, /// Multiple bytes of padding - PadN = 1 + PadN = 1, + /// RPL Option + Rpl = 0x63, } } impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Type::Pad1 => write!(f, "Pad1"), - &Type::PadN => write!(f, "PadN"), - &Type::Unknown(id) => write!(f, "{}", id) + match *self { + Type::Pad1 => write!(f, "Pad1"), + Type::PadN => write!(f, "PadN"), + Type::Rpl => write!(f, "RPL"), + Type::Unknown(id) => write!(f, "{id}"), } } } @@ -24,7 +30,7 @@ impl fmt::Display for Type { enum_with_unknown! { /// Action required when parsing the given IPv6 Extension /// Header Option Type fails - pub doc enum FailureType(u8) { + pub enum FailureType(u8) { /// Skip this option and continue processing the packet Skip = 0b00000000, /// Discard the containing packet @@ -39,12 +45,12 @@ enum_with_unknown! { impl fmt::Display for FailureType { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &FailureType::Skip => write!(f, "skip"), - &FailureType::Discard => write!(f, "discard"), - &FailureType::DiscardSendAll => write!(f, "discard and send error"), - &FailureType::DiscardSendUnicast => write!(f, "discard and send error if unicast"), - &FailureType::Unknown(id) => write!(f, "Unknown({})", id), + match *self { + FailureType::Skip => write!(f, "skip"), + FailureType::Discard => write!(f, "discard"), + FailureType::DiscardSendAll => write!(f, "discard and send error"), + FailureType::DiscardSendUnicast => write!(f, "discard and send error if unicast"), + FailureType::Unknown(id) => write!(f, "Unknown({id})"), } } } @@ -57,9 +63,10 @@ impl From for FailureType { } /// A read/write wrapper around an IPv6 Extension Header Option. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Ipv6Option> { - buffer: T + buffer: T, } // Format of Option @@ -73,21 +80,21 @@ pub struct Ipv6Option> { mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; // 8-bit identifier of the type of option. - pub const TYPE: usize = 0; + pub const TYPE: usize = 0; // 8-bit unsigned integer. Length of the DATA field of this option, in octets. - pub const LENGTH: usize = 1; + pub const LENGTH: usize = 1; // Variable-length field. Option-Type-specific data. - pub fn DATA(length: u8) -> Field { + pub const fn DATA(length: u8) -> Field { 2..length as usize + 2 } } impl> Ipv6Option { /// Create a raw octet buffer with an IPv6 Extension Header Option structure. - pub fn new_unchecked(buffer: T) -> Ipv6Option { + pub const fn new_unchecked(buffer: T) -> Ipv6Option { Ipv6Option { buffer } } @@ -102,7 +109,7 @@ impl> Ipv6Option { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_data_len]. /// @@ -112,7 +119,7 @@ impl> Ipv6Option { let len = data.len(); if len < field::LENGTH { - return Err(Error::Truncated); + return Err(Error); } if self.option_type() == Type::Pad1 { @@ -120,13 +127,13 @@ impl> Ipv6Option { } if len == field::LENGTH { - return Err(Error::Truncated); + return Err(Error); } let df = field::DATA(data[field::LENGTH]); if len < df.end { - return Err(Error::Truncated); + return Err(Error); } Ok(()) @@ -203,9 +210,9 @@ impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Ipv6Option<&'a mut T> { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Ipv6Option<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "IPv6 Extension Option ({})", err)?; + write!(f, "IPv6 Extension Option ({err})")?; Ok(()) } } @@ -214,56 +221,65 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Ipv6Option<&'a T> { /// A high-level representation of an IPv6 Extension Header Option. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub enum Repr<'a> { Pad1, PadN(u8), + #[cfg(feature = "proto-rpl")] + Rpl(RplHopByHopRepr), Unknown { - type_: Type, + type_: Type, length: u8, - data: &'a [u8] + data: &'a [u8], }, - - #[doc(hidden)] - __Nonexhaustive } impl<'a> Repr<'a> { /// Parse an IPv6 Extension Header Option and return a high-level representation. - pub fn parse(opt: &Ipv6Option<&'a T>) -> Result> where T: AsRef<[u8]> + ?Sized { + pub fn parse(opt: &Ipv6Option<&'a T>) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { match opt.option_type() { - Type::Pad1 => - Ok(Repr::Pad1), - Type::PadN => - Ok(Repr::PadN(opt.data_len())), - unknown_type @ Type::Unknown(_) => { - Ok(Repr::Unknown { - type_: unknown_type, - length: opt.data_len(), - data: opt.data(), - }) - } + Type::Pad1 => Ok(Repr::Pad1), + Type::PadN => Ok(Repr::PadN(opt.data_len())), + + #[cfg(feature = "proto-rpl")] + Type::Rpl => Ok(Repr::Rpl(RplHopByHopRepr::parse( + &RplHopByHopPacket::new_checked(opt.data())?, + ))), + #[cfg(not(feature = "proto-rpl"))] + Type::Rpl => Ok(Repr::Unknown { + type_: Type::Rpl, + length: opt.data_len(), + data: opt.data(), + }), + + unknown_type @ Type::Unknown(_) => Ok(Repr::Unknown { + type_: unknown_type, + length: opt.data_len(), + data: opt.data(), + }), } } /// Return the length of a header that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { - match self { - &Repr::Pad1 => 1, - &Repr::PadN(length) => - field::DATA(length).end, - &Repr::Unknown{ length, .. } => - field::DATA(length).end, - - &Repr::__Nonexhaustive => unreachable!() + pub const fn buffer_len(&self) -> usize { + match *self { + Repr::Pad1 => 1, + Repr::PadN(length) => field::DATA(length).end, + #[cfg(feature = "proto-rpl")] + Repr::Rpl(opt) => field::DATA(opt.buffer_len() as u8).end, + Repr::Unknown { length, .. } => field::DATA(length).end, } } /// Emit a high-level representation into an IPv6 Extension Header Option. pub fn emit + AsMut<[u8]> + ?Sized>(&self, opt: &mut Ipv6Option<&'a mut T>) { - match self { - &Repr::Pad1 => - opt.set_option_type(Type::Pad1), - &Repr::PadN(len) => { + match *self { + Repr::Pad1 => opt.set_option_type(Type::Pad1), + Repr::PadN(len) => { opt.set_option_type(Type::PadN); opt.set_data_len(len); // Ensure all padding bytes are set to zero. @@ -271,50 +287,50 @@ impl<'a> Repr<'a> { *x = 0 } } - &Repr::Unknown{ type_, length, data } => { + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => { + opt.set_option_type(Type::Rpl); + opt.set_data_len(4); + rpl.emit(&mut crate::wire::RplHopByHopPacket::new_unchecked( + opt.data_mut(), + )); + } + Repr::Unknown { + type_, + length, + data, + } => { opt.set_option_type(type_); opt.set_data_len(length); opt.data_mut().copy_from_slice(&data[..length as usize]); } - - &Repr::__Nonexhaustive => unreachable!() } } } /// A iterator for IPv6 options. #[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Ipv6OptionsIterator<'a> { pos: usize, length: usize, data: &'a [u8], - hit_error: bool + hit_error: bool, } impl<'a> Ipv6OptionsIterator<'a> { /// Create a new `Ipv6OptionsIterator`, used to iterate over the /// options contained in a IPv6 Extension Header (e.g. the Hop-by-Hop /// header). - /// - /// # Panics - /// This function panics if the `length` provided is larger than the - /// length of the `data` buffer. - pub fn new(data: &'a [u8], length: usize) -> Ipv6OptionsIterator<'a> { - assert!(length <= data.len()); + pub fn new(data: &'a [u8]) -> Ipv6OptionsIterator<'a> { + let length = data.len(); Ipv6OptionsIterator { pos: 0, hit_error: false, - length, data + length, + data, } } - - /// Helper function to return an error in the implementation - /// of `Iterator`. - #[inline] - fn return_err(&mut self, err: Error) -> Option>> { - self.hit_error = true; - Some(Err(err)) - } } impl<'a> Iterator for Ipv6OptionsIterator<'a> { @@ -325,19 +341,19 @@ impl<'a> Iterator for Ipv6OptionsIterator<'a> { // If we still have data to parse and we have not previously // hit an error, attempt to parse the next option. match Ipv6Option::new_checked(&self.data[self.pos..]) { - Ok(hdr) => { - match Repr::parse(&hdr) { - Ok(repr) => { - self.pos += repr.buffer_len(); - Some(Ok(repr)) - } - Err(e) => { - self.return_err(e) - } + Ok(hdr) => match Repr::parse(&hdr) { + Ok(repr) => { + self.pos += repr.buffer_len(); + Some(Ok(repr)) } - } + Err(e) => { + self.hit_error = true; + Some(Err(e)) + } + }, Err(e) => { - self.return_err(e) + self.hit_error = true; + Some(Err(e)) } } } else { @@ -351,15 +367,12 @@ impl<'a> Iterator for Ipv6OptionsIterator<'a> { impl<'a> fmt::Display for Repr<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "IPv6 Option ")?; - match self { - &Repr::Pad1 => - write!(f, "{} ", Type::Pad1), - &Repr::PadN(len) => - write!(f, "{} length={} ", Type::PadN, len), - &Repr::Unknown{ type_, length, .. } => - write!(f, "{} length={} ", type_, length), - - &Repr::__Nonexhaustive => unreachable!() + match *self { + Repr::Pad1 => write!(f, "{} ", Type::Pad1), + Repr::PadN(len) => write!(f, "{} length={} ", Type::PadN, len), + #[cfg(feature = "proto-rpl")] + Repr::Rpl(rpl) => write!(f, "{} {rpl}", Type::Rpl), + Repr::Unknown { type_, length, .. } => write!(f, "{type_} length={length} "), } } } @@ -368,35 +381,59 @@ impl<'a> fmt::Display for Repr<'a> { mod test { use super::*; - static IPV6OPTION_BYTES_PAD1: [u8; 1] = [0x0]; - static IPV6OPTION_BYTES_PADN: [u8; 3] = [0x1, 0x1, 0x0]; + static IPV6OPTION_BYTES_PAD1: [u8; 1] = [0x0]; + static IPV6OPTION_BYTES_PADN: [u8; 3] = [0x1, 0x1, 0x0]; static IPV6OPTION_BYTES_UNKNOWN: [u8; 5] = [0xff, 0x3, 0x0, 0x0, 0x0]; + #[cfg(feature = "proto-rpl")] + static IPV6OPTION_BYTES_RPL: [u8; 6] = [0x63, 0x04, 0x00, 0x1e, 0x08, 0x00]; #[test] fn test_check_len() { let bytes = [0u8]; // zero byte buffer - assert_eq!(Err(Error::Truncated), - Ipv6Option::new_unchecked(&bytes[..0]).check_len()); + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&bytes[..0]).check_len() + ); // pad1 - assert_eq!(Ok(()), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1).check_len()); + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PAD1).check_len() + ); // padn with truncated data - assert_eq!(Err(Error::Truncated), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN[..2]).check_len()); + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN[..2]).check_len() + ); // padn - assert_eq!(Ok(()), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN).check_len()); + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_PADN).check_len() + ); // unknown option type with truncated data - assert_eq!(Err(Error::Truncated), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..4]).check_len()); - assert_eq!(Err(Error::Truncated), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..1]).check_len()); + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..4]).check_len() + ); + assert_eq!( + Err(Error), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN[..1]).check_len() + ); // unknown type - assert_eq!(Ok(()), - Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN).check_len()); + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN).check_len() + ); + + #[cfg(feature = "proto-rpl")] + { + assert_eq!( + Ok(()), + Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL).check_len() + ); + } } #[test] @@ -413,7 +450,7 @@ mod test { assert_eq!(opt.option_type(), Type::Pad1); // two octets of padding - let bytes: [u8; 2] = [0x1, 0x0]; + let bytes: [u8; 2] = [0x1, 0x0]; let opt = Ipv6Option::new_unchecked(&bytes); assert_eq!(opt.option_type(), Type::PadN); assert_eq!(opt.data_len(), 0); @@ -425,19 +462,27 @@ mod test { assert_eq!(opt.data(), &[0]); // extra bytes in buffer - let bytes: [u8; 10] = [0x1, 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff]; + let bytes: [u8; 10] = [0x1, 0x7, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0xff]; let opt = Ipv6Option::new_unchecked(&bytes); assert_eq!(opt.option_type(), Type::PadN); assert_eq!(opt.data_len(), 7); assert_eq!(opt.data(), &[0, 0, 0, 0, 0, 0, 0]); // unrecognized option - let bytes: [u8; 1] = [0xff]; + let bytes: [u8; 1] = [0xff]; let opt = Ipv6Option::new_unchecked(&bytes); assert_eq!(opt.option_type(), Type::Unknown(255)); // unrecognized option without length and data - assert_eq!(Ipv6Option::new_checked(&bytes), Err(Error::Truncated)); + assert_eq!(Ipv6Option::new_checked(&bytes), Err(Error)); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + assert_eq!(opt.option_type(), Type::Rpl); + assert_eq!(opt.data_len(), 4); + assert_eq!(opt.data(), &[0x00, 0x1e, 0x08, 0x00]); + } } #[test] @@ -458,7 +503,31 @@ mod test { let data = [0u8; 3]; let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_UNKNOWN); let unknown = Repr::parse(&opt).unwrap(); - assert_eq!(unknown, Repr::Unknown { type_: Type::Unknown(255), length: 3, data: &data }); + assert_eq!( + unknown, + Repr::Unknown { + type_: Type::Unknown(255), + length: 3, + data: &data + } + ); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + let rpl = Repr::parse(&opt).unwrap(); + + assert_eq!( + rpl, + Repr::Rpl(crate::wire::RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: crate::wire::RplInstanceId::from(0x1e), + sender_rank: 0x0800, + }) + ); + } } #[test] @@ -476,11 +545,25 @@ mod test { assert_eq!(opt.into_inner(), &IPV6OPTION_BYTES_PADN); let data = [0u8; 3]; - let repr = Repr::Unknown { type_: Type::Unknown(255), length: 3, data: &data }; + let repr = Repr::Unknown { + type_: Type::Unknown(255), + length: 3, + data: &data, + }; let mut bytes = [254u8; 5]; // don't assume bytes are initialized to zero let mut opt = Ipv6Option::new_unchecked(&mut bytes); repr.emit(&mut opt); assert_eq!(opt.into_inner(), &IPV6OPTION_BYTES_UNKNOWN); + + #[cfg(feature = "proto-rpl")] + { + let opt = Ipv6Option::new_unchecked(&IPV6OPTION_BYTES_RPL); + let rpl = Repr::parse(&opt).unwrap(); + let mut bytes = [0u8; 6]; + rpl.emit(&mut Ipv6Option::new_unchecked(&mut bytes)); + + assert_eq!(&bytes, &IPV6OPTION_BYTES_RPL); + } } #[test] @@ -499,15 +582,12 @@ mod test { #[test] fn test_options_iter() { - let options = [0x00, 0x01, 0x01, 0x00, - 0x01, 0x02, 0x00, 0x00, - 0x01, 0x00, 0x00, 0x11, - 0x00, 0x01, 0x08, 0x00]; - - let mut iterator = Ipv6OptionsIterator::new(&options, 0); - assert_eq!(iterator.next(), None); + let options = [ + 0x00, 0x01, 0x01, 0x00, 0x01, 0x02, 0x00, 0x00, 0x01, 0x00, 0x00, 0x11, 0x00, 0x01, + 0x08, 0x00, + ]; - iterator = Ipv6OptionsIterator::new(&options, 16); + let iterator = Ipv6OptionsIterator::new(&options); for (i, opt) in iterator.enumerate() { match (i, opt) { (0, Ok(Repr::Pad1)) => continue, @@ -515,18 +595,17 @@ mod test { (2, Ok(Repr::PadN(2))) => continue, (3, Ok(Repr::PadN(0))) => continue, (4, Ok(Repr::Pad1)) => continue, - (5, Ok(Repr::Unknown { type_: Type::Unknown(0x11), length: 0, .. })) => - continue, - (6, Err(Error::Truncated)) => continue, - (i, res) => panic!("Unexpected option `{:?}` at index {}", res, i), + ( + 5, + Ok(Repr::Unknown { + type_: Type::Unknown(0x11), + length: 0, + .. + }), + ) => continue, + (6, Err(Error)) => continue, + (i, res) => panic!("Unexpected option `{res:?}` at index {i}"), } } } - - #[test] - #[should_panic(expected = "length <= data.len()")] - fn test_options_iter_truncated() { - let options = [0x01, 0x02, 0x00, 0x00]; - let _ = Ipv6OptionsIterator::new(&options, 5); - } } diff --git a/src/wire/ipv6routing.rs b/src/wire/ipv6routing.rs index befbd2c8a..f5f9c4138 100644 --- a/src/wire/ipv6routing.rs +++ b/src/wire/ipv6routing.rs @@ -1,12 +1,11 @@ +use super::{Error, Result}; use core::fmt; -use {Error, Result}; -use super::IpProtocol as Protocol; -use super::Ipv6Address as Address; +use crate::wire::Ipv6Address as Address; enum_with_unknown! { /// IPv6 Extension Routing Header Routing Type - pub doc enum Type(u8) { + pub enum Type(u8) { /// Source Route (DEPRECATED) /// /// See https://tools.ietf.org/html/rfc5095 for details. @@ -36,23 +35,24 @@ enum_with_unknown! { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Type::Type0 => write!(f, "Type0"), - &Type::Nimrod => write!(f, "Nimrod"), - &Type::Type2 => write!(f, "Type2"), - &Type::Rpl => write!(f, "Rpl"), - &Type::Experiment1 => write!(f, "Experiment1"), - &Type::Experiment2 => write!(f, "Experiment2"), - &Type::Reserved => write!(f, "Reserved"), - &Type::Unknown(id) => write!(f, "{}", id) + match *self { + Type::Type0 => write!(f, "Type0"), + Type::Nimrod => write!(f, "Nimrod"), + Type::Type2 => write!(f, "Type2"), + Type::Rpl => write!(f, "Rpl"), + Type::Experiment1 => write!(f, "Experiment1"), + Type::Experiment2 => write!(f, "Experiment2"), + Type::Reserved => write!(f, "Reserved"), + Type::Unknown(id) => write!(f, "{id}"), } } } /// A read/write wrapper around an IPv6 Routing Header buffer. -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Header> { - buffer: T + buffer: T, } // Format of the Routing Header @@ -69,31 +69,20 @@ pub struct Header> { // // // See https://tools.ietf.org/html/rfc8200#section-4.4 for details. +// +// **NOTE**: The fields start counting after the header length field. mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; // Minimum size of the header. - pub const MIN_HEADER_SIZE: usize = 4; + pub const MIN_HEADER_SIZE: usize = 2; - // 8-bit identifier of the header immediately following this header. - pub const NXT_HDR: usize = 0; - // 8-bit unsigned integer. Length of the DATA field in 8-octet units, - // not including the first 8 octets. - pub const LENGTH: usize = 1; // 8-bit identifier of a particular Routing header variant. - pub const TYPE: usize = 2; + pub const TYPE: usize = 0; // 8-bit unsigned integer. The number of route segments remaining. - pub const SEG_LEFT: usize = 3; - // Variable-length field. Routing-Type-specific data. - // - // Length of the header is in 8-octet units, not including the first 8 octets. The first four - // octets are the next header type, the header length, routing type and segments left. - pub fn DATA(length_field: u8) -> Field { - let bytes = length_field * 8 + 8; - 4..bytes as usize - } + pub const SEG_LEFT: usize = 1; // The Type 2 Routing Header has the following format: // @@ -112,7 +101,7 @@ mod field { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // 16-byte field containing the home address of the destination mobile node. - pub const HOME_ADDRESS: Field = 8..24; + pub const HOME_ADDRESS: Field = 6..22; // The RPL Source Routing Header has the following format: // @@ -129,20 +118,17 @@ mod field { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // 8-bit field containing the CmprI and CmprE values. - pub const CMPR: usize = 4; + pub const CMPR: usize = 2; // 8-bit field containing the Pad value. - pub const PAD: usize = 5; + pub const PAD: usize = 3; // Variable length field containing addresses - pub fn ADDRESSES(length_field: u8) -> Field { - let data = DATA(length_field); - 8..data.end - } + pub const ADDRESSES: usize = 6; } /// Core getter methods relevant to any routing type. impl> Header { /// Create a raw octet buffer with an IPv6 Routing Header structure. - pub fn new(buffer: T) -> Header { + pub const fn new_unchecked(buffer: T) -> Header { Header { buffer } } @@ -151,13 +137,13 @@ impl> Header { /// [new_unchecked]: #method.new_unchecked /// [check_len]: #method.check_len pub fn new_checked(buffer: T) -> Result> { - let header = Self::new(buffer); + let header = Self::new_unchecked(buffer); header.check_len()?; Ok(header) } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_header_len]. /// @@ -165,11 +151,13 @@ impl> Header { pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::MIN_HEADER_SIZE { - return Err(Error::Truncated); + return Err(Error); } - if len < field::DATA(self.header_len()).end as usize { - return Err(Error::Truncated); + match self.routing_type() { + Type::Type2 if len < field::HOME_ADDRESS.end => return Err(Error), + Type::Rpl if len < field::ADDRESSES => return Err(Error), + _ => (), } Ok(()) @@ -180,21 +168,6 @@ impl> Header { self.buffer } - /// Return the next header field. - #[inline] - pub fn next_header(&self) -> Protocol { - let data = self.buffer.as_ref(); - Protocol::from(data[field::NXT_HDR]) - } - - /// Return the header length field. Length of the Routing header in 8-octet units, - /// not including the first 8 octets. - #[inline] - pub fn header_len(&self) -> u8 { - let data = self.buffer.as_ref(); - data[field::LENGTH] - } - /// Return the routing type field. #[inline] pub fn routing_type(&self) -> Type { @@ -224,7 +197,7 @@ impl> Header { /// Getter methods for the RPL Source Routing Header routing type. impl> Header { - /// Return the number of prefix octects elided from addresses[1..n-1]. + /// Return the number of prefix octets elided from addresses[1..n-1]. /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -233,8 +206,7 @@ impl> Header { data[field::CMPR] >> 4 } - - /// Return the number of prefix octects elided from the last address (`addresses[n]`). + /// Return the number of prefix octets elided from the last address (`addresses[n]`). /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -243,7 +215,7 @@ impl> Header { data[field::CMPR] & 0xf } - /// Return the number of octects used for padding after `addresses[n]`. + /// Return the number of octets used for padding after `addresses[n]`. /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -258,26 +230,12 @@ impl> Header { /// This function may panic if this header is not the RPL Source Routing Header routing type. pub fn addresses(&self) -> &[u8] { let data = self.buffer.as_ref(); - &data[field::ADDRESSES(data[field::LENGTH])] + &data[field::ADDRESSES..] } } /// Core setter methods relevant to any routing type. impl + AsMut<[u8]>> Header { - /// Set the next header field. - #[inline] - pub fn set_next_header(&mut self, value: Protocol) { - let data = self.buffer.as_mut(); - data[field::NXT_HDR] = value.into(); - } - - /// Set the option data length. Length of the Routing header in 8-octet units. - #[inline] - pub fn set_header_len(&mut self, value: u8) { - let data = self.buffer.as_mut(); - data[field::LENGTH] = value; - } - /// Set the routing type. #[inline] pub fn set_routing_type(&mut self, value: Type) { @@ -310,12 +268,12 @@ impl + AsMut<[u8]>> Header { } Type::Rpl => { // Retain the higher order 4 bits of the padding field - data[field::PAD] = data[field::PAD] & 0xF0; + data[field::PAD] &= 0xF0; data[6] = 0; data[7] = 0; } - _ => panic!("Unrecognized routing type when clearing reserved fields.") + _ => panic!("Unrecognized routing type when clearing reserved fields."), } } } @@ -334,7 +292,7 @@ impl + AsMut<[u8]>> Header { /// Setter methods for the RPL Source Routing Header routing type. impl + AsMut<[u8]>> Header { - /// Set the number of prefix octects elided from addresses[1..n-1]. + /// Set the number of prefix octets elided from addresses[1..n-1]. /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -344,7 +302,7 @@ impl + AsMut<[u8]>> Header { data[field::CMPR] = raw; } - /// Set the number of prefix octects elided from the last address (`addresses[n]`). + /// Set the number of prefix octets elided from the last address (`addresses[n]`). /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -354,7 +312,7 @@ impl + AsMut<[u8]>> Header { data[field::CMPR] = raw; } - /// Set the number of octects used for padding after `addresses[n]`. + /// Set the number of octets used for padding after `addresses[n]`. /// /// # Panics /// This function may panic if this header is not the RPL Source Routing Header routing type. @@ -369,8 +327,7 @@ impl + AsMut<[u8]>> Header { /// This function may panic if this header is not the RPL Source Routing Header routing type. pub fn set_addresses(&mut self, value: &[u8]) { let data = self.buffer.as_mut(); - let len = data[field::LENGTH]; - let addresses = &mut data[field::ADDRESSES(len)]; + let addresses = &mut data[field::ADDRESSES..]; addresses.copy_from_slice(value); } } @@ -378,9 +335,9 @@ impl + AsMut<[u8]>> Header { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "IPv6 Routing ({})", err)?; + write!(f, "IPv6 Routing ({err})")?; Ok(()) } } @@ -389,94 +346,82 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Header<&'a T> { /// A high-level representation of an IPv6 Routing Header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[non_exhaustive] pub enum Repr<'a> { Type2 { - /// The type of header immediately following the Routing header. - next_header: Protocol, - /// Length of the Routing header in 8-octet units, not including the first 8 octets. - length: u8, /// Number of route segments remaining. - segments_left: u8, + segments_left: u8, /// The home address of the destination mobile node. home_address: Address, }, Rpl { - /// The type of header immediately following the Routing header. - next_header: Protocol, - /// Length of the Routing header in 8-octet units, not including the first 8 octets. - length: u8, /// Number of route segments remaining. - segments_left: u8, + segments_left: u8, /// Number of prefix octets from each segment, except the last segment, that are elided. - cmpr_i: u8, + cmpr_i: u8, /// Number of prefix octets from the last segment that are elided. - cmpr_e: u8, + cmpr_e: u8, /// Number of octets that are used for padding after `address[n]` at the end of the /// RPL Source Route Header. - pad: u8, + pad: u8, /// Vector of addresses, numbered 1 to `n`. - addresses: &'a[u8], + addresses: &'a [u8], }, - - #[doc(hidden)] - __Nonexhaustive } - impl<'a> Repr<'a> { /// Parse an IPv6 Routing Header and return a high-level representation. - pub fn parse(header: &'a Header<&'a T>) -> Result> where T: AsRef<[u8]> + ?Sized { + pub fn parse(header: &'a Header<&'a T>) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { match header.routing_type() { - Type::Type2 => { - Ok(Repr::Type2 { - next_header: header.next_header(), - length: header.header_len(), - segments_left: header.segments_left(), - home_address: header.home_address(), - }) - } - Type::Rpl => { - Ok(Repr::Rpl { - next_header: header.next_header(), - length: header.header_len(), - segments_left: header.segments_left(), - cmpr_i: header.cmpr_i(), - cmpr_e: header.cmpr_e(), - pad: header.pad(), - addresses: header.addresses(), - }) - } - - _ => Err(Error::Unrecognized) + Type::Type2 => Ok(Repr::Type2 { + segments_left: header.segments_left(), + home_address: header.home_address(), + }), + Type::Rpl => Ok(Repr::Rpl { + segments_left: header.segments_left(), + cmpr_i: header.cmpr_i(), + cmpr_e: header.cmpr_e(), + pad: header.pad(), + addresses: header.addresses(), + }), + + _ => Err(Error), } } /// Return the length, in bytes, of a header that will be emitted from this high-level /// representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { match self { - &Repr::Rpl { length, .. } | &Repr::Type2 { length, .. } => { - field::DATA(length).end - } - - &Repr::__Nonexhaustive => unreachable!() + // Routing Type + Segments Left + Reserved + Home Address + Repr::Type2 { home_address, .. } => 2 + 4 + home_address.as_bytes().len(), + Repr::Rpl { addresses, .. } => 2 + 4 + addresses.len(), } } /// Emit a high-level representation into an IPv6 Routing Header. pub fn emit + AsMut<[u8]> + ?Sized>(&self, header: &mut Header<&mut T>) { - match self { - &Repr::Type2 { next_header, length, segments_left, home_address } => { - header.set_next_header(next_header); - header.set_header_len(length); + match *self { + Repr::Type2 { + segments_left, + home_address, + } => { header.set_routing_type(Type::Type2); header.set_segments_left(segments_left); header.clear_reserved(); header.set_home_address(home_address); } - &Repr::Rpl { next_header, length, segments_left, cmpr_i, cmpr_e, pad, addresses } => { - header.set_next_header(next_header); - header.set_header_len(length); + Repr::Rpl { + segments_left, + cmpr_i, + cmpr_e, + pad, + addresses, + } => { header.set_routing_type(Type::Rpl); header.set_segments_left(segments_left); header.set_cmpr_i(cmpr_i); @@ -485,25 +430,42 @@ impl<'a> Repr<'a> { header.clear_reserved(); header.set_addresses(addresses); } - - &Repr::__Nonexhaustive => unreachable!(), } } } impl<'a> fmt::Display for Repr<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - match self { - &Repr::Type2 { next_header, length, segments_left, home_address } => { - write!(f, "IPv6 Routing next_hdr={} length={} type={} seg_left={} home_address={}", - next_header, length, Type::Type2, segments_left, home_address) + match *self { + Repr::Type2 { + segments_left, + home_address, + } => { + write!( + f, + "IPv6 Routing type={} seg_left={} home_address={}", + Type::Type2, + segments_left, + home_address + ) } - &Repr::Rpl { next_header, length, segments_left, cmpr_i, cmpr_e, pad, .. } => { - write!(f, "IPv6 Routing next_hdr={} length={} type={} seg_left={} cmpr_i={} cmpr_e={} pad={}", - next_header, length, Type::Rpl, segments_left, cmpr_i, cmpr_e, pad) + Repr::Rpl { + segments_left, + cmpr_i, + cmpr_e, + pad, + .. + } => { + write!( + f, + "IPv6 Routing type={} seg_left={} cmpr_i={} cmpr_e={} pad={}", + Type::Rpl, + segments_left, + cmpr_i, + cmpr_e, + pad + ) } - - &Repr::__Nonexhaustive => unreachable!(), } } } @@ -513,107 +475,93 @@ mod test { use super::*; // A Type 2 Routing Header - static BYTES_TYPE2: [u8; 24] = [0x6, 0x2, 0x2, 0x1, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x1]; + static BYTES_TYPE2: [u8; 22] = [ + 0x2, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x1, + ]; // A representation of a Type 2 Routing header static REPR_TYPE2: Repr = Repr::Type2 { - next_header: Protocol::Tcp, - length: 2, segments_left: 1, home_address: Address::LOOPBACK, }; // A Source Routing Header with full IPv6 addresses in bytes - static BYTES_SRH_FULL: [u8; 40] = [0x6, 0x4, 0x3, 0x2, - 0x0, 0x0, 0x0, 0x0, - 0xfd, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x2, - 0xfd, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x3, 0x1]; + static BYTES_SRH_FULL: [u8; 38] = [ + 0x3, 0x2, 0x0, 0x0, 0x0, 0x0, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x0, 0x0, 0x0, 0x2, 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, + 0x3, 0x1, + ]; // A representation of a Source Routing Header with full IPv6 addresses static REPR_SRH_FULL: Repr = Repr::Rpl { - next_header: Protocol::Tcp, - length: 4, segments_left: 2, cmpr_i: 0, cmpr_e: 0, pad: 0, - addresses: &[0xfd, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x2, - 0xfd, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x0, 0x0, - 0x0, 0x0, 0x3, 0x1] + addresses: &[ + 0xfd, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x2, 0xfd, + 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x0, 0x3, 0x1, + ], }; // A Source Routing Header with elided IPv6 addresses in bytes - static BYTES_SRH_ELIDED: [u8; 16] = [0x6, 0x1, 0x3, 0x2, - 0xfe, 0x50, 0x0, 0x0, - 0x2, 0x3, 0x1, 0x0, - 0x0, 0x0, 0x0, 0x0]; + static BYTES_SRH_ELIDED: [u8; 14] = [ + 0x3, 0x2, 0xfe, 0x50, 0x0, 0x0, 0x2, 0x3, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0, + ]; // A representation of a Source Routing Header with elided IPv6 addresses static REPR_SRH_ELIDED: Repr = Repr::Rpl { - next_header: Protocol::Tcp, - length: 1, segments_left: 2, cmpr_i: 15, cmpr_e: 14, pad: 5, - addresses: &[0x2, 0x3, 0x1, 0x0, - 0x0, 0x0, 0x0, 0x0] + addresses: &[0x2, 0x3, 0x1, 0x0, 0x0, 0x0, 0x0, 0x0], }; #[test] fn test_check_len() { // less than min header size - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_TYPE2[..3]).check_len()); - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_SRH_FULL[..3]).check_len()); - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_SRH_ELIDED[..3]).check_len()); - // less than specfied length field - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_TYPE2[..23]).check_len()); - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_SRH_FULL[..39]).check_len()); - assert_eq!(Err(Error::Truncated), Header::new(&BYTES_SRH_ELIDED[..11]).check_len()); + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_TYPE2[..3]).check_len() + ); + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_SRH_FULL[..3]).check_len() + ); + assert_eq!( + Err(Error), + Header::new_unchecked(&BYTES_SRH_ELIDED[..3]).check_len() + ); // valid - assert_eq!(Ok(()), Header::new(&BYTES_TYPE2[..]).check_len()); - assert_eq!(Ok(()), Header::new(&BYTES_SRH_FULL[..]).check_len()); - assert_eq!(Ok(()), Header::new(&BYTES_SRH_ELIDED[..]).check_len()); + assert_eq!(Ok(()), Header::new_unchecked(&BYTES_TYPE2[..]).check_len()); + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_SRH_FULL[..]).check_len() + ); + assert_eq!( + Ok(()), + Header::new_unchecked(&BYTES_SRH_ELIDED[..]).check_len() + ); } #[test] fn test_header_deconstruct() { - let header = Header::new(&BYTES_TYPE2[..]); - assert_eq!(header.next_header(), Protocol::Tcp); - assert_eq!(header.header_len(), 2); + let header = Header::new_unchecked(&BYTES_TYPE2[..]); assert_eq!(header.routing_type(), Type::Type2); assert_eq!(header.segments_left(), 1); assert_eq!(header.home_address(), Address::LOOPBACK); - let header = Header::new(&BYTES_SRH_FULL[..]); - assert_eq!(header.next_header(), Protocol::Tcp); - assert_eq!(header.header_len(), 4); + let header = Header::new_unchecked(&BYTES_SRH_FULL[..]); assert_eq!(header.routing_type(), Type::Rpl); assert_eq!(header.segments_left(), 2); - assert_eq!(header.addresses(), &BYTES_SRH_FULL[8..]); + assert_eq!(header.addresses(), &BYTES_SRH_FULL[6..]); - let header = Header::new(&BYTES_SRH_ELIDED[..]); - assert_eq!(header.next_header(), Protocol::Tcp); - assert_eq!(header.header_len(), 1); + let header = Header::new_unchecked(&BYTES_SRH_ELIDED[..]); assert_eq!(header.routing_type(), Type::Rpl); assert_eq!(header.segments_left(), 2); - assert_eq!(header.addresses(), &BYTES_SRH_ELIDED[8..]); + assert_eq!(header.addresses(), &BYTES_SRH_ELIDED[6..]); } #[test] @@ -633,26 +581,26 @@ mod test { #[test] fn test_repr_emit() { - let mut bytes = [0u8; 24]; - let mut header = Header::new(&mut bytes[..]); + let mut bytes = [0u8; 22]; + let mut header = Header::new_unchecked(&mut bytes[..]); REPR_TYPE2.emit(&mut header); assert_eq!(header.into_inner(), &BYTES_TYPE2[..]); - let mut bytes = [0u8; 40]; - let mut header = Header::new(&mut bytes[..]); + let mut bytes = [0u8; 38]; + let mut header = Header::new_unchecked(&mut bytes[..]); REPR_SRH_FULL.emit(&mut header); assert_eq!(header.into_inner(), &BYTES_SRH_FULL[..]); - let mut bytes = [0u8; 16]; - let mut header = Header::new(&mut bytes[..]); + let mut bytes = [0u8; 14]; + let mut header = Header::new_unchecked(&mut bytes[..]); REPR_SRH_ELIDED.emit(&mut header); assert_eq!(header.into_inner(), &BYTES_SRH_ELIDED[..]); } #[test] fn test_buffer_len() { - assert_eq!(REPR_TYPE2.buffer_len(), 24); - assert_eq!(REPR_SRH_FULL.buffer_len(), 40); - assert_eq!(REPR_SRH_ELIDED.buffer_len(), 16); + assert_eq!(REPR_TYPE2.buffer_len(), 22); + assert_eq!(REPR_SRH_FULL.buffer_len(), 38); + assert_eq!(REPR_SRH_ELIDED.buffer_len(), 14); } } diff --git a/src/wire/mld.rs b/src/wire/mld.rs index 926829399..18872b502 100644 --- a/src/wire/mld.rs +++ b/src/wire/mld.rs @@ -6,16 +6,16 @@ use byteorder::{ByteOrder, NetworkEndian}; -use {Error, Result}; -use super::icmpv6::{field, Message, Packet}; -use super::Ipv6Address; +use super::{Error, Result}; +use crate::wire::icmpv6::{field, Message, Packet}; +use crate::wire::Ipv6Address; enum_with_unknown! { /// MLDv2 Multicast Listener Report Record Type. See [RFC 3810 § 5.2.12] for /// more details. /// /// [RFC 3810 § 5.2.12]: https://tools.ietf.org/html/rfc3010#section-5.2.12 - pub doc enum RecordType(u8) { + pub enum RecordType(u8) { /// Interface has a filter mode of INCLUDE for the specified multicast address. ModeIsInclude = 0x01, /// Interface has a filter mode of EXCLUDE for the specified multicast address. @@ -125,7 +125,7 @@ impl + AsMut<[u8]>> Packet { #[inline] pub fn clear_s_flag(&mut self) { let data = self.buffer.as_mut(); - data[field::SQRV] = data[field::SQRV] & 0x7; + data[field::SQRV] &= 0x7; } /// Set the Querier's Robustness Variable. @@ -165,14 +165,15 @@ impl + AsMut<[u8]>> Packet { } /// A read/write wrapper around an MLDv2 Listener Report Message Address Record. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct AddressRecord> { - buffer: T + buffer: T, } impl> AddressRecord { /// Imbue a raw octet buffer with a Address Record structure. - pub fn new_unchecked(buffer: T) -> Self { + pub const fn new_unchecked(buffer: T) -> Self { Self { buffer } } @@ -191,7 +192,7 @@ impl> AddressRecord { pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::RECORD_MCAST_ADDR.end { - Err(Error::Truncated) + Err(Error) } else { Ok(()) } @@ -215,7 +216,7 @@ impl> AddressRecord { RecordType::from(data[field::RECORD_TYPE]) } - /// Return the length of the auxilary data. + /// Return the length of the auxiliary data. #[inline] pub fn aux_data_len(&self) -> u8 { let data = self.buffer.as_ref(); @@ -258,7 +259,7 @@ impl + AsRef<[u8]>> AddressRecord { data[field::RECORD_TYPE] = rty.into(); } - /// Return the length of the auxilary data. + /// Return the length of the auxiliary data. #[inline] pub fn set_aux_data_len(&mut self, len: u8) { let data = self.buffer.as_mut(); @@ -295,6 +296,7 @@ impl + AsMut<[u8]>> AddressRecord { /// A high-level representation of an MLDv2 packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr<'a> { Query { max_resp_code: u16, @@ -303,58 +305,61 @@ pub enum Repr<'a> { qrv: u8, qqic: u8, num_srcs: u16, - data: &'a [u8] + data: &'a [u8], }, Report { nr_mcast_addr_rcrds: u16, - data: &'a [u8] - } + data: &'a [u8], + }, } impl<'a> Repr<'a> { /// Parse an MLDv2 packet and return a high-level representation. pub fn parse(packet: &Packet<&'a T>) -> Result> - where T: AsRef<[u8]> + ?Sized { + where + T: AsRef<[u8]> + ?Sized, + { match packet.msg_type() { - Message::MldQuery => { - Ok(Repr::Query { - max_resp_code: packet.max_resp_code(), - mcast_addr: packet.mcast_addr(), - s_flag: packet.s_flag(), - qrv: packet.qrv(), - qqic: packet.qqic(), - num_srcs: packet.num_srcs(), - data: packet.payload() - }) - }, - Message::MldReport => { - Ok(Repr::Report { - nr_mcast_addr_rcrds: packet.nr_mcast_addr_rcrds(), - data: packet.payload() - }) - }, - _ => Err(Error::Unrecognized) + Message::MldQuery => Ok(Repr::Query { + max_resp_code: packet.max_resp_code(), + mcast_addr: packet.mcast_addr(), + s_flag: packet.s_flag(), + qrv: packet.qrv(), + qqic: packet.qqic(), + num_srcs: packet.num_srcs(), + data: packet.payload(), + }), + Message::MldReport => Ok(Repr::Report { + nr_mcast_addr_rcrds: packet.nr_mcast_addr_rcrds(), + data: packet.payload(), + }), + _ => Err(Error), } } /// Return the length of a packet that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { match self { - Repr::Query { .. } => { - field::QUERY_NUM_SRCS.end - } - Repr::Report { .. } => { - field::NR_MCAST_RCRDS.end - } + Repr::Query { data, .. } => field::QUERY_NUM_SRCS.end + data.len(), + Repr::Report { data, .. } => field::NR_MCAST_RCRDS.end + data.len(), } } /// Emit a high-level representation into an MLDv2 packet. pub fn emit(&self, packet: &mut Packet<&mut T>) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { match self { - Repr::Query { max_resp_code, mcast_addr, s_flag, - qrv, qqic, num_srcs, data } => { + Repr::Query { + max_resp_code, + mcast_addr, + s_flag, + qrv, + qqic, + num_srcs, + data, + } => { packet.set_msg_type(Message::MldQuery); packet.set_msg_code(0); packet.clear_reserved(); @@ -369,8 +374,11 @@ impl<'a> Repr<'a> { packet.set_qqic(*qqic); packet.set_num_srcs(*num_srcs); packet.payload_mut().copy_from_slice(&data[..]); - }, - Repr::Report { nr_mcast_addr_rcrds, data } => { + } + Repr::Report { + nr_mcast_addr_rcrds, + data, + } => { packet.set_msg_type(Message::MldReport); packet.set_msg_code(0); packet.clear_reserved(); @@ -383,74 +391,49 @@ impl<'a> Repr<'a> { #[cfg(test)] mod test { - use phy::ChecksumCapabilities; - use wire::Icmpv6Repr; - use wire::icmpv6::Message; use super::*; - - static QUERY_PACKET_BYTES: [u8; 44] = - [0x82, 0x00, 0x73, 0x74, - 0x04, 0x00, 0x00, 0x00, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0x0a, 0x12, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02]; - - static QUERY_PACKET_PAYLOAD: [u8; 16] = - [0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02]; - - static REPORT_PACKET_BYTES: [u8; 44] = - [0x8f, 0x00, 0x73, 0x85, - 0x00, 0x00, 0x00, 0x01, - 0x01, 0x00, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02]; - - static REPORT_PACKET_PAYLOAD: [u8; 36] = - [0x01, 0x00, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01, - 0xff, 0x02, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x02]; - + use crate::phy::ChecksumCapabilities; + use crate::wire::icmpv6::Message; + use crate::wire::Icmpv6Repr; + + static QUERY_PACKET_BYTES: [u8; 44] = [ + 0x82, 0x00, 0x73, 0x74, 0x04, 0x00, 0x00, 0x00, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0x0a, 0x12, 0x00, 0x01, 0xff, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; + + static QUERY_PACKET_PAYLOAD: [u8; 16] = [ + 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, + ]; + + static REPORT_PACKET_BYTES: [u8; 44] = [ + 0x8f, 0x00, 0x73, 0x85, 0x00, 0x00, 0x00, 0x01, 0x01, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; + + static REPORT_PACKET_PAYLOAD: [u8; 36] = [ + 0x01, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x01, 0xff, 0x02, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, + ]; fn create_repr<'a>(ty: Message) -> Icmpv6Repr<'a> { match ty { - Message::MldQuery => { - Icmpv6Repr::Mld(Repr::Query { - max_resp_code: 0x400, - mcast_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, - s_flag: true, - qrv: 0x02, - qqic: 0x12, - num_srcs: 0x01, - data: &QUERY_PACKET_PAYLOAD - }) - }, - Message::MldReport => { - Icmpv6Repr::Mld(Repr::Report { - nr_mcast_addr_rcrds: 1, - data: &REPORT_PACKET_PAYLOAD - }) - }, + Message::MldQuery => Icmpv6Repr::Mld(Repr::Query { + max_resp_code: 0x400, + mcast_addr: Ipv6Address::LINK_LOCAL_ALL_NODES, + s_flag: true, + qrv: 0x02, + qqic: 0x12, + num_srcs: 0x01, + data: &QUERY_PACKET_PAYLOAD, + }), + Message::MldReport => Icmpv6Repr::Mld(Repr::Report { + nr_mcast_addr_rcrds: 1, + data: &REPORT_PACKET_PAYLOAD, + }), _ => { panic!("Message type must be a MLDv2 message type"); } @@ -465,12 +448,14 @@ mod test { assert_eq!(packet.checksum(), 0x7374); assert_eq!(packet.max_resp_code(), 0x0400); assert_eq!(packet.mcast_addr(), Ipv6Address::LINK_LOCAL_ALL_NODES); - assert_eq!(packet.s_flag(), true); + assert!(packet.s_flag()); assert_eq!(packet.qrv(), 0x02); assert_eq!(packet.qqic(), 0x12); assert_eq!(packet.num_srcs(), 0x01); - assert_eq!(Ipv6Address::from_bytes(packet.payload()), - Ipv6Address::LINK_LOCAL_ALL_ROUTERS); + assert_eq!( + Ipv6Address::from_bytes(packet.payload()), + Ipv6Address::LINK_LOCAL_ALL_ROUTERS + ); } #[test] @@ -485,11 +470,15 @@ mod test { packet.set_qrv(0x02); packet.set_qqic(0x12); packet.set_num_srcs(0x01); - packet.payload_mut().copy_from_slice(Ipv6Address::LINK_LOCAL_ALL_ROUTERS.as_bytes()); + packet + .payload_mut() + .copy_from_slice(Ipv6Address::LINK_LOCAL_ALL_ROUTERS.as_bytes()); packet.clear_reserved(); - packet.fill_checksum(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into()); - assert_eq!(&packet.into_inner()[..], &QUERY_PACKET_BYTES[..]); + packet.fill_checksum( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + ); + assert_eq!(&*packet.into_inner(), &QUERY_PACKET_BYTES[..]); } #[test] @@ -504,8 +493,10 @@ mod test { assert_eq!(addr_rcrd.aux_data_len(), 0x00); assert_eq!(addr_rcrd.num_srcs(), 0x01); assert_eq!(addr_rcrd.mcast_addr(), Ipv6Address::LINK_LOCAL_ALL_NODES); - assert_eq!(Ipv6Address::from_bytes(addr_rcrd.payload()), - Ipv6Address::LINK_LOCAL_ALL_ROUTERS); + assert_eq!( + Ipv6Address::from_bytes(addr_rcrd.payload()), + Ipv6Address::LINK_LOCAL_ALL_ROUTERS + ); } #[test] @@ -522,31 +513,38 @@ mod test { addr_rcrd.set_aux_data_len(0); addr_rcrd.set_num_srcs(1); addr_rcrd.set_mcast_addr(Ipv6Address::LINK_LOCAL_ALL_NODES); - addr_rcrd.payload_mut() + addr_rcrd + .payload_mut() .copy_from_slice(Ipv6Address::LINK_LOCAL_ALL_ROUTERS.as_bytes()); } - packet.fill_checksum(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into()); - assert_eq!(&packet.into_inner()[..], &REPORT_PACKET_BYTES[..]); + packet.fill_checksum( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + ); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); } #[test] fn test_query_repr_parse() { let packet = Packet::new_unchecked(&QUERY_PACKET_BYTES[..]); - let repr = Icmpv6Repr::parse(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), - &packet, - &ChecksumCapabilities::default()); + let repr = Icmpv6Repr::parse( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &packet, + &ChecksumCapabilities::default(), + ); assert_eq!(repr, Ok(create_repr(Message::MldQuery))); } #[test] fn test_report_repr_parse() { let packet = Packet::new_unchecked(&REPORT_PACKET_BYTES[..]); - let repr = Icmpv6Repr::parse(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), - &packet, - &ChecksumCapabilities::default()); + let repr = Icmpv6Repr::parse( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &packet, + &ChecksumCapabilities::default(), + ); assert_eq!(repr, Ok(create_repr(Message::MldReport))); } @@ -555,11 +553,13 @@ mod test { let mut bytes = [0x2a; 44]; let mut packet = Packet::new_unchecked(&mut bytes[..]); let repr = create_repr(Message::MldQuery); - repr.emit(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), - &mut packet, - &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &QUERY_PACKET_BYTES[..]); + repr.emit( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &QUERY_PACKET_BYTES[..]); } #[test] @@ -567,10 +567,12 @@ mod test { let mut bytes = [0x2a; 44]; let mut packet = Packet::new_unchecked(&mut bytes[..]); let repr = create_repr(Message::MldReport); - repr.emit(&Ipv6Address::LINK_LOCAL_ALL_NODES.into(), - &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), - &mut packet, - &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &REPORT_PACKET_BYTES[..]); + repr.emit( + &Ipv6Address::LINK_LOCAL_ALL_NODES.into(), + &Ipv6Address::LINK_LOCAL_ALL_ROUTERS.into(), + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &REPORT_PACKET_BYTES[..]); } } diff --git a/src/wire/mod.rs b/src/wire/mod.rs index 8641f1438..85f7d8b34 100644 --- a/src/wire/mod.rs +++ b/src/wire/mod.rs @@ -50,9 +50,9 @@ use smoltcp::wire::*; let repr = Ipv4Repr { src_addr: Ipv4Address::new(10, 0, 0, 1), dst_addr: Ipv4Address::new(10, 0, 0, 2), - protocol: IpProtocol::Tcp, + next_header: IpProtocol::Tcp, payload_len: 10, - hop_limit: 64 + hop_limit: 64, }; let mut buffer = vec![0; repr.buffer_len() + repr.payload_len]; { // emission @@ -72,153 +72,441 @@ let mut buffer = vec![0; repr.buffer_len() + repr.payload_len]; mod field { pub type Field = ::core::ops::Range; - pub type Rest = ::core::ops::RangeFrom; + pub type Rest = ::core::ops::RangeFrom; } pub mod pretty_print; -#[cfg(feature = "ethernet")] -mod ethernet; -#[cfg(all(feature = "proto-ipv4", feature = "ethernet"))] +#[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] mod arp; +#[cfg(feature = "proto-dhcpv4")] +pub(crate) mod dhcpv4; +#[cfg(feature = "proto-dns")] +pub(crate) mod dns; +#[cfg(feature = "medium-ethernet")] +mod ethernet; +#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] +mod icmp; +#[cfg(feature = "proto-ipv4")] +mod icmpv4; +#[cfg(feature = "proto-ipv6")] +mod icmpv6; +#[cfg(feature = "medium-ieee802154")] +pub mod ieee802154; +#[cfg(feature = "proto-igmp")] +mod igmp; pub(crate) mod ip; #[cfg(feature = "proto-ipv4")] mod ipv4; #[cfg(feature = "proto-ipv6")] mod ipv6; #[cfg(feature = "proto-ipv6")] -mod ipv6option; -#[cfg(feature = "proto-ipv6")] -mod ipv6hopbyhop; +mod ipv6ext_header; #[cfg(feature = "proto-ipv6")] mod ipv6fragment; #[cfg(feature = "proto-ipv6")] +mod ipv6option; +#[cfg(feature = "proto-ipv6")] mod ipv6routing; -#[cfg(feature = "proto-ipv4")] -mod icmpv4; #[cfg(feature = "proto-ipv6")] -mod icmpv6; -#[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] -mod icmp; -#[cfg(feature = "proto-igmp")] -mod igmp; -#[cfg(all(feature = "proto-ipv6", feature = "ethernet"))] +mod mld; +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] mod ndisc; -#[cfg(all(feature = "proto-ipv6", feature = "ethernet"))] +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] mod ndiscoption; -#[cfg(feature = "proto-ipv6")] -mod mld; -mod udp; +#[cfg(feature = "proto-rpl")] +mod rpl; +#[cfg(all(feature = "proto-sixlowpan", feature = "medium-ieee802154"))] +mod sixlowpan; mod tcp; -#[cfg(feature = "proto-dhcpv4")] -pub(crate) mod dhcpv4; +mod udp; + +use core::fmt; + +use crate::phy::Medium; pub use self::pretty_print::PrettyPrinter; -#[cfg(feature = "ethernet")] -pub use self::ethernet::{EtherType as EthernetProtocol, - Address as EthernetAddress, - Frame as EthernetFrame, - Repr as EthernetRepr}; - -#[cfg(all(feature = "proto-ipv4", feature = "ethernet"))] -pub use self::arp::{Hardware as ArpHardware, - Operation as ArpOperation, - Packet as ArpPacket, - Repr as ArpRepr}; - -pub use self::ip::{Version as IpVersion, - Protocol as IpProtocol, - Address as IpAddress, - Endpoint as IpEndpoint, - Repr as IpRepr, - Cidr as IpCidr}; +#[cfg(feature = "medium-ethernet")] +pub use self::ethernet::{ + Address as EthernetAddress, EtherType as EthernetProtocol, Frame as EthernetFrame, + Repr as EthernetRepr, HEADER_LEN as ETHERNET_HEADER_LEN, +}; + +#[cfg(all(feature = "proto-ipv4", feature = "medium-ethernet"))] +pub use self::arp::{ + Hardware as ArpHardware, Operation as ArpOperation, Packet as ArpPacket, Repr as ArpRepr, +}; + +#[cfg(feature = "proto-rpl")] +pub use self::rpl::{ + data::HopByHopOption as RplHopByHopRepr, data::Packet as RplHopByHopPacket, + options::Packet as RplOptionPacket, options::Repr as RplOptionRepr, + InstanceId as RplInstanceId, Repr as RplRepr, +}; + +#[cfg(all(feature = "proto-sixlowpan", feature = "medium-ieee802154"))] +pub use self::sixlowpan::{ + frag::{Key as SixlowpanFragKey, Packet as SixlowpanFragPacket, Repr as SixlowpanFragRepr}, + iphc::{Packet as SixlowpanIphcPacket, Repr as SixlowpanIphcRepr}, + nhc::{ + ExtHeaderPacket as SixlowpanExtHeaderPacket, ExtHeaderRepr as SixlowpanExtHeaderRepr, + NhcPacket as SixlowpanNhcPacket, UdpNhcPacket as SixlowpanUdpNhcPacket, + UdpNhcRepr as SixlowpanUdpNhcRepr, + }, + AddressContext as SixlowpanAddressContext, NextHeader as SixlowpanNextHeader, SixlowpanPacket, +}; + +#[cfg(feature = "medium-ieee802154")] +pub use self::ieee802154::{ + Address as Ieee802154Address, AddressingMode as Ieee802154AddressingMode, + Frame as Ieee802154Frame, FrameType as Ieee802154FrameType, + FrameVersion as Ieee802154FrameVersion, Pan as Ieee802154Pan, Repr as Ieee802154Repr, +}; + +pub use self::ip::{ + Address as IpAddress, Cidr as IpCidr, Endpoint as IpEndpoint, + ListenEndpoint as IpListenEndpoint, Protocol as IpProtocol, Repr as IpRepr, + Version as IpVersion, +}; #[cfg(feature = "proto-ipv4")] -pub use self::ipv4::{Address as Ipv4Address, - Packet as Ipv4Packet, - Repr as Ipv4Repr, - Cidr as Ipv4Cidr, - MIN_MTU as IPV4_MIN_MTU}; +pub use self::ipv4::{ + Address as Ipv4Address, Cidr as Ipv4Cidr, Key as Ipv4FragKey, Packet as Ipv4Packet, + Repr as Ipv4Repr, HEADER_LEN as IPV4_HEADER_LEN, MIN_MTU as IPV4_MIN_MTU, +}; #[cfg(feature = "proto-ipv6")] -pub use self::ipv6::{Address as Ipv6Address, - Packet as Ipv6Packet, - Repr as Ipv6Repr, - Cidr as Ipv6Cidr, - MIN_MTU as IPV6_MIN_MTU}; +pub use self::ipv6::{ + Address as Ipv6Address, Cidr as Ipv6Cidr, Packet as Ipv6Packet, Repr as Ipv6Repr, + HEADER_LEN as IPV6_HEADER_LEN, MIN_MTU as IPV6_MIN_MTU, +}; #[cfg(feature = "proto-ipv6")] -pub use self::ipv6option::{Ipv6Option, - Repr as Ipv6OptionRepr, - Type as Ipv6OptionType, - FailureType as Ipv6OptionFailureType}; +pub use self::ipv6option::{ + FailureType as Ipv6OptionFailureType, Ipv6Option, Ipv6OptionsIterator, Repr as Ipv6OptionRepr, + Type as Ipv6OptionType, +}; #[cfg(feature = "proto-ipv6")] -pub use self::ipv6hopbyhop::{Header as Ipv6HopByHopHeader, - Repr as Ipv6HopByHopRepr}; +pub use self::ipv6ext_header::{Header as Ipv6ExtHeader, Repr as Ipv6ExtHeaderRepr}; #[cfg(feature = "proto-ipv6")] -pub use self::ipv6fragment::{Header as Ipv6FragmentHeader, - Repr as Ipv6FragmentRepr}; +/// A read/write wrapper around an IPv6 Hop-By-Hop header. +pub type Ipv6HopByHopHeader = Ipv6ExtHeader; +#[cfg(feature = "proto-ipv6")] +/// A high-level representation of an IPv6 Hop-By-Hop heade. +pub type Ipv6HopByHopRepr<'a> = Ipv6ExtHeaderRepr<'a>; #[cfg(feature = "proto-ipv6")] -pub use self::ipv6routing::{Header as Ipv6RoutingHeader, - Repr as Ipv6RoutingRepr}; +pub use self::ipv6fragment::{Header as Ipv6FragmentHeader, Repr as Ipv6FragmentRepr}; + +#[cfg(feature = "proto-ipv6")] +pub use self::ipv6routing::{ + Header as Ipv6RoutingHeader, Repr as Ipv6RoutingRepr, Type as Ipv6RoutingType, +}; #[cfg(feature = "proto-ipv4")] -pub use self::icmpv4::{Message as Icmpv4Message, - DstUnreachable as Icmpv4DstUnreachable, - Redirect as Icmpv4Redirect, - TimeExceeded as Icmpv4TimeExceeded, - ParamProblem as Icmpv4ParamProblem, - Packet as Icmpv4Packet, - Repr as Icmpv4Repr}; +pub use self::icmpv4::{ + DstUnreachable as Icmpv4DstUnreachable, Message as Icmpv4Message, Packet as Icmpv4Packet, + ParamProblem as Icmpv4ParamProblem, Redirect as Icmpv4Redirect, Repr as Icmpv4Repr, + TimeExceeded as Icmpv4TimeExceeded, +}; #[cfg(feature = "proto-igmp")] -pub use self::igmp::{Packet as IgmpPacket, - Repr as IgmpRepr, - IgmpVersion}; +pub use self::igmp::{IgmpVersion, Packet as IgmpPacket, Repr as IgmpRepr}; #[cfg(feature = "proto-ipv6")] -pub use self::icmpv6::{Message as Icmpv6Message, - DstUnreachable as Icmpv6DstUnreachable, - TimeExceeded as Icmpv6TimeExceeded, - ParamProblem as Icmpv6ParamProblem, - Packet as Icmpv6Packet, - Repr as Icmpv6Repr}; +pub use self::icmpv6::{ + DstUnreachable as Icmpv6DstUnreachable, Message as Icmpv6Message, Packet as Icmpv6Packet, + ParamProblem as Icmpv6ParamProblem, Repr as Icmpv6Repr, TimeExceeded as Icmpv6TimeExceeded, +}; #[cfg(any(feature = "proto-ipv4", feature = "proto-ipv6"))] pub use self::icmp::Repr as IcmpRepr; +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +pub use self::ndisc::{ + NeighborFlags as NdiscNeighborFlags, Repr as NdiscRepr, RouterFlags as NdiscRouterFlags, +}; -#[cfg(all(feature = "proto-ipv6", feature = "ethernet"))] -pub use self::ndisc::{Repr as NdiscRepr, - RouterFlags as NdiscRouterFlags, - NeighborFlags as NdiscNeighborFlags}; - -#[cfg(all(feature = "proto-ipv6", feature = "ethernet"))] -pub use self::ndiscoption::{NdiscOption, - Repr as NdiscOptionRepr, - Type as NdiscOptionType, - PrefixInformation as NdiscPrefixInformation, - RedirectedHeader as NdiscRedirectedHeader, - PrefixInfoFlags as NdiscPrefixInfoFlags}; +#[cfg(all( + feature = "proto-ipv6", + any(feature = "medium-ethernet", feature = "medium-ieee802154") +))] +pub use self::ndiscoption::{ + NdiscOption, PrefixInfoFlags as NdiscPrefixInfoFlags, + PrefixInformation as NdiscPrefixInformation, RedirectedHeader as NdiscRedirectedHeader, + Repr as NdiscOptionRepr, Type as NdiscOptionType, +}; #[cfg(feature = "proto-ipv6")] -pub use self::mld::{AddressRecord as MldAddressRecord, - Repr as MldRepr}; +pub use self::mld::{AddressRecord as MldAddressRecord, Repr as MldRepr}; -pub use self::udp::{Packet as UdpPacket, - Repr as UdpRepr}; +pub use self::udp::{Packet as UdpPacket, Repr as UdpRepr, HEADER_LEN as UDP_HEADER_LEN}; -pub use self::tcp::{SeqNumber as TcpSeqNumber, - Packet as TcpPacket, - TcpOption, - Repr as TcpRepr, - Control as TcpControl}; +pub use self::tcp::{ + Control as TcpControl, Packet as TcpPacket, Repr as TcpRepr, SeqNumber as TcpSeqNumber, + TcpOption, HEADER_LEN as TCP_HEADER_LEN, +}; #[cfg(feature = "proto-dhcpv4")] -pub use self::dhcpv4::{Packet as DhcpPacket, - Repr as DhcpRepr, - MessageType as DhcpMessageType}; +pub use self::dhcpv4::{ + DhcpOption, DhcpOptionWriter, MessageType as DhcpMessageType, Packet as DhcpPacket, + Repr as DhcpRepr, CLIENT_PORT as DHCP_CLIENT_PORT, + MAX_DNS_SERVER_COUNT as DHCP_MAX_DNS_SERVER_COUNT, SERVER_PORT as DHCP_SERVER_PORT, +}; + +#[cfg(feature = "proto-dns")] +pub use self::dns::{ + Flags as DnsFlags, Opcode as DnsOpcode, Packet as DnsPacket, Rcode as DnsRcode, + Repr as DnsRepr, Type as DnsQueryType, +}; + +/// Parsing a packet failed. +/// +/// Either it is malformed, or it is not supported by smoltcp. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct Error; + +#[cfg(feature = "std")] +impl std::error::Error for Error {} + +impl fmt::Display for Error { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + write!(f, "wire::Error") + } +} + +pub type Result = core::result::Result; + +/// Representation of an hardware address, such as an Ethernet address or an IEEE802.15.4 address. +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum HardwareAddress { + #[cfg(feature = "medium-ip")] + Ip, + #[cfg(feature = "medium-ethernet")] + Ethernet(EthernetAddress), + #[cfg(feature = "medium-ieee802154")] + Ieee802154(Ieee802154Address), +} + +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +impl HardwareAddress { + pub const fn as_bytes(&self) -> &[u8] { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.as_bytes(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.as_bytes(), + } + } + + /// Query wether the address is an unicast address. + pub fn is_unicast(&self) -> bool { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.is_unicast(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.is_unicast(), + } + } + + /// Query wether the address is a broadcast address. + pub fn is_broadcast(&self) -> bool { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => unreachable!(), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => addr.is_broadcast(), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => addr.is_broadcast(), + } + } + + #[cfg(feature = "medium-ethernet")] + pub(crate) fn ethernet_or_panic(&self) -> EthernetAddress { + match self { + HardwareAddress::Ethernet(addr) => *addr, + #[allow(unreachable_patterns)] + _ => panic!("HardwareAddress is not Ethernet."), + } + } + + #[cfg(feature = "medium-ieee802154")] + pub(crate) fn ieee802154_or_panic(&self) -> Ieee802154Address { + match self { + HardwareAddress::Ieee802154(addr) => *addr, + #[allow(unreachable_patterns)] + _ => panic!("HardwareAddress is not Ethernet."), + } + } + + #[inline] + pub(crate) fn medium(&self) -> Medium { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => Medium::Ip, + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(_) => Medium::Ethernet, + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(_) => Medium::Ieee802154, + } + } +} + +#[cfg(any( + feature = "medium-ip", + feature = "medium-ethernet", + feature = "medium-ieee802154" +))] +impl core::fmt::Display for HardwareAddress { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + match self { + #[cfg(feature = "medium-ip")] + HardwareAddress::Ip => write!(f, "no hardware addr"), + #[cfg(feature = "medium-ethernet")] + HardwareAddress::Ethernet(addr) => write!(f, "{addr}"), + #[cfg(feature = "medium-ieee802154")] + HardwareAddress::Ieee802154(addr) => write!(f, "{addr}"), + } + } +} + +#[cfg(feature = "medium-ethernet")] +impl From for HardwareAddress { + fn from(addr: EthernetAddress) -> Self { + HardwareAddress::Ethernet(addr) + } +} + +#[cfg(feature = "medium-ieee802154")] +impl From for HardwareAddress { + fn from(addr: Ieee802154Address) -> Self { + HardwareAddress::Ieee802154(addr) + } +} + +#[cfg(not(feature = "medium-ieee802154"))] +pub const MAX_HARDWARE_ADDRESS_LEN: usize = 6; +#[cfg(feature = "medium-ieee802154")] +pub const MAX_HARDWARE_ADDRESS_LEN: usize = 8; + +/// Unparsed hardware address. +/// +/// Used to make NDISC parsing agnostic of the hardware medium in use. +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct RawHardwareAddress { + len: u8, + data: [u8; MAX_HARDWARE_ADDRESS_LEN], +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl RawHardwareAddress { + pub fn from_bytes(addr: &[u8]) -> Self { + let mut data = [0u8; MAX_HARDWARE_ADDRESS_LEN]; + data[..addr.len()].copy_from_slice(addr); + + Self { + len: addr.len() as u8, + data, + } + } + + pub fn as_bytes(&self) -> &[u8] { + &self.data[..self.len as usize] + } + + pub const fn len(&self) -> usize { + self.len as usize + } + + pub const fn is_empty(&self) -> bool { + self.len == 0 + } + + pub fn parse(&self, medium: Medium) -> Result { + match medium { + #[cfg(feature = "medium-ethernet")] + Medium::Ethernet => { + if self.len() < 6 { + return Err(Error); + } + Ok(HardwareAddress::Ethernet(EthernetAddress::from_bytes( + self.as_bytes(), + ))) + } + #[cfg(feature = "medium-ieee802154")] + Medium::Ieee802154 => { + if self.len() < 8 { + return Err(Error); + } + Ok(HardwareAddress::Ieee802154(Ieee802154Address::from_bytes( + self.as_bytes(), + ))) + } + #[cfg(feature = "medium-ip")] + Medium::Ip => unreachable!(), + } + } +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl core::fmt::Display for RawHardwareAddress { + fn fmt(&self, f: &mut core::fmt::Formatter) -> core::fmt::Result { + for (i, &b) in self.as_bytes().iter().enumerate() { + if i != 0 { + write!(f, ":")?; + } + write!(f, "{b:02x}")?; + } + Ok(()) + } +} + +#[cfg(feature = "medium-ethernet")] +impl From for RawHardwareAddress { + fn from(addr: EthernetAddress) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} + +#[cfg(feature = "medium-ieee802154")] +impl From for RawHardwareAddress { + fn from(addr: Ieee802154Address) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} + +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] +impl From for RawHardwareAddress { + fn from(addr: HardwareAddress) -> Self { + Self::from_bytes(addr.as_bytes()) + } +} diff --git a/src/wire/ndisc.rs b/src/wire/ndisc.rs index 4f3e5c458..7ea92447c 100644 --- a/src/wire/ndisc.rs +++ b/src/wire/ndisc.rs @@ -1,14 +1,16 @@ +use bitflags::bitflags; use byteorder::{ByteOrder, NetworkEndian}; -use {Error, Result}; -use super::icmpv6::{field, Message, Packet}; -use wire::{EthernetAddress, Ipv6Repr, Ipv6Packet}; -use wire::{NdiscOption, NdiscOptionRepr, NdiscOptionType}; -use wire::{NdiscPrefixInformation, NdiscRedirectedHeader}; -use time::Duration; -use super::Ipv6Address; +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::icmpv6::{field, Message, Packet}; +use crate::wire::Ipv6Address; +use crate::wire::RawHardwareAddress; +use crate::wire::{NdiscOption, NdiscOptionRepr}; +use crate::wire::{NdiscPrefixInformation, NdiscRedirectedHeader}; bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct RouterFlags: u8 { const MANAGED = 0b10000000; const OTHER = 0b01000000; @@ -16,6 +18,7 @@ bitflags! { } bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct NeighborFlags: u8 { const ROUTER = 0b10000000; const SOLICITED = 0b01000000; @@ -79,7 +82,6 @@ impl> Packet { } } - /// Getters for the Neighbor Solicitation message header. /// See [RFC 4861 § 4.3]. /// @@ -188,9 +190,10 @@ impl + AsMut<[u8]>> Packet { /// A high-level representation of an Neighbor Discovery packet header. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr<'a> { RouterSolicit { - lladdr: Option + lladdr: Option, }, RouterAdvert { hop_limit: u8, @@ -198,180 +201,200 @@ pub enum Repr<'a> { router_lifetime: Duration, reachable_time: Duration, retrans_time: Duration, - lladdr: Option, + lladdr: Option, mtu: Option, - prefix_info: Option + prefix_info: Option, }, NeighborSolicit { target_addr: Ipv6Address, - lladdr: Option + lladdr: Option, }, NeighborAdvert { flags: NeighborFlags, target_addr: Ipv6Address, - lladdr: Option + lladdr: Option, }, Redirect { target_addr: Ipv6Address, dest_addr: Ipv6Address, - lladdr: Option, - redirected_hdr: Option> - } + lladdr: Option, + redirected_hdr: Option>, + }, } impl<'a> Repr<'a> { /// Parse an NDISC packet and return a high-level representation of the /// packet. - pub fn parse(packet: &Packet<&'a T>) - -> Result> - where T: AsRef<[u8]> + ?Sized { + #[allow(clippy::single_match)] + pub fn parse(packet: &Packet<&'a T>) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { + fn foreach_option<'a>( + payload: &'a [u8], + mut f: impl FnMut(NdiscOptionRepr<'a>) -> Result<()>, + ) -> Result<()> { + let mut offset = 0; + while payload.len() > offset { + let pkt = NdiscOption::new_checked(&payload[offset..])?; + + // If an option doesn't parse, ignore it and still parse the others. + if let Ok(opt) = NdiscOptionRepr::parse(&pkt) { + f(opt)?; + } + + let len = pkt.data_len() as usize * 8; + if len == 0 { + return Err(Error); + } + offset += len; + } + Ok(()) + } + match packet.msg_type() { Message::RouterSolicit => { - let lladdr = if packet.payload().len() > 0 { - let opt = NdiscOption::new_checked(packet.payload())?; - match opt.option_type() { - NdiscOptionType::SourceLinkLayerAddr => Some(opt.link_layer_addr()), - _ => { return Err(Error::Unrecognized); } + let mut lladdr = None; + foreach_option(packet.payload(), |opt| { + match opt { + NdiscOptionRepr::SourceLinkLayerAddr(addr) => lladdr = Some(addr), + _ => {} } - } else { - None - }; + Ok(()) + })?; Ok(Repr::RouterSolicit { lladdr }) - }, + } Message::RouterAdvert => { - let mut offset = 0; let (mut lladdr, mut mtu, mut prefix_info) = (None, None, None); - while packet.payload().len() - offset > 0 { - let pkt = NdiscOption::new_checked(&packet.payload()[offset..])?; - let opt = NdiscOptionRepr::parse(&pkt)?; + foreach_option(packet.payload(), |opt| { match opt { NdiscOptionRepr::SourceLinkLayerAddr(addr) => lladdr = Some(addr), NdiscOptionRepr::Mtu(val) => mtu = Some(val), NdiscOptionRepr::PrefixInformation(info) => prefix_info = Some(info), - _ => { return Err(Error::Unrecognized); } + _ => {} } - offset += opt.buffer_len(); - } + Ok(()) + })?; Ok(Repr::RouterAdvert { hop_limit: packet.current_hop_limit(), flags: packet.router_flags(), router_lifetime: packet.router_lifetime(), reachable_time: packet.reachable_time(), retrans_time: packet.retrans_time(), - lladdr, mtu, prefix_info + lladdr, + mtu, + prefix_info, }) - }, + } Message::NeighborSolicit => { - let lladdr = if packet.payload().len() > 0 { - let opt = NdiscOption::new_checked(packet.payload())?; - match opt.option_type() { - NdiscOptionType::SourceLinkLayerAddr => Some(opt.link_layer_addr()), - _ => { return Err(Error::Unrecognized); } + let mut lladdr = None; + foreach_option(packet.payload(), |opt| { + match opt { + NdiscOptionRepr::SourceLinkLayerAddr(addr) => lladdr = Some(addr), + _ => {} } - } else { - None - }; + Ok(()) + })?; Ok(Repr::NeighborSolicit { - target_addr: packet.target_addr(), lladdr + target_addr: packet.target_addr(), + lladdr, }) - }, + } Message::NeighborAdvert => { - let lladdr = if packet.payload().len() > 0 { - let opt = NdiscOption::new_checked(packet.payload())?; - match opt.option_type() { - NdiscOptionType::TargetLinkLayerAddr => Some(opt.link_layer_addr()), - _ => { return Err(Error::Unrecognized); } + let mut lladdr = None; + foreach_option(packet.payload(), |opt| { + match opt { + NdiscOptionRepr::TargetLinkLayerAddr(addr) => lladdr = Some(addr), + _ => {} } - } else { - None - }; + Ok(()) + })?; Ok(Repr::NeighborAdvert { flags: packet.neighbor_flags(), target_addr: packet.target_addr(), - lladdr + lladdr, }) - }, + } Message::Redirect => { - let mut offset = 0; let (mut lladdr, mut redirected_hdr) = (None, None); - while packet.payload().len() - offset > 0 { - let opt = NdiscOption::new_checked(&packet.payload()[offset..])?; - match opt.option_type() { - NdiscOptionType::SourceLinkLayerAddr => { - lladdr = Some(opt.link_layer_addr()); - offset += 8; - }, - NdiscOptionType::RedirectedHeader => { - if opt.data_len() < 6 { - return Err(Error::Truncated) - } else { - let ip_packet = - Ipv6Packet::new_unchecked(&opt.data()[offset + 8..]); - let ip_repr = Ipv6Repr::parse(&ip_packet)?; - let data = &opt.data()[offset + 8 + ip_repr.buffer_len()..]; - redirected_hdr = Some(NdiscRedirectedHeader { - header: ip_repr, data - }); - offset += 8 + ip_repr.buffer_len() + data.len(); - } - } - _ => { return Err(Error::Unrecognized); } + + foreach_option(packet.payload(), |opt| { + match opt { + NdiscOptionRepr::SourceLinkLayerAddr(addr) => lladdr = Some(addr), + NdiscOptionRepr::RedirectedHeader(rh) => redirected_hdr = Some(rh), + _ => {} } - } + Ok(()) + })?; Ok(Repr::Redirect { target_addr: packet.target_addr(), dest_addr: packet.dest_addr(), - lladdr, redirected_hdr + lladdr, + redirected_hdr, }) - }, - _ => Err(Error::Unrecognized) + } + _ => Err(Error), } } - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { match self { - &Repr::RouterSolicit { lladdr } => { - match lladdr { - Some(_) => field::UNUSED.end + 8, - None => field::UNUSED.end, + &Repr::RouterSolicit { lladdr } => match lladdr { + Some(addr) => { + field::UNUSED.end + { NdiscOptionRepr::SourceLinkLayerAddr(addr).buffer_len() } } + None => field::UNUSED.end, }, - &Repr::RouterAdvert { lladdr, mtu, prefix_info, .. } => { + &Repr::RouterAdvert { + lladdr, + mtu, + prefix_info, + .. + } => { let mut offset = 0; - if lladdr.is_some() { - offset += 8; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len(); } - if mtu.is_some() { - offset += 8; + if let Some(mtu) = mtu { + offset += NdiscOptionRepr::Mtu(mtu).buffer_len(); } - if prefix_info.is_some() { - offset += 32; + if let Some(prefix_info) = prefix_info { + offset += NdiscOptionRepr::PrefixInformation(prefix_info).buffer_len(); } field::RETRANS_TM.end + offset - }, + } &Repr::NeighborSolicit { lladdr, .. } | &Repr::NeighborAdvert { lladdr, .. } => { - match lladdr { - Some(_) => field::TARGET_ADDR.end + 8, - None => field::TARGET_ADDR.end, + let mut offset = field::TARGET_ADDR.end; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::SourceLinkLayerAddr(lladdr).buffer_len(); } - }, - &Repr::Redirect { lladdr, redirected_hdr, .. } => { - let mut offset = 0; - if lladdr.is_some() { - offset += 8; + offset + } + &Repr::Redirect { + lladdr, + redirected_hdr, + .. + } => { + let mut offset = field::DEST_ADDR.end; + if let Some(lladdr) = lladdr { + offset += NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len(); } if let Some(NdiscRedirectedHeader { header, data }) = redirected_hdr { - offset += 8 + header.buffer_len() + data.len(); + offset += + NdiscOptionRepr::RedirectedHeader(NdiscRedirectedHeader { header, data }) + .buffer_len(); } - field::DEST_ADDR.end + offset + offset } } } pub fn emit(&self, packet: &mut Packet<&mut T>) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { - match self { - &Repr::RouterSolicit { lladdr } => { + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match *self { + Repr::RouterSolicit { lladdr } => { packet.set_msg_type(Message::RouterSolicit); packet.set_msg_code(0); packet.clear_reserved(); @@ -379,10 +402,18 @@ impl<'a> Repr<'a> { let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); NdiscOptionRepr::SourceLinkLayerAddr(lladdr).emit(&mut opt_pkt); } - }, + } - &Repr::RouterAdvert { hop_limit, flags, router_lifetime, reachable_time, - retrans_time, lladdr, mtu, prefix_info } => { + Repr::RouterAdvert { + hop_limit, + flags, + router_lifetime, + reachable_time, + retrans_time, + lladdr, + mtu, + prefix_info, + } => { packet.set_msg_type(Message::RouterAdvert); packet.set_msg_code(0); packet.set_current_hop_limit(hop_limit); @@ -392,50 +423,60 @@ impl<'a> Repr<'a> { packet.set_retrans_time(retrans_time); let mut offset = 0; if let Some(lladdr) = lladdr { - let mut opt_pkt = - NdiscOption::new_unchecked(packet.payload_mut()); - NdiscOptionRepr::SourceLinkLayerAddr(lladdr).emit(&mut opt_pkt); - offset += 8; + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); + let opt = NdiscOptionRepr::SourceLinkLayerAddr(lladdr); + opt.emit(&mut opt_pkt); + offset += opt.buffer_len(); } if let Some(mtu) = mtu { let mut opt_pkt = NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); NdiscOptionRepr::Mtu(mtu).emit(&mut opt_pkt); - offset += 8; + offset += NdiscOptionRepr::Mtu(mtu).buffer_len(); } if let Some(prefix_info) = prefix_info { let mut opt_pkt = NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); NdiscOptionRepr::PrefixInformation(prefix_info).emit(&mut opt_pkt) } - }, + } - &Repr::NeighborSolicit { target_addr, lladdr } => { + Repr::NeighborSolicit { + target_addr, + lladdr, + } => { packet.set_msg_type(Message::NeighborSolicit); packet.set_msg_code(0); packet.clear_reserved(); packet.set_target_addr(target_addr); if let Some(lladdr) = lladdr { - let mut opt_pkt = - NdiscOption::new_unchecked(packet.payload_mut()); + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); NdiscOptionRepr::SourceLinkLayerAddr(lladdr).emit(&mut opt_pkt); } - }, + } - &Repr::NeighborAdvert { flags, target_addr, lladdr } => { + Repr::NeighborAdvert { + flags, + target_addr, + lladdr, + } => { packet.set_msg_type(Message::NeighborAdvert); packet.set_msg_code(0); packet.clear_reserved(); packet.set_neighbor_flags(flags); packet.set_target_addr(target_addr); if let Some(lladdr) = lladdr { - let mut opt_pkt = - NdiscOption::new_unchecked(packet.payload_mut()); + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); NdiscOptionRepr::TargetLinkLayerAddr(lladdr).emit(&mut opt_pkt); } - }, + } - &Repr::Redirect { target_addr, dest_addr, lladdr, redirected_hdr } => { + Repr::Redirect { + target_addr, + dest_addr, + lladdr, + redirected_hdr, + } => { packet.set_msg_type(Message::Redirect); packet.set_msg_code(0); packet.clear_reserved(); @@ -443,11 +484,10 @@ impl<'a> Repr<'a> { packet.set_dest_addr(dest_addr); let offset = match lladdr { Some(lladdr) => { - let mut opt_pkt = - NdiscOption::new_unchecked(packet.payload_mut()); + let mut opt_pkt = NdiscOption::new_unchecked(packet.payload_mut()); NdiscOptionRepr::TargetLinkLayerAddr(lladdr).emit(&mut opt_pkt); - 8 - }, + NdiscOptionRepr::TargetLinkLayerAddr(lladdr).buffer_len() + } None => 0, }; if let Some(redirected_hdr) = redirected_hdr { @@ -455,28 +495,25 @@ impl<'a> Repr<'a> { NdiscOption::new_unchecked(&mut packet.payload_mut()[offset..]); NdiscOptionRepr::RedirectedHeader(redirected_hdr).emit(&mut opt_pkt); } - }, + } } } } +#[cfg(feature = "medium-ethernet")] #[cfg(test)] mod test { - use phy::ChecksumCapabilities; use super::*; - use wire::Icmpv6Repr; - use wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; - - static ROUTER_ADVERT_BYTES: [u8; 24] = - [0x86, 0x00, 0xa9, 0xde, - 0x40, 0x80, 0x03, 0x84, - 0x00, 0x00, 0x03, 0x84, - 0x00, 0x00, 0x03, 0x84, - 0x01, 0x01, 0x52, 0x54, - 0x00, 0x12, 0x34, 0x56]; - static SOURCE_LINK_LAYER_OPT: [u8; 8] = - [0x01, 0x01, 0x52, 0x54, - 0x00, 0x12, 0x34, 0x56]; + use crate::phy::ChecksumCapabilities; + use crate::wire::ip::test::{MOCK_IP_ADDR_1, MOCK_IP_ADDR_2}; + use crate::wire::EthernetAddress; + use crate::wire::Icmpv6Repr; + + static ROUTER_ADVERT_BYTES: [u8; 24] = [ + 0x86, 0x00, 0xa9, 0xde, 0x40, 0x80, 0x03, 0x84, 0x00, 0x00, 0x03, 0x84, 0x00, 0x00, 0x03, + 0x84, 0x01, 0x01, 0x52, 0x54, 0x00, 0x12, 0x34, 0x56, + ]; + static SOURCE_LINK_LAYER_OPT: [u8; 8] = [0x01, 0x01, 0x52, 0x54, 0x00, 0x12, 0x34, 0x56]; fn create_repr<'a>() -> Icmpv6Repr<'a> { Icmpv6Repr::Ndisc(Repr::RouterAdvert { @@ -485,9 +522,9 @@ mod test { router_lifetime: Duration::from_secs(900), reachable_time: Duration::from_millis(900), retrans_time: Duration::from_millis(900), - lladdr: Some(EthernetAddress([0x52, 0x54, 0x00, 0x12, 0x34, 0x56])), + lladdr: Some(EthernetAddress([0x52, 0x54, 0x00, 0x12, 0x34, 0x56]).into()), mtu: None, - prefix_info: None + prefix_info: None, }) } @@ -515,25 +552,38 @@ mod test { packet.set_router_lifetime(Duration::from_secs(900)); packet.set_reachable_time(Duration::from_millis(900)); packet.set_retrans_time(Duration::from_millis(900)); - packet.payload_mut().copy_from_slice(&SOURCE_LINK_LAYER_OPT[..]); + packet + .payload_mut() + .copy_from_slice(&SOURCE_LINK_LAYER_OPT[..]); packet.fill_checksum(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2); - assert_eq!(&packet.into_inner()[..], &ROUTER_ADVERT_BYTES[..]); + assert_eq!(&*packet.into_inner(), &ROUTER_ADVERT_BYTES[..]); } #[test] fn test_router_advert_repr_parse() { let packet = Packet::new_unchecked(&ROUTER_ADVERT_BYTES[..]); - assert_eq!(Icmpv6Repr::parse(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &packet, &ChecksumCapabilities::default()).unwrap(), - create_repr()); + assert_eq!( + Icmpv6Repr::parse( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &packet, + &ChecksumCapabilities::default() + ) + .unwrap(), + create_repr() + ); } #[test] fn test_router_advert_repr_emit() { let mut bytes = vec![0x2a; 24]; let mut packet = Packet::new_unchecked(&mut bytes[..]); - create_repr().emit(&MOCK_IP_ADDR_1, &MOCK_IP_ADDR_2, - &mut packet, &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &ROUTER_ADVERT_BYTES[..]); + create_repr().emit( + &MOCK_IP_ADDR_1, + &MOCK_IP_ADDR_2, + &mut packet, + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &ROUTER_ADVERT_BYTES[..]); } } diff --git a/src/wire/ndiscoption.rs b/src/wire/ndiscoption.rs index 51803318b..eff7a93c1 100644 --- a/src/wire/ndiscoption.rs +++ b/src/wire/ndiscoption.rs @@ -1,13 +1,16 @@ +use bitflags::bitflags; +use byteorder::{ByteOrder, NetworkEndian}; use core::fmt; -use byteorder::{NetworkEndian, ByteOrder}; -use {Error, Result}; -use time::Duration; -use wire::{EthernetAddress, Ipv6Address, Ipv6Packet, Ipv6Repr}; +use super::{Error, Result}; +use crate::time::Duration; +use crate::wire::{Ipv6Address, Ipv6Packet, Ipv6Repr, MAX_HARDWARE_ADDRESS_LEN}; + +use crate::wire::RawHardwareAddress; enum_with_unknown! { /// NDISC Option Type - pub doc enum Type(u8) { + pub enum Type(u8) { /// Source Link-layer Address SourceLinkLayerAddr = 0x1, /// Target Link-layer Address @@ -24,17 +27,18 @@ enum_with_unknown! { impl fmt::Display for Type { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match self { - &Type::SourceLinkLayerAddr => write!(f, "source link-layer address"), - &Type::TargetLinkLayerAddr => write!(f, "target link-layer address"), - &Type::PrefixInformation => write!(f, "prefix information"), - &Type::RedirectedHeader => write!(f, "redirected header"), - &Type::Mtu => write!(f, "mtu"), - &Type::Unknown(id) => write!(f, "{}", id) + Type::SourceLinkLayerAddr => write!(f, "source link-layer address"), + Type::TargetLinkLayerAddr => write!(f, "target link-layer address"), + Type::PrefixInformation => write!(f, "prefix information"), + Type::RedirectedHeader => write!(f, "redirected header"), + Type::Mtu => write!(f, "mtu"), + Type::Unknown(id) => write!(f, "{id}"), } } } bitflags! { + #[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct PrefixInfoFlags: u8 { const ON_LINK = 0b10000000; const ADDRCONF = 0b01000000; @@ -44,9 +48,10 @@ bitflags! { /// A read/write wrapper around an [NDISC Option]. /// /// [NDISC Option]: https://tools.ietf.org/html/rfc4861#section-4.6 -#[derive(Debug, PartialEq)] +#[derive(Debug, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct NdiscOption> { - buffer: T + buffer: T, } // Format of an NDISC Option @@ -61,16 +66,16 @@ pub struct NdiscOption> { mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; // 8-bit identifier of the type of option. - pub const TYPE: usize = 0; - // 8-bit unsigned integer. Length of the option, in units of 8 octests. - pub const LENGTH: usize = 1; + pub const TYPE: usize = 0; + // 8-bit unsigned integer. Length of the option, in units of 8 octets. + pub const LENGTH: usize = 1; // Minimum length of an option. - pub const MIN_OPT_LEN: usize = 8; + pub const MIN_OPT_LEN: usize = 8; // Variable-length field. Option-Type-specific data. - pub fn DATA(length: u8) -> Field { + pub const fn DATA(length: u8) -> Field { 2..length as usize * 8 } @@ -79,9 +84,6 @@ mod field { // | Type | Length | Link-Layer Address ... // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ - // Link-Layer Address - pub const LL_ADDR: Field = 2..8; - // Prefix Information Option fields. // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // | Type | Length | Prefix Length |L|A| Reserved1 | @@ -102,17 +104,17 @@ mod field { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // Prefix length. - pub const PREFIX_LEN: usize = 2; + pub const PREFIX_LEN: usize = 2; // Flags field of prefix header. - pub const FLAGS: usize = 3; + pub const FLAGS: usize = 3; // Valid lifetime. - pub const VALID_LT: Field = 4..8; + pub const VALID_LT: Field = 4..8; // Preferred lifetime. - pub const PREF_LT: Field = 8..12; + pub const PREF_LT: Field = 8..12; // Reserved bits pub const PREF_RESERVED: Field = 12..16; // Prefix - pub const PREFIX: Field = 16..32; + pub const PREFIX: Field = 16..32; // Redirected Header Option fields. // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -126,10 +128,8 @@ mod field { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // Reserved bits. - pub const IP_RESERVED: Field = 4..8; - // Redirected header IP header + data. - pub const IP_DATA: usize = 8; - pub const REDIR_MIN_SZ: usize = 48; + pub const REDIRECTED_RESERVED: Field = 2..8; + pub const REDIR_MIN_SZ: usize = 48; // MTU Option fields // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ @@ -139,13 +139,13 @@ mod field { // +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ // MTU - pub const MTU: Field = 4..8; + pub const MTU: Field = 4..8; } /// Core getter methods relevant to any type of NDISC option. impl> NdiscOption { /// Create a raw octet buffer with an NDISC Option structure. - pub fn new_unchecked(buffer: T) -> NdiscOption { + pub const fn new_unchecked(buffer: T) -> NdiscOption { NdiscOption { buffer } } @@ -156,11 +156,17 @@ impl> NdiscOption { pub fn new_checked(buffer: T) -> Result> { let opt = Self::new_unchecked(buffer); opt.check_len()?; + + // A data length field of 0 is invalid. + if opt.data_len() == 0 { + return Err(Error); + } + Ok(opt) } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. + /// Returns `Err(Error)` if the buffer is too short. /// /// The result of this check is invalidated by calling [set_data_len]. /// @@ -170,23 +176,18 @@ impl> NdiscOption { let len = data.len(); if len < field::MIN_OPT_LEN { - Err(Error::Truncated) + Err(Error) } else { let data_range = field::DATA(data[field::LENGTH]); if len < data_range.end { - Err(Error::Truncated) + Err(Error) } else { match self.option_type() { - Type::SourceLinkLayerAddr | Type::TargetLinkLayerAddr | Type::Mtu => - Ok(()), - Type::PrefixInformation if data_range.end >= field::PREFIX.end => - Ok(()), - Type::RedirectedHeader if data_range.end >= field::REDIR_MIN_SZ => - Ok(()), - Type::Unknown(_) => - Ok(()), - _ => - Err(Error::Truncated), + Type::SourceLinkLayerAddr | Type::TargetLinkLayerAddr | Type::Mtu => Ok(()), + Type::PrefixInformation if data_range.end >= field::PREFIX.end => Ok(()), + Type::RedirectedHeader if data_range.end >= field::REDIR_MIN_SZ => Ok(()), + Type::Unknown(_) => Ok(()), + _ => Err(Error), } } } @@ -216,9 +217,10 @@ impl> NdiscOption { impl> NdiscOption { /// Return the Source/Target Link-layer Address. #[inline] - pub fn link_layer_addr(&self) -> EthernetAddress { + pub fn link_layer_addr(&self) -> RawHardwareAddress { + let len = MAX_HARDWARE_ADDRESS_LEN.min(self.data_len() as usize * 8 - 2); let data = self.buffer.as_ref(); - EthernetAddress::from_bytes(&data[field::LL_ADDR]) + RawHardwareAddress::from_bytes(&data[2..len + 2]) } } @@ -299,9 +301,9 @@ impl + AsMut<[u8]>> NdiscOption { impl + AsMut<[u8]>> NdiscOption { /// Set the Source/Target Link-layer Address. #[inline] - pub fn set_link_layer_addr(&mut self, addr: EthernetAddress) { + pub fn set_link_layer_addr(&mut self, addr: RawHardwareAddress) { let data = self.buffer.as_mut(); - data[field::LL_ADDR].copy_from_slice(addr.as_bytes()) + data[2..2 + addr.len()].copy_from_slice(addr.as_bytes()) } } @@ -364,11 +366,10 @@ impl + AsMut<[u8]>> NdiscOption { #[inline] pub fn clear_redirected_reserved(&mut self) { let data = self.buffer.as_mut(); - NetworkEndian::write_u32(&mut data[field::IP_RESERVED], 0); + data[field::REDIRECTED_RESERVED].fill_with(|| 0); } } - impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NdiscOption<&'a mut T> { /// Return a mutable pointer to the option data. #[inline] @@ -382,9 +383,9 @@ impl<'a, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> NdiscOption<&'a mut T> { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for NdiscOption<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { match Repr::parse(self) { - Ok(repr) => write!(f, "{}", repr), + Ok(repr) => write!(f, "{repr}"), Err(err) => { - write!(f, "NDISC Option ({})", err)?; + write!(f, "NDISC Option ({err})")?; Ok(()) } } @@ -392,54 +393,59 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for NdiscOption<&'a T> { } #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct PrefixInformation { pub prefix_len: u8, pub flags: PrefixInfoFlags, pub valid_lifetime: Duration, pub preferred_lifetime: Duration, - pub prefix: Ipv6Address + pub prefix: Ipv6Address, } #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct RedirectedHeader<'a> { pub header: Ipv6Repr, - pub data: &'a [u8] + pub data: &'a [u8], } /// A high-level representation of an NDISC Option. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Repr<'a> { - SourceLinkLayerAddr(EthernetAddress), - TargetLinkLayerAddr(EthernetAddress), + SourceLinkLayerAddr(RawHardwareAddress), + TargetLinkLayerAddr(RawHardwareAddress), PrefixInformation(PrefixInformation), RedirectedHeader(RedirectedHeader<'a>), Mtu(u32), Unknown { - type_: u8, + type_: u8, length: u8, - data: &'a [u8] + data: &'a [u8], }, } impl<'a> Repr<'a> { /// Parse an NDISC Option and return a high-level representation. - pub fn parse(opt: &'a NdiscOption<&'a T>) -> Result> - where T: AsRef<[u8]> + ?Sized { + pub fn parse(opt: &NdiscOption<&'a T>) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { match opt.option_type() { Type::SourceLinkLayerAddr => { - if opt.data_len() == 1 { + if opt.data_len() >= 1 { Ok(Repr::SourceLinkLayerAddr(opt.link_layer_addr())) } else { - Err(Error::Malformed) + Err(Error) } - }, + } Type::TargetLinkLayerAddr => { - if opt.data_len() == 1 { + if opt.data_len() >= 1 { Ok(Repr::TargetLinkLayerAddr(opt.link_layer_addr())) } else { - Err(Error::Malformed) + Err(Error) } - }, + } Type::PrefixInformation => { if opt.data_len() == 4 { Ok(Repr::PrefixInformation(PrefixInformation { @@ -447,77 +453,93 @@ impl<'a> Repr<'a> { flags: opt.prefix_flags(), valid_lifetime: opt.valid_lifetime(), preferred_lifetime: opt.preferred_lifetime(), - prefix: opt.prefix() + prefix: opt.prefix(), })) } else { - Err(Error::Malformed) + Err(Error) } - }, + } Type::RedirectedHeader => { // If the options data length is less than 6, the option // does not have enough data to fill out the IP header // and common option fields. if opt.data_len() < 6 { - Err(Error::Truncated) + Err(Error) } else { - let ip_packet = Ipv6Packet::new_unchecked(&opt.data()[field::IP_DATA..]); + let redirected_packet = &opt.data()[field::REDIRECTED_RESERVED.len()..]; + + let ip_packet = Ipv6Packet::new_checked(redirected_packet)?; let ip_repr = Ipv6Repr::parse(&ip_packet)?; + Ok(Repr::RedirectedHeader(RedirectedHeader { header: ip_repr, - data: &opt.data()[field::IP_DATA + ip_repr.buffer_len()..] + data: &redirected_packet[ip_repr.buffer_len()..][..ip_repr.payload_len], })) } - }, + } Type::Mtu => { if opt.data_len() == 1 { Ok(Repr::Mtu(opt.mtu())) } else { - Err(Error::Malformed) + Err(Error) } - }, + } Type::Unknown(id) => { - Ok(Repr::Unknown { - type_: id, - length: opt.data_len(), - data: opt.data() - }) + // A length of 0 is invalid. + if opt.data_len() != 0 { + Ok(Repr::Unknown { + type_: id, + length: opt.data_len(), + data: opt.data(), + }) + } else { + Err(Error) + } } } } /// Return the length of a header that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { + pub const fn buffer_len(&self) -> usize { match self { - &Repr::SourceLinkLayerAddr(_) | &Repr::TargetLinkLayerAddr(_) => - field::LL_ADDR.end, - &Repr::PrefixInformation(_) => - field::PREFIX.end, - &Repr::RedirectedHeader(RedirectedHeader { header, data }) => - field::IP_DATA + header.buffer_len() + data.len(), - &Repr::Mtu(_) => - field::MTU.end, - &Repr::Unknown { length, .. } => - field::DATA(length).end + &Repr::SourceLinkLayerAddr(addr) | &Repr::TargetLinkLayerAddr(addr) => { + let len = 2 + addr.len(); + // Round up to next multiple of 8 + (len + 7) / 8 * 8 + } + &Repr::PrefixInformation(_) => field::PREFIX.end, + &Repr::RedirectedHeader(RedirectedHeader { header, data }) => { + (8 + header.buffer_len() + data.len() + 7) / 8 * 8 + } + &Repr::Mtu(_) => field::MTU.end, + &Repr::Unknown { length, .. } => field::DATA(length).end, } } /// Emit a high-level representation into an NDISC Option. pub fn emit(&self, opt: &mut NdiscOption<&'a mut T>) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { - match self { - &Repr::SourceLinkLayerAddr(addr) => { + where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { + match *self { + Repr::SourceLinkLayerAddr(addr) => { opt.set_option_type(Type::SourceLinkLayerAddr); - opt.set_data_len(1); + let opt_len = addr.len() + 2; + opt.set_data_len(((opt_len + 7) / 8) as u8); // round to next multiple of 8. opt.set_link_layer_addr(addr); - }, - &Repr::TargetLinkLayerAddr(addr) => { + } + Repr::TargetLinkLayerAddr(addr) => { opt.set_option_type(Type::TargetLinkLayerAddr); - opt.set_data_len(1); + let opt_len = addr.len() + 2; + opt.set_data_len(((opt_len + 7) / 8) as u8); // round to next multiple of 8. opt.set_link_layer_addr(addr); - }, - &Repr::PrefixInformation(PrefixInformation { - prefix_len, flags, valid_lifetime, - preferred_lifetime, prefix + } + Repr::PrefixInformation(PrefixInformation { + prefix_len, + flags, + valid_lifetime, + preferred_lifetime, + prefix, }) => { opt.clear_prefix_reserved(); opt.set_option_type(Type::PrefixInformation); @@ -527,26 +549,28 @@ impl<'a> Repr<'a> { opt.set_valid_lifetime(valid_lifetime); opt.set_preferred_lifetime(preferred_lifetime); opt.set_prefix(prefix); - }, - &Repr::RedirectedHeader(RedirectedHeader { - header, data - }) => { - let data_len = data.len() / 8; + } + Repr::RedirectedHeader(RedirectedHeader { header, data }) => { + // TODO(thvdveld): I think we need to check if the data we are sending is not + // exceeding the MTU. opt.clear_redirected_reserved(); opt.set_option_type(Type::RedirectedHeader); - opt.set_data_len((header.buffer_len() + 1 + data_len) as u8); - let mut ip_packet = - Ipv6Packet::new_unchecked(&mut opt.data_mut()[field::IP_DATA..]); + opt.set_data_len((((8 + header.buffer_len() + data.len()) + 7) / 8) as u8); + let mut packet = &mut opt.data_mut()[field::REDIRECTED_RESERVED.end - 2..]; + let mut ip_packet = Ipv6Packet::new_unchecked(&mut packet); header.emit(&mut ip_packet); - let payload = &mut ip_packet.into_inner()[header.buffer_len()..]; - payload.copy_from_slice(&data[..data_len]); + ip_packet.payload_mut().copy_from_slice(data); } - &Repr::Mtu(mtu) => { + Repr::Mtu(mtu) => { opt.set_option_type(Type::Mtu); opt.set_data_len(1); opt.set_mtu(mtu); } - &Repr::Unknown { type_: id, length, data } => { + Repr::Unknown { + type_: id, + length, + data, + } => { opt.set_option_type(Type::Unknown(id)); opt.set_data_len(length); opt.data_mut().copy_from_slice(data); @@ -558,70 +582,70 @@ impl<'a> Repr<'a> { impl<'a> fmt::Display for Repr<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { write!(f, "NDISC Option: ")?; - match self { - &Repr::SourceLinkLayerAddr(addr) => { - write!(f, "SourceLinkLayer addr={}", addr) - }, - &Repr::TargetLinkLayerAddr(addr) => { - write!(f, "TargetLinkLayer addr={}", addr) - }, - &Repr::PrefixInformation(PrefixInformation { - prefix, prefix_len, - .. - }) => { - write!(f, "PrefixInformation prefix={}/{}", prefix, prefix_len) - }, - &Repr::RedirectedHeader(RedirectedHeader { - header, - .. + match *self { + Repr::SourceLinkLayerAddr(addr) => { + write!(f, "SourceLinkLayer addr={addr}") + } + Repr::TargetLinkLayerAddr(addr) => { + write!(f, "TargetLinkLayer addr={addr}") + } + Repr::PrefixInformation(PrefixInformation { + prefix, prefix_len, .. }) => { - write!(f, "RedirectedHeader header={}", header) - }, - &Repr::Mtu(mtu) => { - write!(f, "MTU mtu={}", mtu) - }, - &Repr::Unknown { type_: id, length, .. } => { - write!(f, "Unknown({}) length={}", id, length) + write!(f, "PrefixInformation prefix={prefix}/{prefix_len}") + } + Repr::RedirectedHeader(RedirectedHeader { header, .. }) => { + write!(f, "RedirectedHeader header={header}") + } + Repr::Mtu(mtu) => { + write!(f, "MTU mtu={mtu}") + } + Repr::Unknown { + type_: id, length, .. + } => { + write!(f, "Unknown({id}) length={length}") } } } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for NdiscOption { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { match NdiscOption::new_checked(buffer) { - Err(err) => return write!(f, "{}({})", indent, err), - Ok(ndisc) => { - match Repr::parse(&ndisc) { - Err(_) => return Ok(()), - Ok(repr) => { - write!(f, "{}{}", indent, repr) - } + Err(err) => write!(f, "{indent}({err})"), + Ok(ndisc) => match Repr::parse(&ndisc) { + Err(_) => Ok(()), + Ok(repr) => { + write!(f, "{indent}{repr}") } - } + }, } } } +#[cfg(any(feature = "medium-ethernet", feature = "medium-ieee802154"))] #[cfg(test)] mod test { - use Error; - use time::Duration; - use wire::{EthernetAddress, Ipv6Address}; - use super::{NdiscOption, Type, PrefixInfoFlags, PrefixInformation, Repr}; + use super::Error; + use super::{NdiscOption, PrefixInfoFlags, PrefixInformation, Repr, Type}; + use crate::time::Duration; + use crate::wire::Ipv6Address; + + #[cfg(feature = "medium-ethernet")] + use crate::wire::EthernetAddress; + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ieee802154"))] + use crate::wire::Ieee802154Address; static PREFIX_OPT_BYTES: [u8; 32] = [ - 0x03, 0x04, 0x40, 0xc0, - 0x00, 0x00, 0x03, 0x84, - 0x00, 0x00, 0x03, 0xe8, - 0x00, 0x00, 0x00, 0x00, - 0xfe, 0x80, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x01 + 0x03, 0x04, 0x40, 0xc0, 0x00, 0x00, 0x03, 0x84, 0x00, 0x00, 0x03, 0xe8, 0x00, 0x00, 0x00, + 0x00, 0xfe, 0x80, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, ]; #[test] @@ -630,7 +654,10 @@ mod test { assert_eq!(opt.option_type(), Type::PrefixInformation); assert_eq!(opt.data_len(), 4); assert_eq!(opt.prefix_len(), 64); - assert_eq!(opt.prefix_flags(), PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF); + assert_eq!( + opt.prefix_flags(), + PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF + ); assert_eq!(opt.valid_lifetime(), Duration::from_secs(900)); assert_eq!(opt.preferred_lifetime(), Duration::from_secs(1000)); assert_eq!(opt.prefix(), Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)); @@ -647,31 +674,56 @@ mod test { opt.set_valid_lifetime(Duration::from_secs(900)); opt.set_preferred_lifetime(Duration::from_secs(1000)); opt.set_prefix(Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1)); - assert_eq!(&PREFIX_OPT_BYTES[..], &opt.into_inner()[..]); + assert_eq!(&PREFIX_OPT_BYTES[..], &*opt.into_inner()); } #[test] fn test_short_packet() { - assert_eq!(NdiscOption::new_checked(&[0x00, 0x00]), Err(Error::Truncated)); - let bytes = [ - 0x03, 0x01, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00 - ]; - assert_eq!(NdiscOption::new_checked(&bytes), Err(Error::Truncated)); + assert_eq!(NdiscOption::new_checked(&[0x00, 0x00]), Err(Error)); + let bytes = [0x03, 0x01, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00]; + assert_eq!(NdiscOption::new_checked(&bytes), Err(Error)); } + #[cfg(feature = "medium-ethernet")] #[test] - fn test_repr_parse_link_layer_opt() { + fn test_repr_parse_link_layer_opt_ethernet() { let mut bytes = [0x01, 0x01, 0x54, 0x52, 0x00, 0x12, 0x23, 0x34]; let addr = EthernetAddress([0x54, 0x52, 0x00, 0x12, 0x23, 0x34]); { - assert_eq!(Repr::parse(&NdiscOption::new_unchecked(&bytes)), - Ok(Repr::SourceLinkLayerAddr(addr))); + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::SourceLinkLayerAddr(addr.into())) + ); + } + bytes[0] = 0x02; + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::TargetLinkLayerAddr(addr.into())) + ); + } + } + + #[cfg(all(not(feature = "medium-ethernet"), feature = "medium-ieee802154"))] + #[test] + fn test_repr_parse_link_layer_opt_ieee802154() { + let mut bytes = [ + 0x01, 0x02, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ]; + let addr = Ieee802154Address::Extended([0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08]); + { + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::SourceLinkLayerAddr(addr.into())) + ); } bytes[0] = 0x02; { - assert_eq!(Repr::parse(&NdiscOption::new_unchecked(&bytes)), - Ok(Repr::TargetLinkLayerAddr(addr))); + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::TargetLinkLayerAddr(addr.into())) + ); } } @@ -682,9 +734,12 @@ mod test { flags: PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF, valid_lifetime: Duration::from_secs(900), preferred_lifetime: Duration::from_secs(1000), - prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1) + prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), }); - assert_eq!(Repr::parse(&NdiscOption::new_unchecked(&PREFIX_OPT_BYTES)), Ok(repr)); + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&PREFIX_OPT_BYTES)), + Ok(repr) + ); } #[test] @@ -695,7 +750,7 @@ mod test { flags: PrefixInfoFlags::ON_LINK | PrefixInfoFlags::ADDRCONF, valid_lifetime: Duration::from_secs(900), preferred_lifetime: Duration::from_secs(1000), - prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1) + prefix: Ipv6Address::new(0xfe80, 0, 0, 0, 0, 0, 0, 1), }); let mut opt = NdiscOption::new_unchecked(&mut bytes); repr.emit(&mut opt); @@ -705,6 +760,9 @@ mod test { #[test] fn test_repr_parse_mtu() { let bytes = [0x05, 0x01, 0x00, 0x00, 0x00, 0x00, 0x05, 0xdc]; - assert_eq!(Repr::parse(&NdiscOption::new_unchecked(&bytes)), Ok(Repr::Mtu(1500))); + assert_eq!( + Repr::parse(&NdiscOption::new_unchecked(&bytes)), + Ok(Repr::Mtu(1500)) + ); } } diff --git a/src/wire/pretty_print.rs b/src/wire/pretty_print.rs index c60c1ff62..fe7d8b892 100644 --- a/src/wire/pretty_print.rs +++ b/src/wire/pretty_print.rs @@ -7,7 +7,7 @@ easily human readable packet listings. A packet can be formatted using the `PrettyPrinter` wrapper: -```rust,ignore +```rust use smoltcp::wire::*; let buffer = vec![ // Ethernet II @@ -15,7 +15,7 @@ let buffer = vec![ 0x11, 0x12, 0x13, 0x14, 0x15, 0x16, 0x08, 0x00, // IPv4 - 0x45, 0x00, 0x00, 0x18, + 0x45, 0x00, 0x00, 0x20, 0x00, 0x00, 0x40, 0x00, 0x40, 0x01, 0xd2, 0x79, 0x11, 0x12, 0x13, 0x14, @@ -25,7 +25,18 @@ let buffer = vec![ 0x12, 0x34, 0xab, 0xcd, 0xaa, 0x00, 0x00, 0xff ]; -print!("{}", PrettyPrinter::>::new("", &buffer)); + +let result = "\ +EthernetII src=11-12-13-14-15-16 dst=01-02-03-04-05-06 type=IPv4\n\ +\\ IPv4 src=17.18.19.20 dst=33.34.35.36 proto=ICMP (checksum incorrect)\n \ + \\ ICMPv4 echo request id=4660 seq=43981 len=4\ +"; + +#[cfg(all(feature = "medium-ethernet", feature = "proto-ipv4"))] +assert_eq!( + result, + &format!("{}", PrettyPrinter::>::new("", &buffer)) +); ``` */ @@ -34,21 +45,22 @@ use core::marker::PhantomData; /// Indentation state. #[derive(Debug)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct PrettyIndent { prefix: &'static str, - level: usize + level: usize, } impl PrettyIndent { /// Create an indentation state. The entire listing will be indented by the width /// of `prefix`, and `prefix` will appear at the start of the first line. pub fn new(prefix: &'static str) -> PrettyIndent { - PrettyIndent { prefix: prefix, level: 0 } + PrettyIndent { prefix, level: 0 } } /// Increase indentation level. pub fn increase(&mut self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "\n")?; + writeln!(f)?; self.level += 1; Ok(()) } @@ -71,24 +83,27 @@ pub trait PrettyPrint { /// /// `pretty_print` accepts a buffer and not a packet wrapper because the packet might /// be truncated, and so it might not be possible to create the packet wrapper. - fn pretty_print(buffer: &dyn AsRef<[u8]>, fmt: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result; + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + fmt: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result; } /// Wrapper for using a `PrettyPrint` where a `Display` is expected. pub struct PrettyPrinter<'a, T: PrettyPrint> { - prefix: &'static str, - buffer: &'a dyn AsRef<[u8]>, - phantom: PhantomData + prefix: &'static str, + buffer: &'a dyn AsRef<[u8]>, + phantom: PhantomData, } impl<'a, T: PrettyPrint> PrettyPrinter<'a, T> { /// Format the listing with the recorded parameters when Display::fmt is called. pub fn new(prefix: &'static str, buffer: &'a dyn AsRef<[u8]>) -> PrettyPrinter<'a, T> { PrettyPrinter { - prefix: prefix, - buffer: buffer, - phantom: PhantomData + prefix: prefix, + buffer: buffer, + phantom: PhantomData, } } } diff --git a/src/wire/rpl.rs b/src/wire/rpl.rs new file mode 100644 index 000000000..7028b923f --- /dev/null +++ b/src/wire/rpl.rs @@ -0,0 +1,2686 @@ +//! Implementation of the RPL packet formats. See [RFC 6550 § 6]. +//! +//! [RFC 6550 § 6]: https://datatracker.ietf.org/doc/html/rfc6550#section-6 + +use byteorder::{ByteOrder, NetworkEndian}; + +use super::{Error, Result}; +use crate::wire::icmpv6::Packet; +use crate::wire::ipv6::Address; + +#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +#[repr(u8)] +pub enum InstanceId { + Global(u8), + Local(u8), +} + +impl From for InstanceId { + fn from(val: u8) -> Self { + const MASK: u8 = 0b0111_1111; + + if ((val >> 7) & 0xb1) == 0b0 { + Self::Global(val & MASK) + } else { + Self::Local(val & MASK) + } + } +} + +impl From for u8 { + fn from(val: InstanceId) -> Self { + match val { + InstanceId::Global(val) => 0b0000_0000 | val, + InstanceId::Local(val) => 0b1000_0000 | val, + } + } +} + +impl InstanceId { + /// Return the real part of the ID. + pub fn id(&self) -> u8 { + match self { + Self::Global(val) => *val, + Self::Local(val) => *val, + } + } + + /// Returns `true` when the DODAG ID is the destination address of the IPv6 packet. + #[inline] + pub fn dodag_is_destination(&self) -> bool { + match self { + Self::Global(_) => false, + Self::Local(val) => ((val >> 6) & 0b1) == 0b1, + } + } + + /// Returns `true` when the DODAG ID is the source address of the IPv6 packet. + /// + /// *NOTE*: this only makes sence when using a local RPL Instance ID and the packet is not a + /// RPL control message. + #[inline] + pub fn dodag_is_source(&self) -> bool { + !self.dodag_is_destination() + } +} + +mod field { + use crate::wire::field::*; + + pub const RPL_INSTANCE_ID: usize = 4; + + // DODAG information solicitation fields (DIS) + pub const DIS_FLAGS: usize = 4; + pub const DIS_RESERVED: usize = 5; + + // DODAG information object fields (DIO) + pub const DIO_VERSION_NUMBER: usize = 5; + pub const DIO_RANK: Field = 6..8; + pub const DIO_GROUNDED: usize = 8; + pub const DIO_MOP: usize = 8; + pub const DIO_PRF: usize = 8; + pub const DIO_DTSN: usize = 9; + //pub const DIO_FLAGS: usize = 10; + //pub const DIO_RESERVED: usize = 11; + pub const DIO_DODAG_ID: Field = 12..12 + 16; + + // Destination advertisment object (DAO) + pub const DAO_K: usize = 5; + pub const DAO_D: usize = 5; + //pub const DAO_FLAGS: usize = 5; + //pub const DAO_RESERVED: usize = 6; + pub const DAO_SEQUENCE: usize = 7; + pub const DAO_DODAG_ID: Field = 8..8 + 16; + + // Destination advertisment object ack (DAO-ACK) + pub const DAO_ACK_D: usize = 5; + //pub const DAO_ACK_RESERVED: usize = 5; + pub const DAO_ACK_SEQUENCE: usize = 6; + pub const DAO_ACK_STATUS: usize = 7; + pub const DAO_ACK_DODAG_ID: Field = 8..8 + 16; +} + +enum_with_unknown! { + /// RPL Control Message subtypes. + pub enum RplControlMessage(u8) { + DodagInformationSolicitation = 0x00, + DodagInformationObject = 0x01, + DestinationAdvertisementObject = 0x02, + DestinationAdvertisementObjectAck = 0x03, + SecureDodagInformationSolicitation = 0x80, + SecureDodagInformationObject = 0x81, + SecureDesintationAdvertismentObject = 0x82, + SecureDestinationAdvertisementObjectAck = 0x83, + ConsistencyCheck = 0x8a, + } +} + +impl core::fmt::Display for RplControlMessage { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + RplControlMessage::DodagInformationSolicitation => { + write!(f, "DODAG information solicitation (DIS)") + } + RplControlMessage::DodagInformationObject => { + write!(f, "DODAG information object (DIO)") + } + RplControlMessage::DestinationAdvertisementObject => { + write!(f, "destination advertisment object (DAO)") + } + RplControlMessage::DestinationAdvertisementObjectAck => write!( + f, + "destination advertisment object acknowledgement (DAO-ACK)" + ), + RplControlMessage::SecureDodagInformationSolicitation => { + write!(f, "secure DODAG information solicitation (DIS)") + } + RplControlMessage::SecureDodagInformationObject => { + write!(f, "secure DODAG information object (DIO)") + } + RplControlMessage::SecureDesintationAdvertismentObject => { + write!(f, "secure destination advertisment object (DAO)") + } + RplControlMessage::SecureDestinationAdvertisementObjectAck => write!( + f, + "secure destination advertisment object acknowledgement (DAO-ACK)" + ), + RplControlMessage::ConsistencyCheck => write!(f, "consistency check (CC)"), + RplControlMessage::Unknown(id) => write!(f, "{}", id), + } + } +} + +impl> Packet { + /// Return the RPL instance ID. + #[inline] + pub fn rpl_instance_id(&self) -> InstanceId { + get!(self.buffer, into: InstanceId, field: field::RPL_INSTANCE_ID) + } +} + +impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return a pointer to the options. + pub fn options(&self) -> Result<&'p [u8]> { + let len = self.buffer.as_ref().len(); + match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation if len < field::DIS_RESERVED + 1 => { + return Err(Error) + } + RplControlMessage::DodagInformationObject if len < field::DIO_DODAG_ID.end => { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObject + if self.dao_dodag_id_present() && len < field::DAO_DODAG_ID.end => + { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObject if len < field::DAO_SEQUENCE + 1 => { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObjectAck + if self.dao_dodag_id_present() && len < field::DAO_ACK_DODAG_ID.end => + { + return Err(Error) + } + RplControlMessage::DestinationAdvertisementObjectAck + if len < field::DAO_ACK_STATUS + 1 => + { + return Err(Error) + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDesintationAdvertismentObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => return Err(Error), + RplControlMessage::Unknown(_) => return Err(Error), + _ => {} + } + + let buffer = &self.buffer.as_ref(); + Ok(match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation => &buffer[field::DIS_RESERVED + 1..], + RplControlMessage::DodagInformationObject => &buffer[field::DIO_DODAG_ID.end..], + RplControlMessage::DestinationAdvertisementObject if self.dao_dodag_id_present() => { + &buffer[field::DAO_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObject => &buffer[field::DAO_SEQUENCE + 1..], + RplControlMessage::DestinationAdvertisementObjectAck if self.dao_dodag_id_present() => { + &buffer[field::DAO_ACK_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObjectAck => { + &buffer[field::DAO_ACK_STATUS + 1..] + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDesintationAdvertismentObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => unreachable!(), + RplControlMessage::Unknown(_) => unreachable!(), + }) + } +} + +impl + AsMut<[u8]>> Packet { + /// Set the RPL Instance ID field. + #[inline] + pub fn set_rpl_instance_id(&mut self, value: u8) { + set!(self.buffer, value, field: field::RPL_INSTANCE_ID) + } +} + +impl<'p, T: AsRef<[u8]> + AsMut<[u8]> + ?Sized> Packet<&'p mut T> { + /// Return a pointer to the options. + pub fn options_mut(&mut self) -> &mut [u8] { + match RplControlMessage::from(self.msg_code()) { + RplControlMessage::DodagInformationSolicitation => { + &mut self.buffer.as_mut()[field::DIS_RESERVED + 1..] + } + RplControlMessage::DodagInformationObject => { + &mut self.buffer.as_mut()[field::DIO_DODAG_ID.end..] + } + RplControlMessage::DestinationAdvertisementObject => { + if self.dao_dodag_id_present() { + &mut self.buffer.as_mut()[field::DAO_DODAG_ID.end..] + } else { + &mut self.buffer.as_mut()[field::DAO_SEQUENCE + 1..] + } + } + RplControlMessage::DestinationAdvertisementObjectAck => { + if self.dao_dodag_id_present() { + &mut self.buffer.as_mut()[field::DAO_ACK_DODAG_ID.end..] + } else { + &mut self.buffer.as_mut()[field::DAO_ACK_STATUS + 1..] + } + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDesintationAdvertismentObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => todo!("Secure messages not supported"), + RplControlMessage::Unknown(_) => todo!(), + } + } +} + +/// Getters for the DODAG information solicitation (DIS) message. +/// +/// ```txt +/// 0 1 2 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Flags | Reserved | Option(s)... +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// ``` +impl> Packet { + /// Return the DIS flags field. + #[inline] + pub fn dis_flags(&self) -> u8 { + get!(self.buffer, field: field::DIS_FLAGS) + } + + /// Return the DIS reserved field. + #[inline] + pub fn dis_reserved(&self) -> u8 { + get!(self.buffer, field: field::DIS_RESERVED) + } +} + +/// Setters for the DODAG information solicitation (DIS) message. +impl + AsMut<[u8]>> Packet { + /// Clear the DIS flags field. + pub fn clear_dis_flags(&mut self) { + self.buffer.as_mut()[field::DIS_FLAGS] = 0; + } + + /// Clear the DIS rserved field. + pub fn clear_dis_reserved(&mut self) { + self.buffer.as_mut()[field::DIS_RESERVED] = 0; + } +} + +enum_with_unknown! { + pub enum ModeOfOperation(u8) { + NoDownwardRoutesMaintained = 0x00, + NonStoringMode = 0x01, + StoringModeWithoutMulticast = 0x02, + StoringModeWithMulticast = 0x03, + } +} + +impl Default for ModeOfOperation { + fn default() -> Self { + Self::StoringModeWithoutMulticast + } +} + +/// Getters for the DODAG information object (DIO) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |Version Number | Rank | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// |G|0| MOP | Prf | DTSN | Flags | Reserved | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl> Packet { + /// Return the Version Number field. + #[inline] + pub fn dio_version_number(&self) -> u8 { + get!(self.buffer, field: field::DIO_VERSION_NUMBER) + } + + /// Return the Rank field. + #[inline] + pub fn dio_rank(&self) -> u16 { + get!(self.buffer, u16, field: field::DIO_RANK) + } + + /// Return the value of the Grounded flag. + #[inline] + pub fn dio_grounded(&self) -> bool { + get!(self.buffer, bool, field: field::DIO_GROUNDED, shift: 7, mask: 0b01) + } + + /// Return the mode of operation field. + #[inline] + pub fn dio_mode_of_operation(&self) -> ModeOfOperation { + get!(self.buffer, into: ModeOfOperation, field: field::DIO_MOP, shift: 3, mask: 0b111) + } + + /// Return the DODAG preference field. + #[inline] + pub fn dio_dodag_preference(&self) -> u8 { + get!(self.buffer, field: field::DIO_PRF, mask: 0b111) + } + + /// Return the destination advertisment trigger sequence number. + #[inline] + pub fn dio_dest_adv_trigger_seq_number(&self) -> u8 { + get!(self.buffer, field: field::DIO_DTSN) + } + + /// Return the DODAG id, which is an IPv6 address. + #[inline] + pub fn dio_dodag_id(&self) -> Address { + get!( + self.buffer, + into: Address, + fun: from_bytes, + field: field::DIO_DODAG_ID + ) + } +} + +/// Setters for the DODAG information object (DIO) message. +impl + AsMut<[u8]>> Packet { + /// Set the Version Number field. + #[inline] + pub fn set_dio_version_number(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_VERSION_NUMBER) + } + + /// Set the Rank field. + #[inline] + pub fn set_dio_rank(&mut self, value: u16) { + set!(self.buffer, value, u16, field: field::DIO_RANK) + } + + /// Set the value of the Grounded flag. + #[inline] + pub fn set_dio_grounded(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DIO_GROUNDED, shift: 7, mask: 0b01) + } + + /// Set the mode of operation field. + #[inline] + pub fn set_dio_mode_of_operation(&mut self, mode: ModeOfOperation) { + let raw = (self.buffer.as_ref()[field::DIO_MOP] & !(0b111 << 3)) | (u8::from(mode) << 3); + self.buffer.as_mut()[field::DIO_MOP] = raw; + } + + /// Set the DODAG preference field. + #[inline] + pub fn set_dio_dodag_preference(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_PRF, mask: 0b111) + } + + /// Set the destination advertisment trigger sequence number. + #[inline] + pub fn set_dio_dest_adv_trigger_seq_number(&mut self, value: u8) { + set!(self.buffer, value, field: field::DIO_DTSN) + } + + /// Set the DODAG id, which is an IPv6 address. + #[inline] + pub fn set_dio_dodag_id(&mut self, address: Address) { + set!(self.buffer, address: address, field: field::DIO_DODAG_ID) + } +} + +/// Getters for the Destination Advertisment Object (DAO) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |K|D| Flags | Reserved | DAOSequence | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID* + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl> Packet { + /// Returns the Expect DAO-ACK flag. + #[inline] + pub fn dao_ack_request(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_K, shift: 7, mask: 0b1) + } + + /// Returns the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn dao_dodag_id_present(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_D, shift: 6, mask: 0b1) + } + + /// Returns the DODAG sequence flag. + #[inline] + pub fn dao_dodag_sequence(&self) -> u8 { + get!(self.buffer, field: field::DAO_SEQUENCE) + } + + /// Returns the DODAG ID, an IPv6 address, when it is present. + #[inline] + pub fn dao_dodag_id(&self) -> Option
{ + if self.dao_dodag_id_present() { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::DAO_DODAG_ID], + )) + } else { + None + } + } +} + +/// Setters for the Destination Advertisment Object (DAO) message. +impl + AsMut<[u8]>> Packet { + /// Set the Expect DAO-ACK flag. + #[inline] + pub fn set_dao_ack_request(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_K, shift: 7, mask: 0b1,) + } + + /// Set the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn set_dao_dodag_id_present(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_D, shift: 6, mask: 0b1) + } + + /// Set the DODAG sequence flag. + #[inline] + pub fn set_dao_dodag_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_SEQUENCE) + } + + /// Set the DODAG ID. + #[inline] + pub fn set_dao_dodag_id(&mut self, address: Option
) { + match address { + Some(address) => { + self.buffer.as_mut()[field::DAO_DODAG_ID].copy_from_slice(address.as_bytes()); + self.set_dao_dodag_id_present(true); + } + None => { + self.set_dao_dodag_id_present(false); + } + } + } +} + +/// Getters for the Destination Advertisment Object acknowledgement (DAO-ACK) message. +/// +/// ```txt +/// 0 1 2 3 +/// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | RPLInstanceID |D| Reserved | DAOSequence | Status | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | | +/// + + +/// | | +/// + DODAGID* + +/// | | +/// + + +/// | | +/// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ +/// | Option(s)... +/// +-+-+-+-+-+-+-+-+ +/// ``` +impl> Packet { + /// Returns the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn dao_ack_dodag_id_present(&self) -> bool { + get!(self.buffer, bool, field: field::DAO_ACK_D, shift: 6, mask: 0b1) + } + + /// Return the DODAG sequence number. + #[inline] + pub fn dao_ack_sequence(&self) -> u8 { + get!(self.buffer, field: field::DAO_ACK_SEQUENCE) + } + + /// Return the DOA status field. + #[inline] + pub fn dao_ack_status(&self) -> u8 { + get!(self.buffer, field: field::DAO_ACK_STATUS) + } + + /// Returns the DODAG ID, an IPv6 address, when it is present. + #[inline] + pub fn dao_ack_dodag_id(&self) -> Option
{ + if self.dao_ack_dodag_id_present() { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::DAO_ACK_DODAG_ID], + )) + } else { + None + } + } +} + +/// Setters for the Destination Advertisment Object acknowledgement (DAO-ACK) message. +impl + AsMut<[u8]>> Packet { + /// Set the flag indicating that the DODAG ID is present or not. + #[inline] + pub fn set_dao_ack_dodag_id_present(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::DAO_ACK_D, shift: 6, mask: 0b1) + } + + /// Set the DODAG sequence number. + #[inline] + pub fn set_dao_ack_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_ACK_SEQUENCE) + } + + /// Set the DOA status field. + #[inline] + pub fn set_dao_ack_status(&mut self, value: u8) { + set!(self.buffer, value, field: field::DAO_ACK_STATUS) + } + + /// Set the DODAG ID. + #[inline] + pub fn set_dao_ack_dodag_id(&mut self, address: Option
) { + match address { + Some(address) => { + self.buffer.as_mut()[field::DAO_ACK_DODAG_ID].copy_from_slice(address.as_bytes()); + self.set_dao_ack_dodag_id_present(true); + } + None => { + self.set_dao_ack_dodag_id_present(false); + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum Repr<'p> { + DodagInformationSolicitation { + options: &'p [u8], + }, + DodagInformationObject { + rpl_instance_id: InstanceId, + version_number: u8, + rank: u16, + grounded: bool, + mode_of_operation: ModeOfOperation, + dodag_preference: u8, + dtsn: u8, + dodag_id: Address, + options: &'p [u8], + }, + DestinationAdvertisementObject { + rpl_instance_id: InstanceId, + expect_ack: bool, + sequence: u8, + dodag_id: Option
, + options: &'p [u8], + }, + DestinationAdvertisementObjectAck { + rpl_instance_id: InstanceId, + sequence: u8, + status: u8, + dodag_id: Option
, + }, +} + +impl core::fmt::Display for Repr<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::DodagInformationSolicitation { .. } => { + write!(f, "DIS")?; + } + Repr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + write!( + f, + "DIO \ + IID={rpl_instance_id:?} \ + V={version_number} \ + R={rank} \ + G={grounded} \ + MOP={mode_of_operation:?} \ + Pref={dodag_preference} \ + DTSN={dtsn} \ + DODAGID={dodag_id}" + )?; + } + Repr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + write!( + f, + "DAO \ + IID={rpl_instance_id:?} \ + Ack={expect_ack} \ + Seq={sequence} \ + DODAGID={dodag_id:?}", + )?; + } + Repr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + write!( + f, + "DAO-ACK \ + IID={rpl_instance_id:?} \ + Seq={sequence} \ + Status={status} \ + DODAGID={dodag_id:?}", + )?; + } + }; + + Ok(()) + } +} + +impl<'p> Repr<'p> { + pub fn set_options(&mut self, options: &'p [u8]) { + let opts = match self { + Repr::DodagInformationSolicitation { options } => options, + Repr::DodagInformationObject { options, .. } => options, + Repr::DestinationAdvertisementObject { options, .. } => options, + Repr::DestinationAdvertisementObjectAck { .. } => unreachable!(), + }; + + *opts = options; + } + + pub fn parse + ?Sized>(packet: &Packet<&'p T>) -> Result { + let options = packet.options()?; + match RplControlMessage::from(packet.msg_code()) { + RplControlMessage::DodagInformationSolicitation => { + Ok(Repr::DodagInformationSolicitation { options }) + } + RplControlMessage::DodagInformationObject => Ok(Repr::DodagInformationObject { + rpl_instance_id: packet.rpl_instance_id(), + version_number: packet.dio_version_number(), + rank: packet.dio_rank(), + grounded: packet.dio_grounded(), + mode_of_operation: packet.dio_mode_of_operation(), + dodag_preference: packet.dio_dodag_preference(), + dtsn: packet.dio_dest_adv_trigger_seq_number(), + dodag_id: packet.dio_dodag_id(), + options, + }), + RplControlMessage::DestinationAdvertisementObject => { + Ok(Repr::DestinationAdvertisementObject { + rpl_instance_id: packet.rpl_instance_id(), + expect_ack: packet.dao_ack_request(), + sequence: packet.dao_dodag_sequence(), + dodag_id: packet.dao_dodag_id(), + options, + }) + } + RplControlMessage::DestinationAdvertisementObjectAck => { + Ok(Repr::DestinationAdvertisementObjectAck { + rpl_instance_id: packet.rpl_instance_id(), + sequence: packet.dao_ack_sequence(), + status: packet.dao_ack_status(), + dodag_id: packet.dao_ack_dodag_id(), + }) + } + RplControlMessage::SecureDodagInformationSolicitation + | RplControlMessage::SecureDodagInformationObject + | RplControlMessage::SecureDesintationAdvertismentObject + | RplControlMessage::SecureDestinationAdvertisementObjectAck + | RplControlMessage::ConsistencyCheck => Err(Error), + RplControlMessage::Unknown(_) => Err(Error), + } + } + + pub fn buffer_len(&self) -> usize { + let mut len = 4 + match self { + Repr::DodagInformationSolicitation { .. } => 2, + Repr::DodagInformationObject { .. } => 24, + Repr::DestinationAdvertisementObject { dodag_id, .. } => { + if dodag_id.is_some() { + 20 + } else { + 4 + } + } + Repr::DestinationAdvertisementObjectAck { dodag_id, .. } => { + if dodag_id.is_some() { + 20 + } else { + 4 + } + } + }; + + let opts = match self { + Repr::DodagInformationSolicitation { options } => &options[..], + Repr::DodagInformationObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObjectAck { .. } => &[], + }; + + len += opts.len(); + + len + } + + pub fn emit + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&mut T>) { + packet.set_msg_type(crate::wire::icmpv6::Message::RplControl); + + match self { + Repr::DodagInformationSolicitation { .. } => { + packet.set_msg_code(RplControlMessage::DodagInformationSolicitation.into()); + packet.clear_dis_flags(); + packet.clear_dis_reserved(); + } + Repr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DodagInformationObject.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dio_version_number(*version_number); + packet.set_dio_rank(*rank); + packet.set_dio_grounded(*grounded); + packet.set_dio_mode_of_operation(*mode_of_operation); + packet.set_dio_dodag_preference(*dodag_preference); + packet.set_dio_dest_adv_trigger_seq_number(*dtsn); + packet.set_dio_dodag_id(*dodag_id); + } + Repr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DestinationAdvertisementObject.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dao_ack_request(*expect_ack); + packet.set_dao_dodag_sequence(*sequence); + packet.set_dao_dodag_id(*dodag_id); + } + Repr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + packet.set_msg_code(RplControlMessage::DestinationAdvertisementObjectAck.into()); + packet.set_rpl_instance_id((*rpl_instance_id).into()); + packet.set_dao_ack_sequence(*sequence); + packet.set_dao_ack_status(*status); + packet.set_dao_ack_dodag_id(*dodag_id); + } + } + + let options = match self { + Repr::DodagInformationSolicitation { options } => &options[..], + Repr::DodagInformationObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObject { options, .. } => &options[..], + Repr::DestinationAdvertisementObjectAck { .. } => &[], + }; + + packet.options_mut().copy_from_slice(options); + } +} + +pub mod options { + use byteorder::{ByteOrder, NetworkEndian}; + + use super::{Error, InstanceId, Result}; + use crate::wire::ipv6::Address; + + /// A read/write wrapper around a RPL Control Message Option. + #[derive(Debug, Clone)] + pub struct Packet> { + buffer: T, + } + + enum_with_unknown! { + pub enum OptionType(u8) { + Pad1 = 0x00, + PadN = 0x01, + DagMetricContainer = 0x02, + RouteInformation = 0x03, + DodagConfiguration = 0x04, + RplTarget = 0x05, + TransitInformation = 0x06, + SolicitedInformation = 0x07, + PrefixInformation = 0x08, + RplTargetDescriptor = 0x09, + } + } + + impl From<&Repr<'_>> for OptionType { + fn from(repr: &Repr) -> Self { + match repr { + Repr::Pad1 => Self::Pad1, + Repr::PadN(_) => Self::PadN, + Repr::DagMetricContainer => Self::DagMetricContainer, + Repr::RouteInformation { .. } => Self::RouteInformation, + Repr::DodagConfiguration { .. } => Self::DodagConfiguration, + Repr::RplTarget { .. } => Self::RplTarget, + Repr::TransitInformation { .. } => Self::TransitInformation, + Repr::SolicitedInformation { .. } => Self::SolicitedInformation, + Repr::PrefixInformation { .. } => Self::PrefixInformation, + Repr::RplTargetDescriptor { .. } => Self::RplTargetDescriptor, + } + } + } + + mod field { + use crate::wire::field::*; + + // Generic fields. + pub const TYPE: usize = 0; + pub const LENGTH: usize = 1; + + pub const PADN: Rest = 2..; + + // Route Information fields. + pub const ROUTE_INFO_PREFIX_LENGTH: usize = 2; + pub const ROUTE_INFO_RESERVED: usize = 3; + pub const ROUTE_INFO_PREFERENCE: usize = 3; + pub const ROUTE_INFO_LIFETIME: Field = 4..9; + + // DODAG Configuration fields. + pub const DODAG_CONF_FLAGS: usize = 2; + pub const DODAG_CONF_AUTHENTICATION_ENABLED: usize = 2; + pub const DODAG_CONF_PATH_CONTROL_SIZE: usize = 2; + pub const DODAG_CONF_DIO_INTERVAL_DOUBLINGS: usize = 3; + pub const DODAG_CONF_DIO_INTERVAL_MINIMUM: usize = 4; + pub const DODAG_CONF_DIO_REDUNDANCY_CONSTANT: usize = 5; + pub const DODAG_CONF_DIO_MAX_RANK_INCREASE: Field = 6..8; + pub const DODAG_CONF_MIN_HOP_RANK_INCREASE: Field = 8..10; + pub const DODAG_CONF_OBJECTIVE_CODE_POINT: Field = 10..12; + pub const DODAG_CONF_DEFAULT_LIFETIME: usize = 13; + pub const DODAG_CONF_LIFETIME_UNIT: Field = 14..16; + + // RPL Target fields. + pub const RPL_TARGET_FLAGS: usize = 2; + pub const RPL_TARGET_PREFIX_LENGTH: usize = 3; + + // Transit Information fields. + pub const TRANSIT_INFO_FLAGS: usize = 2; + pub const TRANSIT_INFO_EXTERNAL: usize = 2; + pub const TRANSIT_INFO_PATH_CONTROL: usize = 3; + pub const TRANSIT_INFO_PATH_SEQUENCE: usize = 4; + pub const TRANSIT_INFO_PATH_LIFETIME: usize = 5; + pub const TRANSIT_INFO_PARENT_ADDRESS: Field = 6..6 + 16; + + // Solicited Information fields. + pub const SOLICITED_INFO_RPL_INSTANCE_ID: usize = 2; + pub const SOLICITED_INFO_FLAGS: usize = 3; + pub const SOLICITED_INFO_VERSION_PREDICATE: usize = 3; + pub const SOLICITED_INFO_INSTANCE_ID_PREDICATE: usize = 3; + pub const SOLICITED_INFO_DODAG_ID_PREDICATE: usize = 3; + pub const SOLICITED_INFO_DODAG_ID: Field = 4..20; + pub const SOLICITED_INFO_VERSION_NUMBER: usize = 20; + + // Prefix Information fields. + pub const PREFIX_INFO_PREFIX_LENGTH: usize = 2; + pub const PREFIX_INFO_RESERVED1: usize = 3; + pub const PREFIX_INFO_ON_LINK: usize = 3; + pub const PREFIX_INFO_AUTONOMOUS_CONF: usize = 3; + pub const PREFIX_INFO_ROUTER_ADDRESS_FLAG: usize = 3; + pub const PREFIX_INFO_VALID_LIFETIME: Field = 4..8; + pub const PREFIX_INFO_PREFERRED_LIFETIME: Field = 8..12; + pub const PREFIX_INFO_RESERVED2: Field = 12..16; + pub const PREFIX_INFO_PREFIX: Field = 16..16 + 16; + + // RPL Target Descriptor fields. + pub const TARGET_DESCRIPTOR: Field = 2..6; + } + + /// Getters for the RPL Control Message Options. + impl> Packet { + /// Imbue a raw octet buffer with RPL Control Message Option structure. + #[inline] + pub fn new_unchecked(buffer: T) -> Self { + Packet { buffer } + } + + #[inline] + pub fn new_checked(buffer: T) -> Result { + if buffer.as_ref().is_empty() { + return Err(Error); + } + + Ok(Packet { buffer }) + } + + /// Return the type field. + #[inline] + pub fn option_type(&self) -> OptionType { + OptionType::from(self.buffer.as_ref()[field::TYPE]) + } + + /// Return the length field. + #[inline] + pub fn option_length(&self) -> u8 { + get!(self.buffer, field: field::LENGTH) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return a pointer to the next option. + #[inline] + pub fn next_option(&self) -> Option<&'p [u8]> { + if !self.buffer.as_ref().is_empty() { + match self.option_type() { + OptionType::Pad1 => Some(&self.buffer.as_ref()[1..]), + OptionType::Unknown(_) => unreachable!(), + _ => { + let len = self.option_length(); + Some(&self.buffer.as_ref()[2 + len as usize..]) + } + } + } else { + None + } + } + } + + impl + AsMut<[u8]>> Packet { + /// Set the Option Type field. + #[inline] + pub fn set_option_type(&mut self, option_type: OptionType) { + self.buffer.as_mut()[field::TYPE] = option_type.into(); + } + + /// Set the Option Length field. + #[inline] + pub fn set_option_length(&mut self, length: u8) { + self.buffer.as_mut()[field::LENGTH] = length; + } + } + + impl + AsMut<[u8]>> Packet { + #[inline] + pub fn clear_padn(&mut self, size: u8) { + for b in &mut self.buffer.as_mut()[field::PADN][..size as usize] { + *b = 0; + } + } + } + + /// Getters for the DAG Metric Container Option Message. + + /// Getters for the Route Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x03 | Option Length | Prefix Length |Resvd|Prf|Resvd| + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Route Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// . Prefix (Variable Length) . + /// . . + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the Prefix Length field. + #[inline] + pub fn prefix_length(&self) -> u8 { + get!(self.buffer, field: field::ROUTE_INFO_PREFIX_LENGTH) + } + + /// Return the Route Preference field. + #[inline] + pub fn route_preference(&self) -> u8 { + (self.buffer.as_ref()[field::ROUTE_INFO_PREFERENCE] & 0b0001_1000) >> 3 + } + + /// Return the Route Lifetime field. + #[inline] + pub fn route_lifetime(&self) -> u32 { + get!(self.buffer, u32, field: field::ROUTE_INFO_LIFETIME) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Prefix field. + #[inline] + pub fn prefix(&self) -> &'p [u8] { + let option_len = self.option_length(); + &self.buffer.as_ref()[field::ROUTE_INFO_LIFETIME.end..] + [..option_len as usize - field::ROUTE_INFO_LIFETIME.end] + } + } + + /// Setters for the Route Information Option Message. + impl + AsMut<[u8]>> Packet { + /// Set the Prefix Length field. + #[inline] + pub fn set_route_info_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::ROUTE_INFO_PREFIX_LENGTH) + } + + /// Set the Route Preference field. + #[inline] + pub fn set_route_info_route_preference(&mut self, _value: u8) { + todo!(); + } + + /// Set the Route Lifetime field. + #[inline] + pub fn set_route_info_route_lifetime(&mut self, value: u32) { + set!(self.buffer, value, u32, field: field::ROUTE_INFO_LIFETIME) + } + + /// Set the prefix field. + #[inline] + pub fn set_route_info_prefix(&mut self, _prefix: &[u8]) { + todo!(); + } + + /// Clear the reserved field. + #[inline] + pub fn clear_route_info_reserved(&mut self) { + self.buffer.as_mut()[field::ROUTE_INFO_RESERVED] = 0; + } + } + + /// Getters for the DODAG Configuration Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x04 |Opt Length = 14| Flags |A| PCS | DIOIntDoubl. | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | DIOIntMin. | DIORedun. | MaxRankIncrease | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | MinHopRankIncrease | OCP | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Reserved | Def. Lifetime | Lifetime Unit | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the Authentication Enabled field. + #[inline] + pub fn authentication_enabled(&self) -> bool { + get!( + self.buffer, + bool, + field: field::DODAG_CONF_AUTHENTICATION_ENABLED, + shift: 3, + mask: 0b1 + ) + } + + /// Return the Path Control Size field. + #[inline] + pub fn path_control_size(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_PATH_CONTROL_SIZE, mask: 0b111) + } + + /// Return the DIO Interval Doublings field. + #[inline] + pub fn dio_interval_doublings(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DIO_INTERVAL_DOUBLINGS) + } + + /// Return the DIO Interval Minimum field. + #[inline] + pub fn dio_interval_minimum(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DIO_INTERVAL_MINIMUM) + } + + /// Return the DIO Redundancy Constant field. + #[inline] + pub fn dio_redundancy_constant(&self) -> u8 { + get!( + self.buffer, + field: field::DODAG_CONF_DIO_REDUNDANCY_CONSTANT + ) + } + + /// Return the Max Rank Increase field. + #[inline] + pub fn max_rank_increase(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_DIO_MAX_RANK_INCREASE + ) + } + + /// Return the Minimum Hop Rank Increase field. + #[inline] + pub fn minimum_hop_rank_increase(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_MIN_HOP_RANK_INCREASE + ) + } + + /// Return the Objective Code Point field. + #[inline] + pub fn objective_code_point(&self) -> u16 { + get!( + self.buffer, + u16, + field: field::DODAG_CONF_OBJECTIVE_CODE_POINT + ) + } + + /// Return the Default Lifetime field. + #[inline] + pub fn default_lifetime(&self) -> u8 { + get!(self.buffer, field: field::DODAG_CONF_DEFAULT_LIFETIME) + } + + /// Return the Lifetime Unit field. + #[inline] + pub fn lifetime_unit(&self) -> u16 { + get!(self.buffer, u16, field: field::DODAG_CONF_LIFETIME_UNIT) + } + } + + /// Getters for the DODAG Configuration Option Message. + impl + AsMut<[u8]>> Packet { + /// Clear the Flags field. + #[inline] + pub fn clear_dodag_conf_flags(&mut self) { + self.buffer.as_mut()[field::DODAG_CONF_FLAGS] = 0; + } + + /// Set the Authentication Enabled field. + #[inline] + pub fn set_dodag_conf_authentication_enabled(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::DODAG_CONF_AUTHENTICATION_ENABLED, + shift: 3, + mask: 0b1 + ) + } + + /// Set the Path Control Size field. + #[inline] + pub fn set_dodag_conf_path_control_size(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_PATH_CONTROL_SIZE, + mask: 0b111 + ) + } + + /// Set the DIO Interval Doublings field. + #[inline] + pub fn set_dodag_conf_dio_interval_doublings(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_INTERVAL_DOUBLINGS + ) + } + + /// Set the DIO Interval Minimum field. + #[inline] + pub fn set_dodag_conf_dio_interval_minimum(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_INTERVAL_MINIMUM + ) + } + + /// Set the DIO Redundancy Constant field. + #[inline] + pub fn set_dodag_conf_dio_redundancy_constant(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DIO_REDUNDANCY_CONSTANT + ) + } + + /// Set the Max Rank Increase field. + #[inline] + pub fn set_dodag_conf_max_rank_increase(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_DIO_MAX_RANK_INCREASE + ) + } + + /// Set the Minimum Hop Rank Increase field. + #[inline] + pub fn set_dodag_conf_minimum_hop_rank_increase(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_MIN_HOP_RANK_INCREASE + ) + } + + /// Set the Objective Code Point field. + #[inline] + pub fn set_dodag_conf_objective_code_point(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_OBJECTIVE_CODE_POINT + ) + } + + /// Set the Default Lifetime field. + #[inline] + pub fn set_dodag_conf_default_lifetime(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::DODAG_CONF_DEFAULT_LIFETIME + ) + } + + /// Set the Lifetime Unit field. + #[inline] + pub fn set_dodag_conf_lifetime_unit(&mut self, value: u16) { + set!( + self.buffer, + value, + u16, + field: field::DODAG_CONF_LIFETIME_UNIT + ) + } + } + + /// Getters for the RPL Target Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x05 | Option Length | Flags | Prefix Length | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | Target Prefix (Variable Length) | + /// . . + /// . . + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the Target Prefix Length field. + pub fn target_prefix_length(&self) -> u8 { + get!(self.buffer, field: field::RPL_TARGET_PREFIX_LENGTH) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Target Prefix field. + #[inline] + pub fn target_prefix(&self) -> &'p [u8] { + let option_len = self.option_length(); + &self.buffer.as_ref()[field::RPL_TARGET_PREFIX_LENGTH + 1..] + [..option_len as usize - field::RPL_TARGET_PREFIX_LENGTH + 1] + } + } + + /// Setters for the RPL Target Option Message. + impl + AsMut<[u8]>> Packet { + /// Clear the Flags field. + #[inline] + pub fn clear_rpl_target_flags(&mut self) { + self.buffer.as_mut()[field::RPL_TARGET_FLAGS] = 0; + } + + /// Set the Target Prefix Length field. + #[inline] + pub fn set_rpl_target_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::RPL_TARGET_PREFIX_LENGTH) + } + + /// Set the Target Prefix field. + #[inline] + pub fn set_rpl_target_prefix(&mut self, prefix: &[u8]) { + self.buffer.as_mut()[field::RPL_TARGET_PREFIX_LENGTH + 1..][..prefix.len()] + .copy_from_slice(prefix); + } + } + + /// Getters for the Transit Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x06 | Option Length |E| Flags | Path Control | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Path Sequence | Path Lifetime | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + + /// | | + /// + + + /// | | + /// + Parent Address* + + /// | | + /// + +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the External flag. + #[inline] + pub fn is_external(&self) -> bool { + get!( + self.buffer, + bool, + field: field::TRANSIT_INFO_EXTERNAL, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Path Control field. + #[inline] + pub fn path_control(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_CONTROL) + } + + /// Return the Path Sequence field. + #[inline] + pub fn path_sequence(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_SEQUENCE) + } + + /// Return the Path Lifetime field. + #[inline] + pub fn path_lifetime(&self) -> u8 { + get!(self.buffer, field: field::TRANSIT_INFO_PATH_LIFETIME) + } + + /// Return the Parent Address field. + #[inline] + pub fn parent_address(&self) -> Option
{ + if self.option_length() > 5 { + Some(Address::from_bytes( + &self.buffer.as_ref()[field::TRANSIT_INFO_PARENT_ADDRESS], + )) + } else { + None + } + } + } + + /// Setters for the Transit Information Option Message. + impl + AsMut<[u8]>> Packet { + /// Clear the Flags field. + #[inline] + pub fn clear_transit_info_flags(&mut self) { + self.buffer.as_mut()[field::TRANSIT_INFO_FLAGS] = 0; + } + + /// Set the External flag. + #[inline] + pub fn set_transit_info_is_external(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::TRANSIT_INFO_EXTERNAL, + shift: 7, + mask: 0b1 + ) + } + + /// Set the Path Control field. + #[inline] + pub fn set_transit_info_path_control(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_CONTROL) + } + + /// Set the Path Sequence field. + #[inline] + pub fn set_transit_info_path_sequence(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_SEQUENCE) + } + + /// Set the Path Lifetime field. + #[inline] + pub fn set_transit_info_path_lifetime(&mut self, value: u8) { + set!(self.buffer, value, field: field::TRANSIT_INFO_PATH_LIFETIME) + } + + /// Set the Parent Address field. + #[inline] + pub fn set_transit_info_parent_address(&mut self, address: Address) { + self.buffer.as_mut()[field::TRANSIT_INFO_PARENT_ADDRESS] + .copy_from_slice(address.as_bytes()); + } + } + + /// Getters for the Solicited Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x07 |Opt Length = 19| RPLInstanceID |V|I|D| Flags | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | | + /// + DODAGID + + /// | | + /// + + + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |Version Number | + /// +-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the RPL Instance ID field. + #[inline] + pub fn rpl_instance_id(&self) -> u8 { + get!(self.buffer, field: field::SOLICITED_INFO_RPL_INSTANCE_ID) + } + + /// Return the Version Predicate flag. + #[inline] + pub fn version_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_VERSION_PREDICATE, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Instance ID Predicate flag. + #[inline] + pub fn instance_id_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_INSTANCE_ID_PREDICATE, + shift: 6, + mask: 0b1, + ) + } + + /// Return the DODAG Predicate ID flag. + #[inline] + pub fn dodag_id_predicate(&self) -> bool { + get!( + self.buffer, + bool, + field: field::SOLICITED_INFO_DODAG_ID_PREDICATE, + shift: 5, + mask: 0b1, + ) + } + + /// Return the DODAG ID field. + #[inline] + pub fn dodag_id(&self) -> Address { + get!( + self.buffer, + into: Address, + fun: from_bytes, + field: field::SOLICITED_INFO_DODAG_ID + ) + } + + /// Return the Version Number field. + #[inline] + pub fn version_number(&self) -> u8 { + get!(self.buffer, field: field::SOLICITED_INFO_VERSION_NUMBER) + } + } + + /// Setters for the Solicited Information Option Message. + impl + AsMut<[u8]>> Packet { + /// Clear the Flags field. + #[inline] + pub fn clear_solicited_info_flags(&mut self) { + self.buffer.as_mut()[field::SOLICITED_INFO_FLAGS] = 0; + } + + /// Set the RPL Instance ID field. + #[inline] + pub fn set_solicited_info_rpl_instance_id(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::SOLICITED_INFO_RPL_INSTANCE_ID + ) + } + + /// Set the Version Predicate flag. + #[inline] + pub fn set_solicited_info_version_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_VERSION_PREDICATE, + shift: 7, + mask: 0b1 + ) + } + + /// Set the Instance ID Predicate flag. + #[inline] + pub fn set_solicited_info_instance_id_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_INSTANCE_ID_PREDICATE, + shift: 6, + mask: 0b1 + ) + } + + /// Set the DODAG Predicate ID flag. + #[inline] + pub fn set_solicited_info_dodag_id_predicate(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::SOLICITED_INFO_DODAG_ID_PREDICATE, + shift: 5, + mask: 0b1 + ) + } + + /// Set the DODAG ID field. + #[inline] + pub fn set_solicited_info_dodag_id(&mut self, address: Address) { + set!( + self.buffer, + address: address, + field: field::SOLICITED_INFO_DODAG_ID + ) + } + + /// Set the Version Number field. + #[inline] + pub fn set_solicited_info_version_number(&mut self, value: u8) { + set!( + self.buffer, + value, + field: field::SOLICITED_INFO_VERSION_NUMBER + ) + } + } + + /// Getters for the Prefix Information Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x08 |Opt Length = 30| Prefix Length |L|A|R|Reserved1| + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Valid Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Preferred Lifetime | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Reserved2 | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | | + /// + + + /// | | + /// + Prefix + + /// | | + /// + + + /// | | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the Prefix Length field. + #[inline] + pub fn prefix_info_prefix_length(&self) -> u8 { + get!(self.buffer, field: field::PREFIX_INFO_PREFIX_LENGTH) + } + + /// Return the On-Link flag. + #[inline] + pub fn on_link(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_ON_LINK, + shift: 7, + mask: 0b1, + ) + } + + /// Return the Autonomous Address-Configuration flag. + #[inline] + pub fn autonomous_address_configuration(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_AUTONOMOUS_CONF, + shift: 6, + mask: 0b1, + ) + } + + /// Return the Router Address flag. + #[inline] + pub fn router_address(&self) -> bool { + get!( + self.buffer, + bool, + field: field::PREFIX_INFO_ROUTER_ADDRESS_FLAG, + shift: 5, + mask: 0b1, + ) + } + + /// Return the Valid Lifetime field. + #[inline] + pub fn valid_lifetime(&self) -> u32 { + get!(self.buffer, u32, field: field::PREFIX_INFO_VALID_LIFETIME) + } + + /// Return the Preferred Lifetime field. + #[inline] + pub fn preferred_lifetime(&self) -> u32 { + get!( + self.buffer, + u32, + field: field::PREFIX_INFO_PREFERRED_LIFETIME + ) + } + } + + impl<'p, T: AsRef<[u8]> + ?Sized> Packet<&'p T> { + /// Return the Prefix field. + #[inline] + pub fn destination_prefix(&self) -> &'p [u8] { + &self.buffer.as_ref()[field::PREFIX_INFO_PREFIX] + } + } + + /// Setters for the Prefix Information Option Message. + impl + AsMut<[u8]>> Packet { + /// Clear the reserved fields. + #[inline] + pub fn clear_prefix_info_reserved(&mut self) { + self.buffer.as_mut()[field::PREFIX_INFO_RESERVED1] = 0; + self.buffer.as_mut()[field::PREFIX_INFO_RESERVED2].copy_from_slice(&[0; 4]); + } + + /// Set the Prefix Length field. + #[inline] + pub fn set_prefix_info_prefix_length(&mut self, value: u8) { + set!(self.buffer, value, field: field::PREFIX_INFO_PREFIX_LENGTH) + } + + /// Set the On-Link flag. + #[inline] + pub fn set_prefix_info_on_link(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::PREFIX_INFO_ON_LINK, shift: 7, mask: 0b1) + } + + /// Set the Autonomous Address-Configuration flag. + #[inline] + pub fn set_prefix_info_autonomous_address_configuration(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::PREFIX_INFO_AUTONOMOUS_CONF, + shift: 6, + mask: 0b1 + ) + } + + /// Set the Router Address flag. + #[inline] + pub fn set_prefix_info_router_address(&mut self, value: bool) { + set!( + self.buffer, + value, + bool, + field: field::PREFIX_INFO_ROUTER_ADDRESS_FLAG, + shift: 5, + mask: 0b1 + ) + } + + /// Set the Valid Lifetime field. + #[inline] + pub fn set_prefix_info_valid_lifetime(&mut self, value: u32) { + set!( + self.buffer, + value, + u32, + field: field::PREFIX_INFO_VALID_LIFETIME + ) + } + + /// Set the Preferred Lifetime field. + #[inline] + pub fn set_prefix_info_preferred_lifetime(&mut self, value: u32) { + set!( + self.buffer, + value, + u32, + field: field::PREFIX_INFO_PREFERRED_LIFETIME + ) + } + + /// Set the Prefix field. + #[inline] + pub fn set_prefix_info_destination_prefix(&mut self, prefix: &[u8]) { + self.buffer.as_mut()[field::PREFIX_INFO_PREFIX].copy_from_slice(prefix); + } + } + + /// Getters for the RPL Target Descriptor Option Message. + /// + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Type = 0x09 |Opt Length = 4 | Descriptor + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// Descriptor (cont.) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + impl> Packet { + /// Return the Descriptor field. + #[inline] + pub fn descriptor(&self) -> u32 { + get!(self.buffer, u32, field: field::TARGET_DESCRIPTOR) + } + } + + /// Setters for the RPL Target Descriptor Option Message. + impl + AsMut<[u8]>> Packet { + /// Set the Descriptor field. + #[inline] + pub fn set_rpl_target_descriptor_descriptor(&mut self, value: u32) { + set!(self.buffer, value, u32, field: field::TARGET_DESCRIPTOR) + } + } + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub enum Repr<'p> { + Pad1, + PadN(u8), + DagMetricContainer, + RouteInformation { + prefix_length: u8, + preference: u8, + lifetime: u32, + prefix: &'p [u8], + }, + DodagConfiguration { + authentication_enabled: bool, + path_control_size: u8, + dio_interval_doublings: u8, + dio_interval_min: u8, + dio_redundancy_constant: u8, + max_rank_increase: u16, + minimum_hop_rank_increase: u16, + objective_code_point: u16, + default_lifetime: u8, + lifetime_unit: u16, + }, + RplTarget { + prefix_length: u8, + prefix: crate::wire::Ipv6Address, // FIXME: this is not the correct type, because the + // field can be an IPv6 address, a prefix or a + // multicast group. + }, + TransitInformation { + external: bool, + path_control: u8, + path_sequence: u8, + path_lifetime: u8, + parent_address: Option
, + }, + SolicitedInformation { + rpl_instance_id: InstanceId, + version_predicate: bool, + instance_id_predicate: bool, + dodag_id_predicate: bool, + dodag_id: Address, + version_number: u8, + }, + PrefixInformation { + prefix_length: u8, + on_link: bool, + autonomous_address_configuration: bool, + router_address: bool, + valid_lifetime: u32, + preferred_lifetime: u32, + destination_prefix: &'p [u8], + }, + RplTargetDescriptor { + descriptor: u32, + }, + } + + impl core::fmt::Display for Repr<'_> { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::Pad1 => write!(f, "Pad1"), + Repr::PadN(n) => write!(f, "PadN({n})"), + Repr::DagMetricContainer => todo!(), + Repr::RouteInformation { + prefix_length, + preference, + lifetime, + prefix, + } => { + write!( + f, + "ROUTE INFO \ + PrefixLength={prefix_length} \ + Preference={preference} \ + Lifetime={lifetime} \ + Prefix={prefix:0x?}" + ) + } + Repr::DodagConfiguration { + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + .. + } => { + write!( + f, + "DODAG CONF \ + IntD={dio_interval_doublings} \ + IntMin={dio_interval_min} \ + RedCst={dio_redundancy_constant} \ + MaxRankIncr={max_rank_increase} \ + MinHopRankIncr={minimum_hop_rank_increase} \ + OCP={objective_code_point} \ + DefaultLifetime={default_lifetime} \ + LifeUnit={lifetime_unit}" + ) + } + Repr::RplTarget { + prefix_length, + prefix, + } => { + write!( + f, + "RPL Target \ + PrefixLength={prefix_length} \ + Prefix={prefix:0x?}" + ) + } + Repr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + write!( + f, + "Transit Info \ + External={external} \ + PathCtrl={path_control} \ + PathSqnc={path_sequence} \ + PathLifetime={path_lifetime} \ + Parent={parent_address:0x?}" + ) + } + Repr::SolicitedInformation { + rpl_instance_id, + version_predicate, + instance_id_predicate, + dodag_id_predicate, + dodag_id, + version_number, + } => { + write!( + f, + "Solicited Info \ + I={instance_id_predicate} \ + IID={rpl_instance_id:0x?} \ + D={dodag_id_predicate} \ + DODAGID={dodag_id} \ + V={version_predicate} \ + Version={version_number}" + ) + } + Repr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + router_address, + valid_lifetime, + preferred_lifetime, + destination_prefix, + } => { + write!( + f, + "Prefix Info \ + PrefixLength={prefix_length} \ + L={on_link} A={autonomous_address_configuration} R={router_address} \ + Valid={valid_lifetime} \ + Prefered={preferred_lifetime} \ + Prefix={destination_prefix:0x?}" + ) + } + Repr::RplTargetDescriptor { .. } => write!(f, "Target Descriptor"), + } + } + } + + impl<'p> Repr<'p> { + pub fn parse + ?Sized>(packet: &Packet<&'p T>) -> Result { + match packet.option_type() { + OptionType::Pad1 => Ok(Repr::Pad1), + OptionType::PadN => Ok(Repr::PadN(packet.option_length())), + OptionType::DagMetricContainer => todo!(), + OptionType::RouteInformation => Ok(Repr::RouteInformation { + prefix_length: packet.prefix_length(), + preference: packet.route_preference(), + lifetime: packet.route_lifetime(), + prefix: packet.prefix(), + }), + OptionType::DodagConfiguration => Ok(Repr::DodagConfiguration { + authentication_enabled: packet.authentication_enabled(), + path_control_size: packet.path_control_size(), + dio_interval_doublings: packet.dio_interval_doublings(), + dio_interval_min: packet.dio_interval_minimum(), + dio_redundancy_constant: packet.dio_redundancy_constant(), + max_rank_increase: packet.max_rank_increase(), + minimum_hop_rank_increase: packet.minimum_hop_rank_increase(), + objective_code_point: packet.objective_code_point(), + default_lifetime: packet.default_lifetime(), + lifetime_unit: packet.lifetime_unit(), + }), + OptionType::RplTarget => Ok(Repr::RplTarget { + prefix_length: packet.target_prefix_length(), + prefix: crate::wire::Ipv6Address::from_bytes(packet.target_prefix()), + }), + OptionType::TransitInformation => Ok(Repr::TransitInformation { + external: packet.is_external(), + path_control: packet.path_control(), + path_sequence: packet.path_sequence(), + path_lifetime: packet.path_lifetime(), + parent_address: packet.parent_address(), + }), + OptionType::SolicitedInformation => Ok(Repr::SolicitedInformation { + rpl_instance_id: InstanceId::from(packet.rpl_instance_id()), + version_predicate: packet.version_predicate(), + instance_id_predicate: packet.instance_id_predicate(), + dodag_id_predicate: packet.dodag_id_predicate(), + dodag_id: packet.dodag_id(), + version_number: packet.version_number(), + }), + OptionType::PrefixInformation => Ok(Repr::PrefixInformation { + prefix_length: packet.prefix_info_prefix_length(), + on_link: packet.on_link(), + autonomous_address_configuration: packet.autonomous_address_configuration(), + router_address: packet.router_address(), + valid_lifetime: packet.valid_lifetime(), + preferred_lifetime: packet.preferred_lifetime(), + destination_prefix: packet.destination_prefix(), + }), + OptionType::RplTargetDescriptor => Ok(Repr::RplTargetDescriptor { + descriptor: packet.descriptor(), + }), + OptionType::Unknown(_) => Err(Error), + } + } + + pub fn buffer_len(&self) -> usize { + match self { + Repr::Pad1 => 1, + Repr::PadN(size) => 2 + *size as usize, + Repr::DagMetricContainer => todo!(), + Repr::RouteInformation { prefix, .. } => 2 + 6 + prefix.len(), + Repr::DodagConfiguration { .. } => 2 + 14, + Repr::RplTarget { prefix, .. } => 2 + 2 + prefix.0.len(), + Repr::TransitInformation { parent_address, .. } => { + 2 + 4 + if parent_address.is_some() { 16 } else { 0 } + } + Repr::SolicitedInformation { .. } => 2 + 2 + 16 + 1, + Repr::PrefixInformation { .. } => 32, + Repr::RplTargetDescriptor { .. } => 2 + 4, + } + } + + pub fn emit + AsMut<[u8]> + ?Sized>(&self, packet: &mut Packet<&'p mut T>) { + let mut option_length = self.buffer_len() as u8; + + packet.set_option_type(self.into()); + + if !matches!(self, Repr::Pad1) { + option_length -= 2; + packet.set_option_length(option_length); + } + + match self { + Repr::Pad1 => {} + Repr::PadN(size) => { + packet.clear_padn(*size); + } + Repr::DagMetricContainer => { + unimplemented!(); + } + Repr::RouteInformation { + prefix_length, + preference, + lifetime, + prefix, + } => { + packet.clear_route_info_reserved(); + packet.set_route_info_prefix_length(*prefix_length); + packet.set_route_info_route_preference(*preference); + packet.set_route_info_route_lifetime(*lifetime); + packet.set_route_info_prefix(prefix); + } + Repr::DodagConfiguration { + authentication_enabled, + path_control_size, + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + } => { + packet.clear_dodag_conf_flags(); + packet.set_dodag_conf_authentication_enabled(*authentication_enabled); + packet.set_dodag_conf_path_control_size(*path_control_size); + packet.set_dodag_conf_dio_interval_doublings(*dio_interval_doublings); + packet.set_dodag_conf_dio_interval_minimum(*dio_interval_min); + packet.set_dodag_conf_dio_redundancy_constant(*dio_redundancy_constant); + packet.set_dodag_conf_max_rank_increase(*max_rank_increase); + packet.set_dodag_conf_minimum_hop_rank_increase(*minimum_hop_rank_increase); + packet.set_dodag_conf_objective_code_point(*objective_code_point); + packet.set_dodag_conf_default_lifetime(*default_lifetime); + packet.set_dodag_conf_lifetime_unit(*lifetime_unit); + } + Repr::RplTarget { + prefix_length, + prefix, + } => { + packet.clear_rpl_target_flags(); + packet.set_rpl_target_prefix_length(*prefix_length); + packet.set_rpl_target_prefix(prefix.as_bytes()); + } + Repr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + packet.clear_transit_info_flags(); + packet.set_transit_info_is_external(*external); + packet.set_transit_info_path_control(*path_control); + packet.set_transit_info_path_sequence(*path_sequence); + packet.set_transit_info_path_lifetime(*path_lifetime); + + if let Some(address) = parent_address { + packet.set_transit_info_parent_address(*address); + } + } + Repr::SolicitedInformation { + rpl_instance_id, + version_predicate, + instance_id_predicate, + dodag_id_predicate, + dodag_id, + version_number, + } => { + packet.clear_solicited_info_flags(); + packet.set_solicited_info_rpl_instance_id((*rpl_instance_id).into()); + packet.set_solicited_info_version_predicate(*version_predicate); + packet.set_solicited_info_instance_id_predicate(*instance_id_predicate); + packet.set_solicited_info_dodag_id_predicate(*dodag_id_predicate); + packet.set_solicited_info_version_number(*version_number); + packet.set_solicited_info_dodag_id(*dodag_id); + } + Repr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + router_address, + valid_lifetime, + preferred_lifetime, + destination_prefix, + } => { + packet.clear_prefix_info_reserved(); + packet.set_prefix_info_prefix_length(*prefix_length); + packet.set_prefix_info_on_link(*on_link); + packet.set_prefix_info_autonomous_address_configuration( + *autonomous_address_configuration, + ); + packet.set_prefix_info_router_address(*router_address); + packet.set_prefix_info_valid_lifetime(*valid_lifetime); + packet.set_prefix_info_preferred_lifetime(*preferred_lifetime); + packet.set_prefix_info_destination_prefix(destination_prefix); + } + Repr::RplTargetDescriptor { descriptor } => { + packet.set_rpl_target_descriptor_descriptor(*descriptor); + } + } + } + } +} + +pub mod data { + use super::{InstanceId, Result}; + use byteorder::{ByteOrder, NetworkEndian}; + + mod field { + use crate::wire::field::*; + + pub const FLAGS: usize = 0; + pub const INSTANCE_ID: usize = 1; + pub const SENDER_RANK: Field = 2..4; + } + + /// A read/write wrapper around a RPL Packet Information send with + /// an IPv6 Hop-by-Hop option, defined in RFC6553. + /// ```txt + /// 0 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | Option Type | Opt Data Len | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |O|R|F|0|0|0|0|0| RPLInstanceID | SenderRank | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// | (sub-TLVs) | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub struct Packet> { + buffer: T, + } + + impl> Packet { + #[inline] + pub fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + #[inline] + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + #[inline] + pub fn check_len(&self) -> Result<()> { + if self.buffer.as_ref().len() == 4 { + Ok(()) + } else { + Err(crate::wire::Error) + } + } + + #[inline] + pub fn is_down(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 7, mask: 0b1) + } + + #[inline] + pub fn has_rank_error(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 6, mask: 0b1) + } + + #[inline] + pub fn has_forwarding_error(&self) -> bool { + get!(self.buffer, bool, field: field::FLAGS, shift: 5, mask: 0b1) + } + + #[inline] + pub fn rpl_instance_id(&self) -> InstanceId { + get!(self.buffer, into: InstanceId, field: field::INSTANCE_ID) + } + + #[inline] + pub fn sender_rank(&self) -> u16 { + get!(self.buffer, u16, field: field::SENDER_RANK) + } + } + + impl + AsMut<[u8]>> Packet { + #[inline] + pub fn set_is_down(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 7, mask: 0b1) + } + + #[inline] + pub fn set_has_rank_error(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 6, mask: 0b1) + } + + #[inline] + pub fn set_has_forwarding_error(&mut self, value: bool) { + set!(self.buffer, value, bool, field: field::FLAGS, shift: 5, mask: 0b1) + } + + #[inline] + pub fn set_rpl_instance_id(&mut self, value: u8) { + set!(self.buffer, value, field: field::INSTANCE_ID) + } + + #[inline] + pub fn set_sender_rank(&mut self, value: u16) { + set!(self.buffer, value, u16, field: field::SENDER_RANK) + } + } + + /// A high-level representation of an IPv6 Extension Header Option. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct HopByHopOption { + pub down: bool, + pub rank_error: bool, + pub forwarding_error: bool, + pub instance_id: InstanceId, + pub sender_rank: u16, + } + + impl HopByHopOption { + /// Parse an IPv6 Extension Header Option and return a high-level representation. + pub fn parse(opt: &Packet<&T>) -> Self + where + T: AsRef<[u8]> + ?Sized, + { + Self { + down: opt.is_down(), + rank_error: opt.has_rank_error(), + forwarding_error: opt.has_forwarding_error(), + instance_id: opt.rpl_instance_id(), + sender_rank: opt.sender_rank(), + } + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub const fn buffer_len(&self) -> usize { + 4 + } + + /// Emit a high-level representation into an IPv6 Extension Header Option. + pub fn emit + AsMut<[u8]> + ?Sized>(&self, opt: &mut Packet<&mut T>) { + opt.set_is_down(self.down); + opt.set_has_rank_error(self.rank_error); + opt.set_has_forwarding_error(self.forwarding_error); + opt.set_rpl_instance_id(self.instance_id.into()); + opt.set_sender_rank(self.sender_rank); + } + } + + impl core::fmt::Display for HopByHopOption { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "down={} rank_error={} forw_error={} IID={:?} sender_rank={}", + self.down, + self.rank_error, + self.forwarding_error, + self.instance_id, + self.sender_rank + ) + } + } +} + +#[cfg(test)] +mod tests { + use super::options::{Packet as OptionPacket, Repr as OptionRepr}; + use super::Repr as RplRepr; + use super::*; + use crate::phy::ChecksumCapabilities; + use crate::wire::{icmpv6::*, *}; + + #[test] + fn dis_packet() { + let data = [0x7a, 0x3b, 0x3a, 0x1a, 0x9b, 0x00, 0x00, 0x00, 0x00, 0x00]; + + let ll_src_address = + Ieee802154Address::Extended([0x9e, 0xd3, 0xa2, 0x9c, 0x57, 0x1a, 0x4f, 0xe4]); + let ll_dst_address = Ieee802154Address::Short([0xff, 0xff]); + + let packet = SixlowpanIphcPacket::new_checked(&data).unwrap(); + let repr = + SixlowpanIphcRepr::parse(&packet, Some(ll_src_address), Some(ll_dst_address), &[]) + .unwrap(); + + let icmp_repr = match repr.next_header { + SixlowpanNextHeader::Uncompressed(IpProtocol::Icmpv6) => { + let icmp_packet = Icmpv6Packet::new_checked(packet.payload()).unwrap(); + match Icmpv6Repr::parse( + &IpAddress::Ipv6(repr.src_addr), + &IpAddress::Ipv6(repr.dst_addr), + &icmp_packet, + &ChecksumCapabilities::ignored(), + ) { + Ok(icmp @ Icmpv6Repr::Rpl(RplRepr::DodagInformationSolicitation { .. })) => { + icmp + } + _ => unreachable!(), + } + } + _ => unreachable!(), + }; + + // We also try to emit the packet: + let mut buffer = vec![0u8; repr.buffer_len() + icmp_repr.buffer_len()]; + repr.emit(&mut SixlowpanIphcPacket::new_unchecked( + &mut buffer[..repr.buffer_len()], + )); + icmp_repr.emit( + &repr.src_addr.into(), + &repr.dst_addr.into(), + &mut Icmpv6Packet::new_unchecked( + &mut buffer[repr.buffer_len()..][..icmp_repr.buffer_len()], + ), + &ChecksumCapabilities::ignored(), + ); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DIO packets. + #[test] + fn dio_packet() { + let data = [ + 0x9b, 0x01, 0x00, 0x00, 0x00, 0xf0, 0x00, 0x80, 0x08, 0xf0, 0x00, 0x00, 0xfd, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x04, 0x0e, 0x00, 0x08, 0x0c, 0x00, 0x04, 0x00, 0x00, 0x80, 0x00, 0x01, 0x00, 0x1e, + 0x00, 0x3c, 0x08, 0x1e, 0x40, 0x40, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, 0xff, + 0x00, 0x00, 0x00, 0x00, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + ]; + + let addr = Address::from_bytes(&[ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x01, + ]); + + let dest_prefix = [ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x00, + ]; + + let packet = Packet::new_checked(&data[..]).unwrap(); + assert_eq!(packet.msg_type(), Message::RplControl); + assert_eq!( + RplControlMessage::from(packet.msg_code()), + RplControlMessage::DodagInformationObject + ); + + let mut dio_repr = RplRepr::parse(&packet).unwrap(); + match dio_repr { + RplRepr::DodagInformationObject { + rpl_instance_id, + version_number, + rank, + grounded, + mode_of_operation, + dodag_preference, + dtsn, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert_eq!(version_number, 240); + assert_eq!(rank, 128); + assert!(!grounded); + assert_eq!(mode_of_operation, ModeOfOperation::NonStoringMode); + assert_eq!(dodag_preference, 0); + assert_eq!(dtsn, 240); + assert_eq!(dodag_id, addr); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(packet.options().unwrap()); + let dodag_conf_option = OptionRepr::parse(&option).unwrap(); + match dodag_conf_option { + OptionRepr::DodagConfiguration { + authentication_enabled, + path_control_size, + dio_interval_doublings, + dio_interval_min, + dio_redundancy_constant, + max_rank_increase, + minimum_hop_rank_increase, + objective_code_point, + default_lifetime, + lifetime_unit, + } => { + assert!(!authentication_enabled); + assert_eq!(path_control_size, 0); + assert_eq!(dio_interval_doublings, 8); + assert_eq!(dio_interval_min, 12); + assert_eq!(dio_redundancy_constant, 0); + assert_eq!(max_rank_increase, 1024); + assert_eq!(minimum_hop_rank_increase, 128); + assert_eq!(objective_code_point, 1); + assert_eq!(default_lifetime, 30); + assert_eq!(lifetime_unit, 60); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(option.next_option().unwrap()); + let prefix_info_option = OptionRepr::parse(&option).unwrap(); + match prefix_info_option { + OptionRepr::PrefixInformation { + prefix_length, + on_link, + autonomous_address_configuration, + valid_lifetime, + preferred_lifetime, + destination_prefix, + .. + } => { + assert_eq!(prefix_length, 64); + assert!(!on_link); + assert!(autonomous_address_configuration); + assert_eq!(valid_lifetime, u32::MAX); + assert_eq!(preferred_lifetime, u32::MAX); + assert_eq!(destination_prefix, &dest_prefix[..]); + } + _ => unreachable!(), + } + + let mut options_buffer = + vec![0u8; dodag_conf_option.buffer_len() + prefix_info_option.buffer_len()]; + + dodag_conf_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[..dodag_conf_option.buffer_len()], + )); + prefix_info_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[dodag_conf_option.buffer_len()..] + [..prefix_info_option.buffer_len()], + )); + + dio_repr.set_options(&options_buffer[..]); + + let mut buffer = vec![0u8; dio_repr.buffer_len()]; + dio_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DAO packets. + #[test] + fn dao_packet() { + let data = [ + 0x9b, 0x02, 0x00, 0x00, 0x00, 0x80, 0x00, 0xf1, 0x05, 0x12, 0x00, 0x80, 0xfd, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, + 0x06, 0x14, 0x00, 0x00, 0x00, 0x1e, 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, 0x00, 0x01, + ]; + + let target_prefix = [ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x02, 0x00, 0x02, 0x00, 0x02, + 0x00, 0x02, + ]; + + let parent_addr = Address::from_bytes(&[ + 0xfd, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0x01, 0x00, 0x01, 0x00, 0x01, + 0x00, 0x01, + ]); + + let packet = Packet::new_checked(&data[..]).unwrap(); + let mut dao_repr = RplRepr::parse(&packet).unwrap(); + match dao_repr { + RplRepr::DestinationAdvertisementObject { + rpl_instance_id, + expect_ack, + sequence, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert!(expect_ack); + assert_eq!(sequence, 241); + assert_eq!(dodag_id, None); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(packet.options().unwrap()); + + let rpl_target_option = OptionRepr::parse(&option).unwrap(); + match rpl_target_option { + OptionRepr::RplTarget { + prefix_length, + prefix, + } => { + assert_eq!(prefix_length, 128); + assert_eq!(prefix.as_bytes(), &target_prefix[..]); + } + _ => unreachable!(), + } + + let option = OptionPacket::new_unchecked(option.next_option().unwrap()); + let transit_info_option = OptionRepr::parse(&option).unwrap(); + match transit_info_option { + OptionRepr::TransitInformation { + external, + path_control, + path_sequence, + path_lifetime, + parent_address, + } => { + assert!(!external); + assert_eq!(path_control, 0); + assert_eq!(path_sequence, 0); + assert_eq!(path_lifetime, 30); + assert_eq!(parent_address, Some(parent_addr)); + } + _ => unreachable!(), + } + + let mut options_buffer = + vec![0u8; rpl_target_option.buffer_len() + transit_info_option.buffer_len()]; + + rpl_target_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[..rpl_target_option.buffer_len()], + )); + transit_info_option.emit(&mut OptionPacket::new_unchecked( + &mut options_buffer[rpl_target_option.buffer_len()..] + [..transit_info_option.buffer_len()], + )); + + dao_repr.set_options(&options_buffer[..]); + + let mut buffer = vec![0u8; dao_repr.buffer_len()]; + dao_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } + + /// Parsing of DAO-ACK packets. + #[test] + fn dao_ack_packet() { + let data = [0x9b, 0x03, 0x00, 0x00, 0x00, 0x00, 0xf1, 0x00]; + + let packet = Packet::new_checked(&data[..]).unwrap(); + let dao_ack_repr = RplRepr::parse(&packet).unwrap(); + match dao_ack_repr { + RplRepr::DestinationAdvertisementObjectAck { + rpl_instance_id, + sequence, + status, + dodag_id, + .. + } => { + assert_eq!(rpl_instance_id, InstanceId::from(0)); + assert_eq!(sequence, 241); + assert_eq!(status, 0); + assert_eq!(dodag_id, None); + } + _ => unreachable!(), + } + + let mut buffer = vec![0u8; dao_ack_repr.buffer_len()]; + dao_ack_repr.emit(&mut Packet::new_unchecked(&mut buffer[..])); + + assert_eq!(&data[..], &buffer[..]); + } +} diff --git a/src/wire/sixlowpan.rs b/src/wire/sixlowpan.rs new file mode 100644 index 000000000..222bb3c10 --- /dev/null +++ b/src/wire/sixlowpan.rs @@ -0,0 +1,2463 @@ +//! Implementation of [RFC 6282] which specifies a compression format for IPv6 datagrams over +//! IEEE802.154-based networks. +//! +//! [RFC 6282]: https://datatracker.ietf.org/doc/html/rfc6282 + +use super::{Error, Result}; +use crate::wire::ieee802154::Address as LlAddress; +use crate::wire::ipv6; +use crate::wire::IpProtocol; + +const ADDRESS_CONTEXT_LENGTH: usize = 8; + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub struct AddressContext(pub [u8; ADDRESS_CONTEXT_LENGTH]); + +/// The representation of an unresolved address. 6LoWPAN compression of IPv6 addresses can be with +/// and without context information. The decompression with context information is not yet +/// implemented. +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum UnresolvedAddress<'a> { + WithoutContext(AddressMode<'a>), + WithContext((usize, AddressMode<'a>)), + Reserved, +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum AddressMode<'a> { + /// The full address is carried in-line. + FullInline(&'a [u8]), + /// The first 64-bits of the address are elided. The value of those bits + /// is the link-local prefix padded with zeros. The remaining 64 bits are + /// carried in-line. + InLine64bits(&'a [u8]), + /// The first 112 bits of the address are elided. The value of the first + /// 64 bits is the link-local prefix padded with zeros. The following 64 bits + /// are 0000:00ff:fe00:XXXX, where XXXX are the 16 bits carried in-line. + InLine16bits(&'a [u8]), + /// The address is fully elided. The first 64 bits of the address are + /// the link-local prefix padded with zeros. The remaining 64 bits are + /// computed from the encapsulating header (e.g., 802.15.4 or IPv6 source address) + /// as specified in Section 3.2.2. + FullyElided, + /// The address takes the form ffXX::00XX:XXXX:XXXX + Multicast48bits(&'a [u8]), + /// The address takes the form ffXX::00XX:XXXX. + Multicast32bits(&'a [u8]), + /// The address takes the form ff02::00XX. + Multicast8bits(&'a [u8]), + /// The unspecified address. + Unspecified, + NotSupported, +} + +const LINK_LOCAL_PREFIX: [u8; 2] = [0xfe, 0x80]; +const EUI64_MIDDLE_VALUE: [u8; 2] = [0xff, 0xfe]; + +impl<'a> UnresolvedAddress<'a> { + pub fn resolve( + self, + ll_address: Option, + addr_context: &[AddressContext], + ) -> Result { + let mut bytes = [0; 16]; + + let copy_context = |index: usize, bytes: &mut [u8]| -> Result<()> { + if index >= addr_context.len() { + return Err(Error); + } + + let context = addr_context[index]; + bytes[..ADDRESS_CONTEXT_LENGTH].copy_from_slice(&context.0); + + Ok(()) + }; + + match self { + UnresolvedAddress::WithoutContext(mode) => match mode { + AddressMode::FullInline(addr) => Ok(ipv6::Address::from_bytes(addr)), + AddressMode::InLine64bits(inline) => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + bytes[8..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::InLine16bits(inline) => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::FullyElided => { + bytes[0..2].copy_from_slice(&LINK_LOCAL_PREFIX[..]); + match ll_address { + Some(LlAddress::Short(ll)) => { + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(&ll); + } + Some(addr @ LlAddress::Extended(_)) => match addr.as_eui_64() { + Some(addr) => bytes[8..].copy_from_slice(&addr), + None => return Err(Error), + }, + Some(LlAddress::Absent) => return Err(Error), + None => return Err(Error), + } + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast48bits(inline) => { + bytes[0] = 0xff; + bytes[1] = inline[0]; + bytes[11..].copy_from_slice(&inline[1..][..5]); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast32bits(inline) => { + bytes[0] = 0xff; + bytes[1] = inline[0]; + bytes[13..].copy_from_slice(&inline[1..][..3]); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + AddressMode::Multicast8bits(inline) => { + bytes[0] = 0xff; + bytes[1] = 0x02; + bytes[15] = inline[0]; + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + _ => Err(Error), + }, + UnresolvedAddress::WithContext(mode) => match mode { + (_, AddressMode::Unspecified) => Ok(ipv6::Address::UNSPECIFIED), + (index, AddressMode::InLine64bits(inline)) => { + copy_context(index, &mut bytes[..])?; + bytes[16 - inline.len()..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + (index, AddressMode::InLine16bits(inline)) => { + copy_context(index, &mut bytes[..])?; + bytes[16 - inline.len()..].copy_from_slice(inline); + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + (index, AddressMode::FullyElided) => { + match ll_address { + Some(LlAddress::Short(ll)) => { + bytes[11..13].copy_from_slice(&EUI64_MIDDLE_VALUE[..]); + bytes[14..].copy_from_slice(&ll); + } + Some(addr @ LlAddress::Extended(_)) => match addr.as_eui_64() { + Some(addr) => bytes[8..].copy_from_slice(&addr), + None => return Err(Error), + }, + Some(LlAddress::Absent) => return Err(Error), + None => return Err(Error), + } + + copy_context(index, &mut bytes[..])?; + + Ok(ipv6::Address::from_bytes(&bytes[..])) + } + _ => Err(Error), + }, + UnresolvedAddress::Reserved => Err(Error), + } + } +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] +pub enum SixlowpanPacket { + FragmentHeader, + IphcHeader, +} + +const DISPATCH_FIRST_FRAGMENT_HEADER: u8 = 0b11000; +const DISPATCH_FRAGMENT_HEADER: u8 = 0b11100; +const DISPATCH_IPHC_HEADER: u8 = 0b011; +const DISPATCH_UDP_HEADER: u8 = 0b11110; +const DISPATCH_EXT_HEADER: u8 = 0b1110; + +impl SixlowpanPacket { + /// Returns the type of the 6LoWPAN header. + /// This can either be a fragment header or an IPHC header. + /// + /// # Errors + /// Returns `[Error::Unrecognized]` when neither the Fragment Header dispatch or the IPHC + /// dispatch is recognized. + pub fn dispatch(buffer: impl AsRef<[u8]>) -> Result { + let raw = buffer.as_ref(); + + if raw.is_empty() { + return Err(Error); + } + + if raw[0] >> 3 == DISPATCH_FIRST_FRAGMENT_HEADER || raw[0] >> 3 == DISPATCH_FRAGMENT_HEADER + { + Ok(Self::FragmentHeader) + } else if raw[0] >> 5 == DISPATCH_IPHC_HEADER { + Ok(Self::IphcHeader) + } else { + Err(Error) + } + } +} + +pub mod frag { + //! Implementation of the fragment headers from [RFC 4944 § 5.3]. + //! + //! [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 + + use super::{DISPATCH_FIRST_FRAGMENT_HEADER, DISPATCH_FRAGMENT_HEADER}; + use crate::wire::{Error, Result}; + use crate::wire::{Ieee802154Address, Ieee802154Repr}; + use byteorder::{ByteOrder, NetworkEndian}; + + /// Key used for identifying all the link fragments that belong to the same packet. + #[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Key { + pub(crate) ll_src_addr: Ieee802154Address, + pub(crate) ll_dst_addr: Ieee802154Address, + pub(crate) datagram_size: u16, + pub(crate) datagram_tag: u16, + } + + /// A read/write wrapper around a 6LoWPAN Fragment header. + /// [RFC 4944 § 5.3] specifies the format of the header. + /// + /// A First Fragment header has the following format: + /// ```txt + /// 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |1 1 0 0 0| datagram_size | datagram_tag | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// ``` + /// + /// Subsequent fragment headers have the following format: + /// ```txt + /// 1 2 3 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |1 1 1 0 0| datagram_size | datagram_tag | + /// +-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+-+ + /// |datagram_offset| + /// +-+-+-+-+-+-+-+-+ + /// ``` + /// + /// [RFC 4944 § 5.3]: https://datatracker.ietf.org/doc/html/rfc4944#section-5.3 + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Packet> { + buffer: T, + } + + pub const FIRST_FRAGMENT_HEADER_SIZE: usize = 4; + pub const NEXT_FRAGMENT_HEADER_SIZE: usize = 5; + + mod field { + use crate::wire::field::*; + + pub const DISPATCH: usize = 0; + pub const DATAGRAM_SIZE: Field = 0..2; + pub const DATAGRAM_TAG: Field = 2..4; + pub const DATAGRAM_OFFSET: usize = 4; + + pub const FIRST_FRAGMENT_REST: Rest = super::FIRST_FRAGMENT_HEADER_SIZE..; + pub const NEXT_FRAGMENT_REST: Rest = super::NEXT_FRAGMENT_HEADER_SIZE..; + } + + impl> Packet { + /// Input a raw octet buffer with a 6LoWPAN Fragment header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + let dispatch = packet.dispatch(); + + if dispatch != DISPATCH_FIRST_FRAGMENT_HEADER && dispatch != DISPATCH_FRAGMENT_HEADER { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + if buffer.is_empty() { + return Err(Error); + } + + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER if buffer.len() >= FIRST_FRAGMENT_HEADER_SIZE => { + Ok(()) + } + DISPATCH_FIRST_FRAGMENT_HEADER if buffer.len() < FIRST_FRAGMENT_HEADER_SIZE => { + Err(Error) + } + DISPATCH_FRAGMENT_HEADER if buffer.len() >= NEXT_FRAGMENT_HEADER_SIZE => Ok(()), + DISPATCH_FRAGMENT_HEADER if buffer.len() < NEXT_FRAGMENT_HEADER_SIZE => Err(Error), + _ => Err(Error), + } + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the dispatch field. + pub fn dispatch(&self) -> u8 { + let raw = self.buffer.as_ref(); + raw[field::DISPATCH] >> 3 + } + + /// Return the total datagram size. + pub fn datagram_size(&self) -> u16 { + let raw = self.buffer.as_ref(); + NetworkEndian::read_u16(&raw[field::DATAGRAM_SIZE]) & 0b111_1111_1111 + } + + /// Return the datagram tag. + pub fn datagram_tag(&self) -> u16 { + let raw = self.buffer.as_ref(); + NetworkEndian::read_u16(&raw[field::DATAGRAM_TAG]) + } + + /// Return the datagram offset. + pub fn datagram_offset(&self) -> u8 { + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => 0, + DISPATCH_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + raw[field::DATAGRAM_OFFSET] + } + _ => unreachable!(), + } + } + + /// Returns `true` when this header is from the first fragment of a link. + pub fn is_first_fragment(&self) -> bool { + self.dispatch() == DISPATCH_FIRST_FRAGMENT_HEADER + } + + /// Returns the key for identifying the packet it belongs to. + pub fn get_key(&self, ieee802154_repr: &Ieee802154Repr) -> Key { + Key { + ll_src_addr: ieee802154_repr.src_addr.unwrap(), + ll_dst_addr: ieee802154_repr.dst_addr.unwrap(), + datagram_size: self.datagram_size(), + datagram_tag: self.datagram_tag(), + } + } + } + + impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return the payload. + pub fn payload(&self) -> &'a [u8] { + match self.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + &raw[field::FIRST_FRAGMENT_REST] + } + DISPATCH_FRAGMENT_HEADER => { + let raw = self.buffer.as_ref(); + &raw[field::NEXT_FRAGMENT_REST] + } + _ => unreachable!(), + } + } + } + + impl + AsMut<[u8]>> Packet { + fn set_dispatch_field(&mut self, value: u8) { + let raw = self.buffer.as_mut(); + raw[field::DISPATCH] = (raw[field::DISPATCH] & !(0b11111 << 3)) | (value << 3); + } + + fn set_datagram_size(&mut self, size: u16) { + let raw = self.buffer.as_mut(); + let mut v = NetworkEndian::read_u16(&raw[field::DATAGRAM_SIZE]); + v = (v & !0b111_1111_1111) | size; + + NetworkEndian::write_u16(&mut raw[field::DATAGRAM_SIZE], v); + } + + fn set_datagram_tag(&mut self, tag: u16) { + let raw = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut raw[field::DATAGRAM_TAG], tag); + } + + fn set_datagram_offset(&mut self, offset: u8) { + let raw = self.buffer.as_mut(); + raw[field::DATAGRAM_OFFSET] = offset; + } + } + + /// A high-level representation of a 6LoWPAN Fragment header. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub enum Repr { + FirstFragment { size: u16, tag: u16 }, + Fragment { size: u16, tag: u16, offset: u8 }, + } + + impl core::fmt::Display for Repr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + Repr::FirstFragment { size, tag } => { + write!(f, "FirstFrag size={size} tag={tag}") + } + Repr::Fragment { size, tag, offset } => { + write!(f, "NthFrag size={size} tag={tag} offset={offset}") + } + } + } + } + + #[cfg(feature = "defmt")] + impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + match self { + Repr::FirstFragment { size, tag } => { + defmt::write!(fmt, "FirstFrag size={} tag={}", size, tag); + } + Repr::Fragment { size, tag, offset } => { + defmt::write!(fmt, "NthFrag size={} tag={} offset={}", size, tag, offset); + } + } + } + } + + impl Repr { + /// Parse a 6LoWPAN Fragment header. + pub fn parse>(packet: &Packet) -> Result { + let size = packet.datagram_size(); + let tag = packet.datagram_tag(); + + match packet.dispatch() { + DISPATCH_FIRST_FRAGMENT_HEADER => Ok(Self::FirstFragment { size, tag }), + DISPATCH_FRAGMENT_HEADER => Ok(Self::Fragment { + size, + tag, + offset: packet.datagram_offset(), + }), + _ => Err(Error), + } + } + + /// Returns the length of the Fragment header. + pub const fn buffer_len(&self) -> usize { + match self { + Self::FirstFragment { .. } => field::FIRST_FRAGMENT_REST.start, + Self::Fragment { .. } => field::NEXT_FRAGMENT_REST.start, + } + } + + /// Emit a high-level representation into a 6LoWPAN Fragment header. + pub fn emit + AsMut<[u8]>>(&self, packet: &mut Packet) { + match self { + Self::FirstFragment { size, tag } => { + packet.set_dispatch_field(DISPATCH_FIRST_FRAGMENT_HEADER); + packet.set_datagram_size(*size); + packet.set_datagram_tag(*tag); + } + Self::Fragment { size, tag, offset } => { + packet.set_dispatch_field(DISPATCH_FRAGMENT_HEADER); + packet.set_datagram_size(*size); + packet.set_datagram_tag(*tag); + packet.set_datagram_offset(*offset); + } + } + } + } +} + +#[derive(Debug, PartialEq, Eq, Clone, Copy)] +pub enum NextHeader { + Compressed, + Uncompressed(IpProtocol), +} + +impl core::fmt::Display for NextHeader { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + match self { + NextHeader::Compressed => write!(f, "compressed"), + NextHeader::Uncompressed(protocol) => write!(f, "{protocol}"), + } + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for NextHeader { + fn format(&self, fmt: defmt::Formatter) { + match self { + NextHeader::Compressed => defmt::write!(fmt, "compressed"), + NextHeader::Uncompressed(protocol) => defmt::write!(fmt, "{}", protocol), + } + } +} + +pub mod iphc { + //! Implementation of IP Header Compression from [RFC 6282 § 3.1]. + //! It defines the compression of IPv6 headers. + //! + //! [RFC 6282 § 3.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1 + + use super::{ + AddressContext, AddressMode, Error, NextHeader, Result, UnresolvedAddress, + DISPATCH_IPHC_HEADER, + }; + use crate::wire::{ieee802154::Address as LlAddress, ipv6, IpProtocol}; + use byteorder::{ByteOrder, NetworkEndian}; + + mod field { + use crate::wire::field::*; + + pub const IPHC_FIELD: Field = 0..2; + } + + macro_rules! get_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&self) -> u8 { + let data = self.buffer.as_ref(); + let raw = NetworkEndian::read_u16(&data[field::IPHC_FIELD]); + ((raw >> $shift) & $mask) as u8 + } + }; + } + + macro_rules! set_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&mut self, val: u8) { + let data = &mut self.buffer.as_mut()[field::IPHC_FIELD]; + let mut raw = NetworkEndian::read_u16(data); + + raw = (raw & !($mask << $shift)) | ((val as u16) << $shift); + NetworkEndian::write_u16(data, raw); + } + }; + } + + /// A read/write wrapper around a 6LoWPAN IPHC header. + /// [RFC 6282 § 3.1] specifies the format of the header. + /// + /// The header always start with the following base format (from [RFC 6282 § 3.1.1]): + /// ```txt + /// 0 1 + /// 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 + /// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + /// | 0 | 1 | 1 | TF |NH | HLIM |CID|SAC| SAM | M |DAC| DAM | + /// +---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+---+ + /// ``` + /// With: + /// - TF: Traffic Class and Flow Label + /// - NH: Next Header + /// - HLIM: Hop Limit + /// - CID: Context Identifier Extension + /// - SAC: Source Address Compression + /// - SAM: Source Address Mode + /// - M: Multicast Compression + /// - DAC: Destination Address Compression + /// - DAM: Destination Address Mode + /// + /// Depending on the flags in the base format, the following fields are added to the header: + /// - Traffic Class and Flow Label + /// - Next Header + /// - Hop Limit + /// - IPv6 source address + /// - IPv6 destinatino address + /// + /// [RFC 6282 § 3.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1 + /// [RFC 6282 § 3.1.1]: https://datatracker.ietf.org/doc/html/rfc6282#section-3.1.1 + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct Packet> { + buffer: T, + } + + impl> Packet { + /// Input a raw octet buffer with a 6LoWPAN IPHC header structure. + pub const fn new_unchecked(buffer: T) -> Self { + Packet { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + if buffer.len() < 2 { + return Err(Error); + } + + let mut offset = self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size(); + offset += self.src_address_size(); + offset += self.dst_address_size(); + + if offset as usize > buffer.len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + /// Return the Next Header field. + pub fn next_header(&self) -> NextHeader { + let nh = self.nh_field(); + + if nh == 1 { + // The next header field is compressed. + // It is also encoded using LOWPAN_NHC. + NextHeader::Compressed + } else { + // The full 8 bits for Next Header are carried in-line. + let start = (self.ip_fields_start() + self.traffic_class_size()) as usize; + + let data = self.buffer.as_ref(); + let nh = data[start..start + 1][0]; + NextHeader::Uncompressed(IpProtocol::from(nh)) + } + } + + /// Return the Hop Limit. + pub fn hop_limit(&self) -> u8 { + match self.hlim_field() { + 0b00 => { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size()) as usize; + + let data = self.buffer.as_ref(); + data[start..start + 1][0] + } + 0b01 => 1, + 0b10 => 64, + 0b11 => 255, + _ => unreachable!(), + } + } + + /// Return the Source Context Identifier. + pub fn src_context_id(&self) -> Option { + if self.cid_field() == 1 { + let data = self.buffer.as_ref(); + Some(data[2] >> 4) + } else { + None + } + } + + /// Return the Destination Context Identifier. + pub fn dst_context_id(&self) -> Option { + if self.cid_field() == 1 { + let data = self.buffer.as_ref(); + Some(data[2] & 0x0f) + } else { + None + } + } + + /// Return the ECN field (when it is inlined). + pub fn ecn_field(&self) -> Option { + match self.tf_field() { + 0b00 | 0b01 | 0b10 => { + let start = self.ip_fields_start() as usize; + Some(self.buffer.as_ref()[start..][0] & 0b1100_0000) + } + 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the DSCP field (when it is inlined). + pub fn dscp_field(&self) -> Option { + match self.tf_field() { + 0b00 | 0b10 => { + let start = self.ip_fields_start() as usize; + Some(self.buffer.as_ref()[start..][0] & 0b111111) + } + 0b01 | 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the flow label field (when it is inlined). + pub fn flow_label_field(&self) -> Option { + match self.tf_field() { + 0b00 => { + let start = self.ip_fields_start() as usize; + Some(NetworkEndian::read_u16( + &self.buffer.as_ref()[start..][2..4], + )) + } + 0b01 => { + let start = self.ip_fields_start() as usize; + Some(NetworkEndian::read_u16( + &self.buffer.as_ref()[start..][1..3], + )) + } + 0b10 | 0b11 => None, + _ => unreachable!(), + } + } + + /// Return the Source Address. + pub fn src_addr(&self) -> Result { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size()) as usize; + + let data = self.buffer.as_ref(); + match (self.sac_field(), self.sam_field()) { + (0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine64bits(&data[start..][..8]), + )), + (0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine16bits(&data[start..][..2]), + )), + (0, 0b11) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)), + (1, 0b00) => Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::Unspecified, + ))), + (1, 0b01) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine64bits(&data[start..][..8]), + ))) + } else { + Err(Error) + } + } + (1, 0b10) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine16bits(&data[start..][..2]), + ))) + } else { + Err(Error) + } + } + (1, 0b11) => { + if let Some(id) = self.src_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::FullyElided, + ))) + } else { + Err(Error) + } + } + _ => Err(Error), + } + } + + /// Return the Destination Address. + pub fn dst_addr(&self) -> Result { + let start = (self.ip_fields_start() + + self.traffic_class_size() + + self.next_header_size() + + self.hop_limit_size() + + self.src_address_size()) as usize; + + let data = self.buffer.as_ref(); + match (self.m_field(), self.dac_field(), self.dam_field()) { + (0, 0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (0, 0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine64bits(&data[start..][..8]), + )), + (0, 0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::InLine16bits(&data[start..][..2]), + )), + (0, 0, 0b11) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)), + (0, 1, 0b00) => Ok(UnresolvedAddress::Reserved), + (0, 1, 0b01) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine64bits(&data[start..][..8]), + ))) + } else { + Err(Error) + } + } + (0, 1, 0b10) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::InLine16bits(&data[start..][..2]), + ))) + } else { + Err(Error) + } + } + (0, 1, 0b11) => { + if let Some(id) = self.dst_context_id() { + Ok(UnresolvedAddress::WithContext(( + id as usize, + AddressMode::FullyElided, + ))) + } else { + Err(Error) + } + } + (1, 0, 0b00) => Ok(UnresolvedAddress::WithoutContext(AddressMode::FullInline( + &data[start..][..16], + ))), + (1, 0, 0b01) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast48bits(&data[start..][..6]), + )), + (1, 0, 0b10) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast32bits(&data[start..][..4]), + )), + (1, 0, 0b11) => Ok(UnresolvedAddress::WithoutContext( + AddressMode::Multicast8bits(&data[start..][..1]), + )), + (1, 1, 0b00) => Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::NotSupported, + ))), + (1, 1, 0b01 | 0b10 | 0b11) => Ok(UnresolvedAddress::Reserved), + _ => Err(Error), + } + } + + get_field!(dispatch_field, 0b111, 13); + get_field!(tf_field, 0b11, 11); + get_field!(nh_field, 0b1, 10); + get_field!(hlim_field, 0b11, 8); + get_field!(cid_field, 0b1, 7); + get_field!(sac_field, 0b1, 6); + get_field!(sam_field, 0b11, 4); + get_field!(m_field, 0b1, 3); + get_field!(dac_field, 0b1, 2); + get_field!(dam_field, 0b11, 0); + + /// Return the start for the IP fields. + fn ip_fields_start(&self) -> u8 { + 2 + self.cid_size() + } + + /// Get the size in octets of the traffic class field. + fn traffic_class_size(&self) -> u8 { + match self.tf_field() { + 0b00 => 4, + 0b01 => 3, + 0b10 => 1, + 0b11 => 0, + _ => unreachable!(), + } + } + + /// Get the size in octets of the next header field. + fn next_header_size(&self) -> u8 { + (self.nh_field() != 1) as u8 + } + + /// Get the size in octets of the hop limit field. + fn hop_limit_size(&self) -> u8 { + (self.hlim_field() == 0b00) as u8 + } + + /// Get the size in octets of the CID field. + fn cid_size(&self) -> u8 { + (self.cid_field() == 1) as u8 + } + + /// Get the size in octets of the source address. + fn src_address_size(&self) -> u8 { + match (self.sac_field(), self.sam_field()) { + (0, 0b00) => 16, // The full address is carried in-line. + (0, 0b01) => 8, // The first 64 bits are elided. + (0, 0b10) => 2, // The first 112 bits are elided. + (0, 0b11) => 0, // The address is fully elided. + (1, 0b00) => 0, // The UNSPECIFIED address. + (1, 0b01) => 8, // Address derived using context information. + (1, 0b10) => 2, // Address derived using context information. + (1, 0b11) => 0, // Address derived using context information. + _ => unreachable!(), + } + } + + /// Get the size in octets of the address address. + fn dst_address_size(&self) -> u8 { + match (self.m_field(), self.dac_field(), self.dam_field()) { + (0, 0, 0b00) => 16, // The full address is carried in-line. + (0, 0, 0b01) => 8, // The first 64 bits are elided. + (0, 0, 0b10) => 2, // The first 112 bits are elided. + (0, 0, 0b11) => 0, // The address is fully elided. + (0, 1, 0b00) => 0, // Reserved. + (0, 1, 0b01) => 8, // Address derived using context information. + (0, 1, 0b10) => 2, // Address derived using context information. + (0, 1, 0b11) => 0, // Address derived using context information. + (1, 0, 0b00) => 16, // The full address is carried in-line. + (1, 0, 0b01) => 6, // The address takes the form ffXX::00XX:XXXX:XXXX. + (1, 0, 0b10) => 4, // The address takes the form ffXX::00XX:XXXX. + (1, 0, 0b11) => 1, // The address takes the form ff02::00XX. + (1, 1, 0b00) => 6, // Match Unicast-Prefix-based IPv6. + (1, 1, 0b01) => 0, // Reserved. + (1, 1, 0b10) => 0, // Reserved. + (1, 1, 0b11) => 0, // Reserved. + _ => unreachable!(), + } + } + + /// Return the length of the header. + pub fn header_len(&self) -> usize { + let mut len = self.ip_fields_start(); + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + len as usize + } + } + + impl<'a, T: AsRef<[u8]> + ?Sized> Packet<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let mut len = self.ip_fields_start(); + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + let len = len as usize; + + let data = self.buffer.as_ref(); + &data[len..] + } + } + + impl + AsMut<[u8]>> Packet { + /// Set the dispatch field to `0b011`. + fn set_dispatch_field(&mut self) { + let data = &mut self.buffer.as_mut()[field::IPHC_FIELD]; + let mut raw = NetworkEndian::read_u16(data); + + raw = (raw & !(0b111 << 13)) | (0b11 << 13); + NetworkEndian::write_u16(data, raw); + } + + set_field!(set_tf_field, 0b11, 11); + set_field!(set_nh_field, 0b1, 10); + set_field!(set_hlim_field, 0b11, 8); + set_field!(set_cid_field, 0b1, 7); + set_field!(set_sac_field, 0b1, 6); + set_field!(set_sam_field, 0b11, 4); + set_field!(set_m_field, 0b1, 3); + set_field!(set_dac_field, 0b1, 2); + set_field!(set_dam_field, 0b11, 0); + + fn set_field(&mut self, idx: usize, value: &[u8]) { + let raw = self.buffer.as_mut(); + raw[idx..idx + value.len()].copy_from_slice(value); + } + + /// Set the Next Header. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_next_header(&mut self, nh: NextHeader, mut idx: usize) -> usize { + match nh { + NextHeader::Uncompressed(nh) => { + self.set_nh_field(0); + self.set_field(idx, &[nh.into()]); + idx += 1; + } + NextHeader::Compressed => self.set_nh_field(1), + } + + idx + } + + /// Set the Hop Limit. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_hop_limit(&mut self, hl: u8, mut idx: usize) -> usize { + match hl { + 255 => self.set_hlim_field(0b11), + 64 => self.set_hlim_field(0b10), + 1 => self.set_hlim_field(0b01), + _ => { + self.set_hlim_field(0b00); + self.set_field(idx, &[hl]); + idx += 1; + } + } + + idx + } + + /// Set the Source Address based on the IPv6 address and the Link-Local address. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_src_address( + &mut self, + src_addr: ipv6::Address, + ll_src_addr: Option, + mut idx: usize, + ) -> usize { + self.set_cid_field(0); + self.set_sac_field(0); + let src = src_addr.as_bytes(); + if src_addr == ipv6::Address::UNSPECIFIED { + self.set_sac_field(1); + self.set_sam_field(0b00); + } else if src_addr.is_link_local() { + // We have a link local address. + // The remainder of the address can be elided when the context contains + // a 802.15.4 short address or a 802.15.4 extended address which can be + // converted to a eui64 address. + let is_eui_64 = ll_src_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == src[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if src[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [src[14], src[15]]; + + if ll_src_addr == Some(LlAddress::Short(ll)) { + // We have the context from the 802.15.4 frame. + // The context contains the short address. + // We can elide the source address. + self.set_sam_field(0b11); + } else { + // We don't have the context from the 802.15.4 frame. + // We cannot elide the source address, however we can elide 112 bits. + self.set_sam_field(0b10); + + self.set_field(idx, &src[14..]); + idx += 2; + } + } else if is_eui_64 { + // We have the context from the 802.15.4 frame. + // The context contains the extended address. + // We can elide the source address. + self.set_sam_field(0b11); + } else { + // We cannot elide the source address, however we can elide 64 bits. + self.set_sam_field(0b01); + + self.set_field(idx, &src[8..]); + idx += 8; + } + } else { + // We cannot elide anything. + self.set_sam_field(0b00); + self.set_field(idx, src); + idx += 16; + } + + idx + } + + /// Set the Destination Address based on the IPv6 address and the Link-Local address. + /// + /// **NOTE**: `idx` is the offset at which the Next Header needs to be written to. + fn set_dst_address( + &mut self, + dst_addr: ipv6::Address, + ll_dst_addr: Option, + mut idx: usize, + ) -> usize { + self.set_dac_field(0); + self.set_dam_field(0); + self.set_m_field(0); + let dst = dst_addr.as_bytes(); + if dst_addr.is_multicast() { + self.set_m_field(1); + + if dst[1] == 0x02 && dst[2..15] == [0; 13] { + self.set_dam_field(0b11); + + self.set_field(idx, &[dst[15]]); + idx += 1; + } else if dst[2..13] == [0; 11] { + self.set_dam_field(0b10); + + self.set_field(idx, &[dst[1]]); + idx += 1; + self.set_field(idx, &dst[13..]); + idx += 3; + } else if dst[2..11] == [0; 9] { + self.set_dam_field(0b01); + + self.set_field(idx, &[dst[1]]); + idx += 1; + self.set_field(idx, &dst[11..]); + idx += 5; + } else { + self.set_dam_field(0b11); + + self.set_field(idx, dst); + idx += 16; + } + } else if dst_addr.is_link_local() { + let is_eui_64 = ll_dst_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == dst[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if dst[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [dst[14], dst[15]]; + + if ll_dst_addr == Some(LlAddress::Short(ll)) { + self.set_dam_field(0b11); + } else { + self.set_dam_field(0b10); + + self.set_field(idx, &dst[14..]); + idx += 2; + } + } else if is_eui_64 { + self.set_dam_field(0b11); + } else { + self.set_dam_field(0b01); + + self.set_field(idx, &dst[8..]); + idx += 8; + } + } else { + self.set_dam_field(0b00); + + self.set_field(idx, dst); + idx += 16; + } + + idx + } + + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let mut len = self.ip_fields_start(); + + len += self.traffic_class_size(); + len += self.next_header_size(); + len += self.hop_limit_size(); + len += self.src_address_size(); + len += self.dst_address_size(); + + let len = len as usize; + + let data = self.buffer.as_mut(); + &mut data[len..] + } + } + + /// A high-level representation of a 6LoWPAN IPHC header. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + pub struct Repr { + pub src_addr: ipv6::Address, + pub ll_src_addr: Option, + pub dst_addr: ipv6::Address, + pub ll_dst_addr: Option, + pub next_header: NextHeader, + pub hop_limit: u8, + // TODO(thvdveld): refactor the following fields into something else + pub ecn: Option, + pub dscp: Option, + pub flow_label: Option, + } + + impl core::fmt::Display for Repr { + fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result { + write!( + f, + "IPHC src={} dst={} nxt-hdr={} hop-limit={}", + self.src_addr, self.dst_addr, self.next_header, self.hop_limit + ) + } + } + + #[cfg(feature = "defmt")] + impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!( + fmt, + "IPHC src={} dst={} nxt-hdr={} hop-limit={}", + self.src_addr, + self.dst_addr, + self.next_header, + self.hop_limit + ); + } + } + + impl Repr { + /// Parse a 6LoWPAN IPHC header and return a high-level representation. + /// + /// The `ll_src_addr` and `ll_dst_addr` are the link-local addresses used for resolving the + /// IPv6 packets. + pub fn parse + ?Sized>( + packet: &Packet<&T>, + ll_src_addr: Option, + ll_dst_addr: Option, + addr_context: &[AddressContext], + ) -> Result { + // Ensure basic accessors will work. + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_IPHC_HEADER { + // This is not an LOWPAN_IPHC packet. + return Err(Error); + } + + let src_addr = packet.src_addr()?.resolve(ll_src_addr, addr_context)?; + let dst_addr = packet.dst_addr()?.resolve(ll_dst_addr, addr_context)?; + + Ok(Self { + src_addr, + ll_src_addr, + dst_addr, + ll_dst_addr, + next_header: packet.next_header(), + hop_limit: packet.hop_limit(), + ecn: packet.ecn_field(), + dscp: packet.dscp_field(), + flow_label: packet.flow_label_field(), + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + let mut len = 0; + len += 2; // The minimal header length + + len += match self.next_header { + NextHeader::Compressed => 0, // The next header is compressed (we don't need to inline what the next header is) + NextHeader::Uncompressed(_) => 1, // The next header field is inlined + }; + + // Hop Limit size + len += match self.hop_limit { + 255 | 64 | 1 => 0, // We can inline the hop limit + _ => 1, + }; + + // Add the length of the source address + len += if self.src_addr == ipv6::Address::UNSPECIFIED { + 0 + } else if self.src_addr.is_link_local() { + let src = self.src_addr.as_bytes(); + let ll = [src[14], src[15]]; + + let is_eui_64 = self + .ll_src_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == src[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if src[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + if self.ll_src_addr == Some(LlAddress::Short(ll)) { + 0 + } else { + 2 + } + } else if is_eui_64 { + 0 + } else { + 8 + } + } else { + 16 + }; + + // Add the size of the destination header + let dst = self.dst_addr.as_bytes(); + len += if self.dst_addr.is_multicast() { + if dst[1] == 0x02 && dst[2..15] == [0; 13] { + 1 + } else if dst[2..13] == [0; 11] { + 4 + } else if dst[2..11] == [0; 9] { + 6 + } else { + 16 + } + } else if self.dst_addr.is_link_local() { + let is_eui_64 = self + .ll_dst_addr + .map(|addr| { + addr.as_eui_64() + .map(|addr| addr[..] == dst[8..]) + .unwrap_or(false) + }) + .unwrap_or(false); + + if dst[8..14] == [0, 0, 0, 0xff, 0xfe, 0] { + let ll = [dst[14], dst[15]]; + + if self.ll_dst_addr == Some(LlAddress::Short(ll)) { + 0 + } else { + 2 + } + } else if is_eui_64 { + 0 + } else { + 8 + } + } else { + 16 + }; + + len += match (self.ecn, self.dscp, self.flow_label) { + (Some(_), Some(_), Some(_)) => 4, + (Some(_), None, Some(_)) => 3, + (Some(_), Some(_), None) => 1, + (None, None, None) => 0, + _ => unreachable!(), + }; + + len + } + + /// Emit a high-level representation into a 6LoWPAN IPHC header. + pub fn emit + AsMut<[u8]>>(&self, packet: &mut Packet) { + let idx = 2; + + packet.set_dispatch_field(); + + // FIXME(thvdveld): we don't set anything from the traffic flow. + packet.set_tf_field(0b11); + + let idx = packet.set_next_header(self.next_header, idx); + let idx = packet.set_hop_limit(self.hop_limit, idx); + let idx = packet.set_src_address(self.src_addr, self.ll_src_addr, idx); + packet.set_dst_address(self.dst_addr, self.ll_dst_addr, idx); + } + } + + #[cfg(test)] + mod test { + use super::*; + + #[test] + fn iphc_fields() { + let bytes = [ + 0x7a, 0x33, // IPHC + 0x3a, // Next header + ]; + + let packet = Packet::new_unchecked(bytes); + + assert_eq!(packet.dispatch_field(), 0b011); + assert_eq!(packet.tf_field(), 0b11); + assert_eq!(packet.nh_field(), 0b0); + assert_eq!(packet.hlim_field(), 0b10); + assert_eq!(packet.cid_field(), 0b0); + assert_eq!(packet.sac_field(), 0b0); + assert_eq!(packet.sam_field(), 0b11); + assert_eq!(packet.m_field(), 0b0); + assert_eq!(packet.dac_field(), 0b0); + assert_eq!(packet.dam_field(), 0b11); + + assert_eq!( + packet.next_header(), + NextHeader::Uncompressed(IpProtocol::Icmpv6) + ); + + assert_eq!(packet.src_address_size(), 0); + assert_eq!(packet.dst_address_size(), 0); + assert_eq!(packet.hop_limit(), 64); + + assert_eq!( + packet.src_addr(), + Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)) + ); + assert_eq!( + packet.dst_addr(), + Ok(UnresolvedAddress::WithoutContext(AddressMode::FullyElided)) + ); + + let bytes = [ + 0x7e, 0xf7, // IPHC, + 0x00, // CID + ]; + + let packet = Packet::new_unchecked(bytes); + + assert_eq!(packet.dispatch_field(), 0b011); + assert_eq!(packet.tf_field(), 0b11); + assert_eq!(packet.nh_field(), 0b1); + assert_eq!(packet.hlim_field(), 0b10); + assert_eq!(packet.cid_field(), 0b1); + assert_eq!(packet.sac_field(), 0b1); + assert_eq!(packet.sam_field(), 0b11); + assert_eq!(packet.m_field(), 0b0); + assert_eq!(packet.dac_field(), 0b1); + assert_eq!(packet.dam_field(), 0b11); + + assert_eq!(packet.next_header(), NextHeader::Compressed); + + assert_eq!(packet.src_address_size(), 0); + assert_eq!(packet.dst_address_size(), 0); + assert_eq!(packet.hop_limit(), 64); + + assert_eq!( + packet.src_addr(), + Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::FullyElided + ))) + ); + assert_eq!( + packet.dst_addr(), + Ok(UnresolvedAddress::WithContext(( + 0, + AddressMode::FullyElided + ))) + ); + } + } +} + +pub mod nhc { + //! Implementation of Next Header Compression from [RFC 6282 § 4]. + //! + //! [RFC 6282 § 4]: https://datatracker.ietf.org/doc/html/rfc6282#section-4 + use super::{Error, NextHeader, Result, DISPATCH_EXT_HEADER, DISPATCH_UDP_HEADER}; + use crate::{ + phy::ChecksumCapabilities, + wire::{ + ip::{checksum, Address as IpAddress}, + ipv6, + udp::Repr as UdpRepr, + IpProtocol, + }, + }; + use byteorder::{ByteOrder, NetworkEndian}; + use ipv6::Address; + + macro_rules! get_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&self) -> u8 { + let data = self.buffer.as_ref(); + let raw = &data[0]; + ((raw >> $shift) & $mask) as u8 + } + }; + } + + macro_rules! set_field { + ($name:ident, $mask:expr, $shift:expr) => { + fn $name(&mut self, val: u8) { + let data = self.buffer.as_mut(); + let mut raw = data[0]; + raw = (raw & !($mask << $shift)) | (val << $shift); + data[0] = raw; + } + }; + } + + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + /// A read/write wrapper around a 6LoWPAN_NHC Header. + /// [RFC 6282 § 4.2] specifies the format of the header. + /// + /// The header has the following format: + /// ```txt + /// 0 1 2 3 4 5 6 7 + /// +---+---+---+---+---+---+---+---+ + /// | 1 | 1 | 1 | 0 | EID |NH | + /// +---+---+---+---+---+---+---+---+ + /// ``` + /// + /// With: + /// - EID: the extension header ID + /// - NH: Next Header + /// + /// [RFC 6282 § 4.2]: https://datatracker.ietf.org/doc/html/rfc6282#section-4.2 + pub enum NhcPacket { + ExtHeader, + UdpHeader, + } + + impl NhcPacket { + /// Returns the type of the Next Header header. + /// This can either be an Extenstion header or an 6LoWPAN Udp header. + /// + /// # Errors + /// Returns `[Error::Unrecognized]` when neither the Extension Header dispatch or the Udp + /// dispatch is recognized. + pub fn dispatch(buffer: impl AsRef<[u8]>) -> Result { + let raw = buffer.as_ref(); + if raw.is_empty() { + return Err(Error); + } + + if raw[0] >> 4 == DISPATCH_EXT_HEADER { + // We have a compressed IPv6 Extension Header. + Ok(Self::ExtHeader) + } else if raw[0] >> 3 == DISPATCH_UDP_HEADER { + // We have a compressed UDP header. + Ok(Self::UdpHeader) + } else { + Err(Error) + } + } + } + + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub enum ExtHeaderId { + HopByHopHeader, + RoutingHeader, + FragmentHeader, + DestinationOptionsHeader, + MobilityHeader, + Header, + Reserved, + } + + impl From for IpProtocol { + fn from(val: ExtHeaderId) -> Self { + match val { + ExtHeaderId::HopByHopHeader => Self::HopByHop, + ExtHeaderId::RoutingHeader => Self::Ipv6Route, + ExtHeaderId::FragmentHeader => Self::Ipv6Frag, + ExtHeaderId::DestinationOptionsHeader => Self::Ipv6Opts, + ExtHeaderId::MobilityHeader => Self::Unknown(0), + ExtHeaderId::Header => Self::Unknown(0), + ExtHeaderId::Reserved => Self::Unknown(0), + } + } + } + + /// A read/write wrapper around a 6LoWPAN NHC Extension header. + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct ExtHeaderPacket> { + buffer: T, + } + + impl> ExtHeaderPacket { + /// Input a raw octet buffer with a 6LoWPAN NHC Extension Header structure. + pub const fn new_unchecked(buffer: T) -> Self { + ExtHeaderPacket { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + + if packet.eid_field() > 7 { + return Err(Error); + } + + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + + if buffer.is_empty() { + return Err(Error); + } + + let mut len = 1; + len += self.next_header_size(); + + if len <= buffer.len() { + Ok(()) + } else { + Err(Error) + } + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + get_field!(dispatch_field, 0b1111, 4); + get_field!(eid_field, 0b111, 1); + get_field!(nh_field, 0b1, 0); + + /// Return the Extension Header ID. + pub fn extension_header_id(&self) -> ExtHeaderId { + match self.eid_field() { + 0 => ExtHeaderId::HopByHopHeader, + 1 => ExtHeaderId::RoutingHeader, + 2 => ExtHeaderId::FragmentHeader, + 3 => ExtHeaderId::DestinationOptionsHeader, + 4 => ExtHeaderId::MobilityHeader, + 5 | 6 => ExtHeaderId::Reserved, + 7 => ExtHeaderId::Header, + _ => unreachable!(), + } + } + + /// Parse the next header field. + pub fn next_header(&self) -> NextHeader { + if self.nh_field() == 1 { + NextHeader::Compressed + } else { + // The full 8 bits for Next Header are carried in-line. + let start = 1; + + let data = self.buffer.as_ref(); + let nh = data[start]; + NextHeader::Uncompressed(IpProtocol::from(nh)) + } + } + + /// Return the size of the Next Header field. + fn next_header_size(&self) -> usize { + // If nh is set, then the Next Header is compressed using LOWPAN_NHC + match self.nh_field() { + 0 => 1, + 1 => 0, + _ => unreachable!(), + } + } + } + + impl<'a, T: AsRef<[u8]> + ?Sized> ExtHeaderPacket<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let start = 2 + self.next_header_size(); + &self.buffer.as_ref()[start..] + } + } + + impl + AsMut<[u8]>> ExtHeaderPacket { + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let start = 2 + self.next_header_size(); + &mut self.buffer.as_mut()[start..] + } + + /// Set the dispatch field to `0b1110`. + fn set_dispatch_field(&mut self) { + let data = self.buffer.as_mut(); + data[0] = (data[0] & !(0b1111 << 4)) | (DISPATCH_EXT_HEADER << 4); + } + + set_field!(set_eid_field, 0b111, 1); + set_field!(set_nh_field, 0b1, 0); + + /// Set the Extension Header ID field. + fn set_extension_header_id(&mut self, ext_header_id: ExtHeaderId) { + let id = match ext_header_id { + ExtHeaderId::HopByHopHeader => 0, + ExtHeaderId::RoutingHeader => 1, + ExtHeaderId::FragmentHeader => 2, + ExtHeaderId::DestinationOptionsHeader => 3, + ExtHeaderId::MobilityHeader => 4, + ExtHeaderId::Reserved => 5, + ExtHeaderId::Header => 7, + }; + + self.set_eid_field(id); + } + + /// Set the Next Header. + fn set_next_header(&mut self, next_header: NextHeader) { + match next_header { + NextHeader::Compressed => self.set_nh_field(0b1), + NextHeader::Uncompressed(nh) => { + self.set_nh_field(0b0); + + let start = 1; + let data = self.buffer.as_mut(); + data[start] = nh.into(); + } + } + } + + /// Set the length. + fn set_length(&mut self, length: u8) { + let start = 1 + self.next_header_size(); + + let data = self.buffer.as_mut(); + data[start] = length; + } + } + + /// A high-level representation of an 6LoWPAN NHC Extension header. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct ExtHeaderRepr { + ext_header_id: ExtHeaderId, + next_header: NextHeader, + length: u8, + } + + impl ExtHeaderRepr { + /// Parse a 6LoWPAN NHC Extension Header packet and return a high-level representation. + pub fn parse + ?Sized>(packet: &ExtHeaderPacket<&T>) -> Result { + // Ensure basic accessors will work. + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_EXT_HEADER { + return Err(Error); + } + + Ok(Self { + ext_header_id: packet.extension_header_id(), + next_header: packet.next_header(), + length: packet.payload().len() as u8, + }) + } + + /// Return the length of a header that will be emitted from this high-level representation. + pub fn buffer_len(&self) -> usize { + let mut len = 1; // The minimal header size + + if self.next_header != NextHeader::Compressed { + len += 1; + } + + len += 1; // The length + + len + } + + /// Emit a high-level representaiton into a 6LoWPAN NHC Extension Header packet. + pub fn emit + AsMut<[u8]>>(&self, packet: &mut ExtHeaderPacket) { + packet.set_dispatch_field(); + packet.set_extension_header_id(self.ext_header_id); + packet.set_next_header(self.next_header); + packet.set_length(self.length); + } + } + + #[cfg(test)] + mod tests { + use super::*; + + use crate::wire::{Ipv6RoutingHeader, Ipv6RoutingRepr}; + + #[cfg(feature = "proto-rpl")] + use crate::wire::{ + Ipv6Option, Ipv6OptionRepr, Ipv6OptionsIterator, RplHopByHopRepr, RplInstanceId, + }; + + #[cfg(feature = "proto-rpl")] + const RPL_HOP_BY_HOP_PACKET: [u8; 9] = + [0xe0, 0x3a, 0x06, 0x63, 0x04, 0x00, 0x1e, 0x03, 0x00]; + + const ROUTING_SR_PACKET: [u8; 32] = [ + 0xe3, 0x1e, 0x03, 0x03, 0x99, 0x30, 0x00, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, + 0x05, 0x06, 0x00, 0x06, 0x00, 0x06, 0x00, 0x06, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, + 0x02, 0x00, 0x00, 0x00, + ]; + + #[test] + #[cfg(feature = "proto-rpl")] + fn test_rpl_hop_by_hop_option_deconstruct() { + let header = ExtHeaderPacket::new_checked(&RPL_HOP_BY_HOP_PACKET).unwrap(); + assert_eq!( + header.next_header(), + NextHeader::Uncompressed(IpProtocol::Icmpv6) + ); + assert_eq!(header.extension_header_id(), ExtHeaderId::HopByHopHeader); + + let options = header.payload(); + let mut options = Ipv6OptionsIterator::new(options); + let rpl_repr = options.next().unwrap(); + let rpl_repr = rpl_repr.unwrap(); + + match rpl_repr { + Ipv6OptionRepr::Rpl(rpl) => { + assert_eq!( + rpl, + RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: RplInstanceId::from(0x1e), + sender_rank: 0x0300, + } + ); + } + _ => unreachable!(), + } + } + + #[test] + #[cfg(feature = "proto-rpl")] + fn test_rpl_hop_by_hop_option_emit() { + let repr = Ipv6OptionRepr::Rpl(RplHopByHopRepr { + down: false, + rank_error: false, + forwarding_error: false, + instance_id: RplInstanceId::from(0x1e), + sender_rank: 0x0300, + }); + + let ext_hdr = ExtHeaderRepr { + ext_header_id: ExtHeaderId::HopByHopHeader, + next_header: NextHeader::Uncompressed(IpProtocol::Icmpv6), + length: repr.buffer_len() as u8, + }; + + let mut buffer = vec![0u8; ext_hdr.buffer_len() + repr.buffer_len()]; + ext_hdr.emit(&mut ExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + repr.emit(&mut Ipv6Option::new_unchecked( + &mut buffer[ext_hdr.buffer_len()..], + )); + + assert_eq!(&buffer[..], RPL_HOP_BY_HOP_PACKET); + } + + #[test] + fn test_source_routing_deconstruct() { + let header = ExtHeaderPacket::new_checked(&ROUTING_SR_PACKET).unwrap(); + assert_eq!(header.next_header(), NextHeader::Compressed); + assert_eq!(header.extension_header_id(), ExtHeaderId::RoutingHeader); + + let routing_hdr = Ipv6RoutingHeader::new_checked(header.payload()).unwrap(); + let repr = Ipv6RoutingRepr::parse(&routing_hdr).unwrap(); + assert_eq!( + repr, + Ipv6RoutingRepr::Rpl { + segments_left: 3, + cmpr_i: 9, + cmpr_e: 9, + pad: 3, + addresses: &[ + 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x06, 0x00, 0x06, 0x00, 0x06, + 0x00, 0x06, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00 + ], + } + ); + } + + #[test] + fn test_source_routing_emit() { + let routing_hdr = Ipv6RoutingRepr::Rpl { + segments_left: 3, + cmpr_i: 9, + cmpr_e: 9, + pad: 3, + addresses: &[ + 0x05, 0x00, 0x05, 0x00, 0x05, 0x00, 0x05, 0x06, 0x00, 0x06, 0x00, 0x06, 0x00, + 0x06, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x02, 0x00, 0x00, 0x00, + ], + }; + + let ext_hdr = ExtHeaderRepr { + ext_header_id: ExtHeaderId::RoutingHeader, + next_header: NextHeader::Compressed, + length: routing_hdr.buffer_len() as u8, + }; + + let mut buffer = vec![0u8; ext_hdr.buffer_len() + routing_hdr.buffer_len()]; + ext_hdr.emit(&mut ExtHeaderPacket::new_unchecked( + &mut buffer[..ext_hdr.buffer_len()], + )); + routing_hdr.emit(&mut Ipv6RoutingHeader::new_unchecked( + &mut buffer[ext_hdr.buffer_len()..], + )); + + assert_eq!(&buffer[..], ROUTING_SR_PACKET); + } + } + + /// A read/write wrapper around a 6LoWPAN_NHC UDP frame. + /// [RFC 6282 § 4.3] specifies the format of the header. + /// + /// The base header has the following formath: + /// ```txt + /// 0 1 2 3 4 5 6 7 + /// +---+---+---+---+---+---+---+---+ + /// | 1 | 1 | 1 | 1 | 0 | C | P | + /// +---+---+---+---+---+---+---+---+ + /// With: + /// - C: checksum, specifies if the checksum is elided. + /// - P: ports, specifies if the ports are elided. + /// ``` + /// + /// [RFC 6282 § 4.3]: https://datatracker.ietf.org/doc/html/rfc6282#section-4.3 + #[derive(Debug, Clone)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct UdpNhcPacket> { + buffer: T, + } + + impl> UdpNhcPacket { + /// Input a raw octet buffer with a LOWPAN_NHC frame structure for UDP. + pub const fn new_unchecked(buffer: T) -> Self { + Self { buffer } + } + + /// Shorthand for a combination of [new_unchecked] and [check_len]. + /// + /// [new_unchecked]: #method.new_unchecked + /// [check_len]: #method.check_len + pub fn new_checked(buffer: T) -> Result { + let packet = Self::new_unchecked(buffer); + packet.check_len()?; + Ok(packet) + } + + /// Ensure that no accessor method will panic if called. + /// Returns `Err(Error::Truncated)` if the buffer is too short. + pub fn check_len(&self) -> Result<()> { + let buffer = self.buffer.as_ref(); + + if buffer.is_empty() { + return Err(Error); + } + + let index = 1 + self.ports_size() + self.checksum_size(); + if index > buffer.len() { + return Err(Error); + } + + Ok(()) + } + + /// Consumes the frame, returning the underlying buffer. + pub fn into_inner(self) -> T { + self.buffer + } + + get_field!(dispatch_field, 0b11111, 3); + get_field!(checksum_field, 0b1, 2); + get_field!(ports_field, 0b11, 0); + + /// Returns the index of the start of the next header compressed fields. + const fn nhc_fields_start(&self) -> usize { + 1 + } + + /// Return the source port number. + pub fn src_port(&self) -> u16 { + match self.ports_field() { + 0b00 | 0b01 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[start..start + 2]) + } + 0b10 => { + // The first 8 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf000 + data[start] as u16 + } + 0b11 => { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf0b0 + (data[start] >> 4) as u16 + } + _ => unreachable!(), + } + } + + /// Return the destination port number. + pub fn dst_port(&self) -> u16 { + match self.ports_field() { + 0b00 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[idx + 2..idx + 4]) + } + 0b01 => { + // The first 8 bits are elided. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + 0xf000 + data[idx] as u16 + } + 0b10 => { + // The full 16 bits are carried in-line. + let data = self.buffer.as_ref(); + let idx = self.nhc_fields_start(); + + NetworkEndian::read_u16(&data[idx + 1..idx + 1 + 2]) + } + 0b11 => { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start(); + + 0xf0b0 + (data[start] & 0xff) as u16 + } + _ => unreachable!(), + } + } + + /// Return the checksum. + pub fn checksum(&self) -> Option { + if self.checksum_field() == 0b0 { + // The first 12 bits are elided. + let data = self.buffer.as_ref(); + let start = self.nhc_fields_start() + self.ports_size(); + Some(NetworkEndian::read_u16(&data[start..start + 2])) + } else { + // The checksum is elided and needs to be recomputed on the 6LoWPAN termination point. + None + } + } + + // Return the size of the checksum field. + pub(crate) fn checksum_size(&self) -> usize { + match self.checksum_field() { + 0b0 => 2, + 0b1 => 0, + _ => unreachable!(), + } + } + + /// Returns the total size of both port numbers. + pub(crate) fn ports_size(&self) -> usize { + match self.ports_field() { + 0b00 => 4, // 16 bits + 16 bits + 0b01 => 3, // 16 bits + 8 bits + 0b10 => 3, // 8 bits + 16 bits + 0b11 => 1, // 4 bits + 4 bits + _ => unreachable!(), + } + } + } + + impl<'a, T: AsRef<[u8]> + ?Sized> UdpNhcPacket<&'a T> { + /// Return a pointer to the payload. + pub fn payload(&self) -> &'a [u8] { + let start = 1 + self.ports_size() + self.checksum_size(); + &self.buffer.as_ref()[start..] + } + } + + impl + AsMut<[u8]>> UdpNhcPacket { + /// Return a mutable pointer to the payload. + pub fn payload_mut(&mut self) -> &mut [u8] { + let start = 1 + self.ports_size() + 2; // XXX(thvdveld): we assume we put the checksum inlined. + &mut self.buffer.as_mut()[start..] + } + + /// Set the dispatch field to `0b11110`. + fn set_dispatch_field(&mut self) { + let data = self.buffer.as_mut(); + data[0] = (data[0] & !(0b11111 << 3)) | (DISPATCH_UDP_HEADER << 3); + } + + set_field!(set_checksum_field, 0b1, 2); + set_field!(set_ports_field, 0b11, 0); + + fn set_ports(&mut self, src_port: u16, dst_port: u16) { + let mut idx = 1; + + match (src_port, dst_port) { + (0xf0b0..=0xf0bf, 0xf0b0..=0xf0bf) => { + // We can compress both the source and destination ports. + self.set_ports_field(0b11); + let data = self.buffer.as_mut(); + data[idx] = (((src_port - 0xf0b0) as u8) << 4) & ((dst_port - 0xf0b0) as u8); + } + (0xf000..=0xf0ff, _) => { + // We can compress the source port, but not the destination port. + self.set_ports_field(0b10); + let data = self.buffer.as_mut(); + data[idx] = (src_port - 0xf000) as u8; + idx += 1; + + NetworkEndian::write_u16(&mut data[idx..idx + 2], dst_port); + } + (_, 0xf000..=0xf0ff) => { + // We can compress the destination port, but not the source port. + self.set_ports_field(0b01); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], src_port); + idx += 2; + data[idx] = (dst_port - 0xf000) as u8; + } + (_, _) => { + // We cannot compress any port. + self.set_ports_field(0b00); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], src_port); + idx += 2; + NetworkEndian::write_u16(&mut data[idx..idx + 2], dst_port); + } + }; + } + + fn set_checksum(&mut self, checksum: u16) { + self.set_checksum_field(0b0); + let idx = 1 + self.ports_size(); + let data = self.buffer.as_mut(); + NetworkEndian::write_u16(&mut data[idx..idx + 2], checksum); + } + } + + /// A high-level representation of a 6LoWPAN NHC UDP header. + #[derive(Debug, PartialEq, Eq, Clone, Copy)] + #[cfg_attr(feature = "defmt", derive(defmt::Format))] + pub struct UdpNhcRepr(pub UdpRepr); + + impl<'a> UdpNhcRepr { + /// Parse a 6LoWPAN NHC UDP packet and return a high-level representation. + pub fn parse + ?Sized>( + packet: &UdpNhcPacket<&'a T>, + src_addr: &ipv6::Address, + dst_addr: &ipv6::Address, + checksum_caps: &ChecksumCapabilities, + ) -> Result { + packet.check_len()?; + + if packet.dispatch_field() != DISPATCH_UDP_HEADER { + return Err(Error); + } + + if checksum_caps.udp.rx() { + let payload_len = packet.payload().len(); + let chk_sum = !checksum::combine(&[ + checksum::pseudo_header( + &IpAddress::Ipv6(*src_addr), + &IpAddress::Ipv6(*dst_addr), + crate::wire::ip::Protocol::Udp, + payload_len as u32 + 8, + ), + packet.src_port(), + packet.dst_port(), + payload_len as u16 + 8, + checksum::data(packet.payload()), + ]); + + if let Some(checksum) = packet.checksum() { + if chk_sum != checksum { + return Err(Error); + } + } + } + + Ok(Self(UdpRepr { + src_port: packet.src_port(), + dst_port: packet.dst_port(), + })) + } + + /// Return the length of a packet that will be emitted from this high-level representation. + pub fn header_len(&self) -> usize { + let mut len = 1; // The minimal header size + + len += 2; // XXX We assume we will add the checksum at the end + + // Check if we can compress the source and destination ports + match (self.src_port, self.dst_port) { + (0xf0b0..=0xf0bf, 0xf0b0..=0xf0bf) => len + 1, + (0xf000..=0xf0ff, _) | (_, 0xf000..=0xf0ff) => len + 3, + (_, _) => len + 4, + } + } + + /// Emit a high-level representation into a LOWPAN_NHC UDP header. + pub fn emit + AsMut<[u8]>>( + &self, + packet: &mut UdpNhcPacket, + src_addr: &Address, + dst_addr: &Address, + payload_len: usize, + emit_payload: impl FnOnce(&mut [u8]), + ) { + packet.set_dispatch_field(); + packet.set_ports(self.src_port, self.dst_port); + emit_payload(packet.payload_mut()); + + let chk_sum = !checksum::combine(&[ + checksum::pseudo_header( + &IpAddress::Ipv6(*src_addr), + &IpAddress::Ipv6(*dst_addr), + crate::wire::ip::Protocol::Udp, + payload_len as u32 + 8, + ), + self.src_port, + self.dst_port, + payload_len as u16 + 8, + checksum::data(packet.payload_mut()), + ]); + + packet.set_checksum(chk_sum); + } + } + + impl core::ops::Deref for UdpNhcRepr { + type Target = UdpRepr; + + fn deref(&self) -> &Self::Target { + &self.0 + } + } + + impl core::ops::DerefMut for UdpNhcRepr { + fn deref_mut(&mut self) -> &mut Self::Target { + &mut self.0 + } + } + + #[cfg(test)] + mod test { + use super::*; + + #[test] + fn ext_header_nhc_fields() { + let bytes = [0xe3, 0x06, 0x03, 0x00, 0xff, 0x00, 0x00, 0x00]; + + let packet = ExtHeaderPacket::new_checked(&bytes[..]).unwrap(); + assert_eq!(packet.next_header_size(), 0); + assert_eq!(packet.dispatch_field(), DISPATCH_EXT_HEADER); + assert_eq!(packet.extension_header_id(), ExtHeaderId::RoutingHeader); + + assert_eq!(packet.payload(), [0x03, 0x00, 0xff, 0x00, 0x00, 0x00]); + } + + #[test] + fn ext_header_emit() { + let ext_header = ExtHeaderRepr { + ext_header_id: ExtHeaderId::RoutingHeader, + next_header: NextHeader::Compressed, + length: 6, + }; + + let len = ext_header.buffer_len(); + let mut buffer = [0u8; 127]; + let mut packet = ExtHeaderPacket::new_unchecked(&mut buffer[..len]); + ext_header.emit(&mut packet); + + assert_eq!(packet.dispatch_field(), DISPATCH_EXT_HEADER); + assert_eq!(packet.next_header(), NextHeader::Compressed); + assert_eq!(packet.extension_header_id(), ExtHeaderId::RoutingHeader); + } + + #[test] + fn udp_nhc_fields() { + let bytes = [0xf0, 0x16, 0x2e, 0x22, 0x3d, 0x28, 0xc4]; + + let packet = UdpNhcPacket::new_checked(&bytes[..]).unwrap(); + assert_eq!(packet.dispatch_field(), DISPATCH_UDP_HEADER); + assert_eq!(packet.checksum(), Some(0x28c4)); + assert_eq!(packet.src_port(), 5678); + assert_eq!(packet.dst_port(), 8765); + } + + #[test] + fn udp_emit() { + let udp = UdpNhcRepr(UdpRepr { + src_port: 0xf0b1, + dst_port: 0xf001, + }); + + let payload = b"Hello World!"; + + let src_addr = ipv6::Address::default(); + let dst_addr = ipv6::Address::default(); + + let len = udp.header_len() + payload.len(); + let mut buffer = [0u8; 127]; + let mut packet = UdpNhcPacket::new_unchecked(&mut buffer[..len]); + udp.emit(&mut packet, &src_addr, &dst_addr, payload.len(), |buf| { + buf.copy_from_slice(&payload[..]) + }); + + assert_eq!(packet.dispatch_field(), DISPATCH_UDP_HEADER); + assert_eq!(packet.src_port(), 0xf0b1); + assert_eq!(packet.dst_port(), 0xf001); + assert_eq!(packet.payload_mut(), b"Hello World!"); + } + } +} + +#[cfg(test)] +mod test { + use super::*; + + #[test] + fn sixlowpan_fragment_emit() { + let repr = frag::Repr::FirstFragment { + size: 0xff, + tag: 0xabcd, + }; + let buffer = [0u8; 4]; + let mut packet = frag::Packet::new_unchecked(buffer); + + assert_eq!(repr.buffer_len(), 4); + repr.emit(&mut packet); + + assert_eq!(packet.datagram_size(), 0xff); + assert_eq!(packet.datagram_tag(), 0xabcd); + assert_eq!(packet.into_inner(), [0xc0, 0xff, 0xab, 0xcd]); + + let repr = frag::Repr::Fragment { + size: 0xff, + tag: 0xabcd, + offset: 0xcc, + }; + let buffer = [0u8; 5]; + let mut packet = frag::Packet::new_unchecked(buffer); + + assert_eq!(repr.buffer_len(), 5); + repr.emit(&mut packet); + + assert_eq!(packet.datagram_size(), 0xff); + assert_eq!(packet.datagram_tag(), 0xabcd); + assert_eq!(packet.into_inner(), [0xe0, 0xff, 0xab, 0xcd, 0xcc]); + } + + #[test] + fn sixlowpan_three_fragments() { + use crate::wire::ieee802154::Frame as Ieee802154Frame; + use crate::wire::ieee802154::Repr as Ieee802154Repr; + use crate::wire::Ieee802154Address; + + let key = frag::Key { + ll_src_addr: Ieee802154Address::Extended([50, 147, 130, 47, 40, 8, 62, 217]), + ll_dst_addr: Ieee802154Address::Extended([26, 11, 66, 66, 66, 66, 66, 66]), + datagram_size: 307, + datagram_tag: 63, + }; + + let frame1: &[u8] = &[ + 0x41, 0xcc, 0x92, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xc1, 0x33, 0x00, 0x3f, 0x6e, 0x33, 0x02, + 0x35, 0x3d, 0xf0, 0xd2, 0x5f, 0x1b, 0x39, 0xb4, 0x6b, 0x4c, 0x6f, 0x72, 0x65, 0x6d, + 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, 0x6c, 0x6f, 0x72, 0x20, 0x73, + 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, 0x6f, 0x6e, 0x73, 0x65, + 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, 0x69, 0x73, 0x63, + 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, 0x2e, 0x20, 0x41, 0x6c, 0x69, 0x71, + 0x75, 0x61, 0x6d, 0x20, 0x64, 0x75, 0x69, 0x20, 0x6f, 0x64, 0x69, 0x6f, 0x2c, 0x20, + 0x69, 0x61, 0x63, 0x75, 0x6c, 0x69, 0x73, 0x20, 0x76, 0x65, 0x6c, 0x20, 0x72, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame1).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 0); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + + let frame2: &[u8] = &[ + 0x41, 0xcc, 0x93, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xe1, 0x33, 0x00, 0x3f, 0x11, 0x75, 0x74, + 0x72, 0x75, 0x6d, 0x20, 0x61, 0x74, 0x2c, 0x20, 0x74, 0x72, 0x69, 0x73, 0x74, 0x69, + 0x71, 0x75, 0x65, 0x20, 0x6e, 0x6f, 0x6e, 0x20, 0x6e, 0x75, 0x6e, 0x63, 0x20, 0x65, + 0x72, 0x61, 0x74, 0x20, 0x63, 0x75, 0x72, 0x61, 0x65, 0x2e, 0x20, 0x4c, 0x6f, 0x72, + 0x65, 0x6d, 0x20, 0x69, 0x70, 0x73, 0x75, 0x6d, 0x20, 0x64, 0x6f, 0x6c, 0x6f, 0x72, + 0x20, 0x73, 0x69, 0x74, 0x20, 0x61, 0x6d, 0x65, 0x74, 0x2c, 0x20, 0x63, 0x6f, 0x6e, + 0x73, 0x65, 0x63, 0x74, 0x65, 0x74, 0x75, 0x72, 0x20, 0x61, 0x64, 0x69, 0x70, 0x69, + 0x73, 0x63, 0x69, 0x6e, 0x67, 0x20, 0x65, 0x6c, 0x69, 0x74, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame2).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 136 / 8); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + + let frame3: &[u8] = &[ + 0x41, 0xcc, 0x94, 0xef, 0xbe, 0x42, 0x42, 0x42, 0x42, 0x42, 0x42, 0x0b, 0x1a, 0xd9, + 0x3e, 0x08, 0x28, 0x2f, 0x82, 0x93, 0x32, 0xe1, 0x33, 0x00, 0x3f, 0x1d, 0x2e, 0x20, + 0x41, 0x6c, 0x69, 0x71, 0x75, 0x61, 0x6d, 0x20, 0x64, 0x75, 0x69, 0x20, 0x6f, 0x64, + 0x69, 0x6f, 0x2c, 0x20, 0x69, 0x61, 0x63, 0x75, 0x6c, 0x69, 0x73, 0x20, 0x76, 0x65, + 0x6c, 0x20, 0x72, 0x75, 0x74, 0x72, 0x75, 0x6d, 0x20, 0x61, 0x74, 0x2c, 0x20, 0x74, + 0x72, 0x69, 0x73, 0x74, 0x69, 0x71, 0x75, 0x65, 0x20, 0x6e, 0x6f, 0x6e, 0x20, 0x6e, + 0x75, 0x6e, 0x63, 0x20, 0x65, 0x72, 0x61, 0x74, 0x20, 0x63, 0x75, 0x72, 0x61, 0x65, + 0x2e, 0x20, 0x0a, + ]; + + let ieee802154_frame = Ieee802154Frame::new_checked(frame3).unwrap(); + let ieee802154_repr = Ieee802154Repr::parse(&ieee802154_frame).unwrap(); + + let sixlowpan_frame = + SixlowpanPacket::dispatch(ieee802154_frame.payload().unwrap()).unwrap(); + + let frag = if let SixlowpanPacket::FragmentHeader = sixlowpan_frame { + frag::Packet::new_checked(ieee802154_frame.payload().unwrap()).unwrap() + } else { + unreachable!() + }; + + assert_eq!(frag.datagram_size(), 307); + assert_eq!(frag.datagram_tag(), 0x003f); + assert_eq!(frag.datagram_offset(), 232 / 8); + + assert_eq!(frag.get_key(&ieee802154_repr), key); + } +} diff --git a/src/wire/tcp.rs b/src/wire/tcp.rs index 8242b366d..bcc6fbc59 100644 --- a/src/wire/tcp.rs +++ b/src/wire/tcp.rs @@ -1,10 +1,10 @@ -use core::{i32, ops, cmp, fmt}; use byteorder::{ByteOrder, NetworkEndian}; +use core::{cmp, fmt, i32, ops}; -use {Error, Result}; -use phy::ChecksumCapabilities; -use super::{IpProtocol, IpAddress}; -use super::ip::checksum; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{IpAddress, IpProtocol}; /// A TCP sequence number. /// @@ -19,6 +19,13 @@ impl fmt::Display for SeqNumber { } } +#[cfg(feature = "defmt")] +impl defmt::Format for SeqNumber { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "{}", self.0 as u32); + } +} + impl ops::Add for SeqNumber { type Output = SeqNumber; @@ -66,26 +73,27 @@ impl cmp::PartialOrd for SeqNumber { } /// A read/write wrapper around a Transmission Control Protocol packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub struct Packet> { - buffer: T + buffer: T, } mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; pub const SRC_PORT: Field = 0..2; pub const DST_PORT: Field = 2..4; - pub const SEQ_NUM: Field = 4..8; - pub const ACK_NUM: Field = 8..12; - pub const FLAGS: Field = 12..14; + pub const SEQ_NUM: Field = 4..8; + pub const ACK_NUM: Field = 8..12; + pub const FLAGS: Field = 12..14; pub const WIN_SIZE: Field = 14..16; pub const CHECKSUM: Field = 16..18; - pub const URGENT: Field = 18..20; + pub const URGENT: Field = 18..20; - pub fn OPTIONS(length: u8) -> Field { + pub const fn OPTIONS(length: u8) -> Field { URGENT.end..(length as usize) } @@ -97,19 +105,21 @@ mod field { pub const FLG_URG: u16 = 0x020; pub const FLG_ECE: u16 = 0x040; pub const FLG_CWR: u16 = 0x080; - pub const FLG_NS: u16 = 0x100; + pub const FLG_NS: u16 = 0x100; pub const OPT_END: u8 = 0x00; pub const OPT_NOP: u8 = 0x01; pub const OPT_MSS: u8 = 0x02; - pub const OPT_WS: u8 = 0x03; + pub const OPT_WS: u8 = 0x03; pub const OPT_SACKPERM: u8 = 0x04; - pub const OPT_SACKRNG: u8 = 0x05; + pub const OPT_SACKRNG: u8 = 0x05; } +pub const HEADER_LEN: usize = field::URGENT.end; + impl> Packet { /// Imbue a raw octet buffer with TCP packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -124,8 +134,8 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. - /// Returns `Err(Error::Malformed)` if the header length field has a value smaller + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the header length field has a value smaller /// than the minimal header length. /// /// The result of this check is invalidated by calling [set_header_len]. @@ -134,13 +144,11 @@ impl> Packet { pub fn check_len(&self) -> Result<()> { let len = self.buffer.as_ref().len(); if len < field::URGENT.end { - Err(Error::Truncated) + Err(Error) } else { let header_len = self.header_len() as usize; - if len < header_len { - Err(Error::Truncated) - } else if header_len < field::URGENT.end { - Err(Error::Malformed) + if len < header_len || header_len < field::URGENT.end { + Err(Error) } else { Ok(()) } @@ -285,8 +293,12 @@ impl> Packet { pub fn segment_len(&self) -> usize { let data = self.buffer.as_ref(); let mut length = data.len() - self.header_len() as usize; - if self.syn() { length += 1 } - if self.fin() { length += 1 } + if self.syn() { + length += 1 + } + if self.fin() { + length += 1 + } length } @@ -294,13 +306,10 @@ impl> Packet { pub fn selective_ack_permitted(&self) -> Result { let data = self.buffer.as_ref(); let mut options = &data[field::OPTIONS(self.header_len())]; - while options.len() > 0 { + while !options.is_empty() { let (next_options, option) = TcpOption::parse(options)?; - match option { - TcpOption::SackPermitted => { - return Ok(true); - }, - _ => {}, + if option == TcpOption::SackPermitted { + return Ok(true); } options = next_options; } @@ -310,18 +319,13 @@ impl> Packet { /// Return the selective acknowledgement ranges, if any. If there are none in the packet, an /// array of ``None`` values will be returned. /// - pub fn selective_ack_ranges<'s>( - &'s self - ) -> Result<[Option<(u32, u32)>; 3]> { + pub fn selective_ack_ranges(&self) -> Result<[Option<(u32, u32)>; 3]> { let data = self.buffer.as_ref(); let mut options = &data[field::OPTIONS(self.header_len())]; - while options.len() > 0 { + while !options.is_empty() { let (next_options, option) = TcpOption::parse(options)?; - match option { - TcpOption::SackRange(slice) => { - return Ok(slice); - }, - _ => {}, + if let TcpOption::SackRange(slice) = option { + return Ok(slice); } options = next_options; } @@ -337,13 +341,14 @@ impl> Packet { /// # Fuzzing /// This function always returns `true` when fuzzing. pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { - if cfg!(fuzzing) { return true } + if cfg!(fuzzing) { + return true; + } let data = self.buffer.as_ref(); checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, - data.len() as u32), - checksum::data(data) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, data.len() as u32), + checksum::data(data), ]) == !0 } } @@ -409,7 +414,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_fin(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_FIN } else { raw & !field::FLG_FIN }; + let raw = if value { + raw | field::FLG_FIN + } else { + raw & !field::FLG_FIN + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -418,7 +427,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_syn(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_SYN } else { raw & !field::FLG_SYN }; + let raw = if value { + raw | field::FLG_SYN + } else { + raw & !field::FLG_SYN + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -427,7 +440,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_rst(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_RST } else { raw & !field::FLG_RST }; + let raw = if value { + raw | field::FLG_RST + } else { + raw & !field::FLG_RST + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -436,7 +453,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_psh(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_PSH } else { raw & !field::FLG_PSH }; + let raw = if value { + raw | field::FLG_PSH + } else { + raw & !field::FLG_PSH + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -445,7 +466,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_ack(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_ACK } else { raw & !field::FLG_ACK }; + let raw = if value { + raw | field::FLG_ACK + } else { + raw & !field::FLG_ACK + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -454,7 +479,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_urg(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_URG } else { raw & !field::FLG_URG }; + let raw = if value { + raw | field::FLG_URG + } else { + raw & !field::FLG_URG + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -463,7 +492,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_ece(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_ECE } else { raw & !field::FLG_ECE }; + let raw = if value { + raw | field::FLG_ECE + } else { + raw & !field::FLG_ECE + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -472,7 +505,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_cwr(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_CWR } else { raw & !field::FLG_CWR }; + let raw = if value { + raw | field::FLG_CWR + } else { + raw & !field::FLG_CWR + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -481,7 +518,11 @@ impl + AsMut<[u8]>> Packet { pub fn set_ns(&mut self, value: bool) { let data = self.buffer.as_mut(); let raw = NetworkEndian::read_u16(&data[field::FLAGS]); - let raw = if value { raw | field::FLG_NS } else { raw & !field::FLG_NS }; + let raw = if value { + raw | field::FLG_NS + } else { + raw & !field::FLG_NS + }; NetworkEndian::write_u16(&mut data[field::FLAGS], raw) } @@ -525,9 +566,8 @@ impl + AsMut<[u8]>> Packet { let checksum = { let data = self.buffer.as_ref(); !checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, - data.len() as u32), - checksum::data(data) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Tcp, data.len() as u32), + checksum::data(data), ]) }; self.set_checksum(checksum) @@ -558,6 +598,7 @@ impl> AsRef<[u8]> for Packet { /// A representation of a single TCP option. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum TcpOption<'a> { EndOfList, NoOperation, @@ -565,13 +606,13 @@ pub enum TcpOption<'a> { WindowScale(u8), SackPermitted, SackRange([Option<(u32, u32)>; 3]), - Unknown { kind: u8, data: &'a [u8] } + Unknown { kind: u8, data: &'a [u8] }, } impl<'a> TcpOption<'a> { pub fn parse(buffer: &'a [u8]) -> Result<(&'a [u8], TcpOption<'a>)> { let (length, option); - match *buffer.get(0).ok_or(Error::Truncated)? { + match *buffer.first().ok_or(Error)? { field::OPT_END => { length = 1; option = TcpOption::EndOfList; @@ -581,27 +622,21 @@ impl<'a> TcpOption<'a> { option = TcpOption::NoOperation; } kind => { - length = *buffer.get(1).ok_or(Error::Truncated)? as usize; - let data = buffer.get(2..length).ok_or(Error::Truncated)?; + length = *buffer.get(1).ok_or(Error)? as usize; + let data = buffer.get(2..length).ok_or(Error)?; match (kind, length) { - (field::OPT_END, _) | - (field::OPT_NOP, _) => - unreachable!(), - (field::OPT_MSS, 4) => - option = TcpOption::MaxSegmentSize(NetworkEndian::read_u16(data)), - (field::OPT_MSS, _) => - return Err(Error::Malformed), - (field::OPT_WS, 3) => - option = TcpOption::WindowScale(data[0]), - (field::OPT_WS, _) => - return Err(Error::Malformed), - (field::OPT_SACKPERM, 2) => - option = TcpOption::SackPermitted, - (field::OPT_SACKPERM, _) => - return Err(Error::Malformed), + (field::OPT_END, _) | (field::OPT_NOP, _) => unreachable!(), + (field::OPT_MSS, 4) => { + option = TcpOption::MaxSegmentSize(NetworkEndian::read_u16(data)) + } + (field::OPT_MSS, _) => return Err(Error), + (field::OPT_WS, 3) => option = TcpOption::WindowScale(data[0]), + (field::OPT_WS, _) => return Err(Error), + (field::OPT_SACKPERM, 2) => option = TcpOption::SackPermitted, + (field::OPT_SACKPERM, _) => return Err(Error), (field::OPT_SACKRNG, n) => { - if n < 10 || (n-2) % 8 != 0 { - return Err(Error::Malformed) + if n < 10 || (n - 2) % 8 != 0 { + return Err(Error); } if n > 26 { // It's possible for a remote to send 4 SACK blocks, but extremely rare. @@ -625,19 +660,16 @@ impl<'a> TcpOption<'a> { *nmut = if left < data.len() { let mid = left + 4; let right = mid + 4; - let range_left = NetworkEndian::read_u32( - &data[left..mid]); - let range_right = NetworkEndian::read_u32( - &data[mid..right]); + let range_left = NetworkEndian::read_u32(&data[left..mid]); + let range_right = NetworkEndian::read_u32(&data[mid..right]); Some((range_left, range_right)) } else { None }; }); option = TcpOption::SackRange(sack_ranges); - }, - (_, _) => - option = TcpOption::Unknown { kind: kind, data: data } + } + (_, _) => option = TcpOption::Unknown { kind, data }, } } } @@ -645,38 +677,36 @@ impl<'a> TcpOption<'a> { } pub fn buffer_len(&self) -> usize { - match self { - &TcpOption::EndOfList => 1, - &TcpOption::NoOperation => 1, - &TcpOption::MaxSegmentSize(_) => 4, - &TcpOption::WindowScale(_) => 3, - &TcpOption::SackPermitted => 2, - &TcpOption::SackRange(s) => s.iter().filter(|s| s.is_some()).count() * 8 + 2, - &TcpOption::Unknown { data, .. } => 2 + data.len() + match *self { + TcpOption::EndOfList => 1, + TcpOption::NoOperation => 1, + TcpOption::MaxSegmentSize(_) => 4, + TcpOption::WindowScale(_) => 3, + TcpOption::SackPermitted => 2, + TcpOption::SackRange(s) => s.iter().filter(|s| s.is_some()).count() * 8 + 2, + TcpOption::Unknown { data, .. } => 2 + data.len(), } } pub fn emit<'b>(&self, buffer: &'b mut [u8]) -> &'b mut [u8] { let length; - match self { - &TcpOption::EndOfList => { - length = 1; + match *self { + TcpOption::EndOfList => { + length = 1; // There may be padding space which also should be initialized. for p in buffer.iter_mut() { *p = field::OPT_END; } } - &TcpOption::NoOperation => { - length = 1; + TcpOption::NoOperation => { + length = 1; buffer[0] = field::OPT_NOP; } _ => { - length = self.buffer_len(); + length = self.buffer_len(); buffer[1] = length as u8; match self { - &TcpOption::EndOfList | - &TcpOption::NoOperation => - unreachable!(), + &TcpOption::EndOfList | &TcpOption::NoOperation => unreachable!(), &TcpOption::MaxSegmentSize(value) => { buffer[0] = field::OPT_MSS; NetworkEndian::write_u16(&mut buffer[2..], value) @@ -690,14 +720,21 @@ impl<'a> TcpOption<'a> { } &TcpOption::SackRange(slice) => { buffer[0] = field::OPT_SACKRNG; - slice.iter().filter(|s| s.is_some()).enumerate().for_each(|(i, s)| { - let (first, second) = *s.as_ref().unwrap(); - let pos = i * 8 + 2; - NetworkEndian::write_u32(&mut buffer[pos..], first); - NetworkEndian::write_u32(&mut buffer[pos+4..], second); - }); + slice + .iter() + .filter(|s| s.is_some()) + .enumerate() + .for_each(|(i, s)| { + let (first, second) = *s.as_ref().unwrap(); + let pos = i * 8 + 2; + NetworkEndian::write_u32(&mut buffer[pos..], first); + NetworkEndian::write_u32(&mut buffer[pos + 4..], second); + }); } - &TcpOption::Unknown { kind, data: provided } => { + &TcpOption::Unknown { + kind, + data: provided, + } => { buffer[0] = kind; buffer[2..].copy_from_slice(provided) } @@ -710,28 +747,30 @@ impl<'a> TcpOption<'a> { /// The possible control flags of a Transmission Control Protocol packet. #[derive(Debug, PartialEq, Eq, Clone, Copy)] +#[cfg_attr(feature = "defmt", derive(defmt::Format))] pub enum Control { None, Psh, Syn, Fin, - Rst + Rst, } +#[allow(clippy::len_without_is_empty)] impl Control { /// Return the length of a control flag, in terms of sequence space. - pub fn len(self) -> usize { + pub const fn len(self) -> usize { match self { - Control::Syn | Control::Fin => 1, - _ => 0 + Control::Syn | Control::Fin => 1, + _ => 0, } } /// Turn the PSH flag into no flag, and keep the rest as-is. - pub fn quash_psh(self) -> Control { + pub const fn quash_psh(self) -> Control { match self { Control::Psh => Control::None, - _ => self + _ => self, } } } @@ -739,46 +778,54 @@ impl Control { /// A high-level representation of a Transmission Control Protocol packet. #[derive(Debug, PartialEq, Eq, Clone, Copy)] pub struct Repr<'a> { - pub src_port: u16, - pub dst_port: u16, - pub control: Control, - pub seq_number: SeqNumber, - pub ack_number: Option, - pub window_len: u16, + pub src_port: u16, + pub dst_port: u16, + pub control: Control, + pub seq_number: SeqNumber, + pub ack_number: Option, + pub window_len: u16, pub window_scale: Option, pub max_seg_size: Option, pub sack_permitted: bool, - pub sack_ranges: [Option<(u32, u32)>; 3], - pub payload: &'a [u8] + pub sack_ranges: [Option<(u32, u32)>; 3], + pub payload: &'a [u8], } impl<'a> Repr<'a> { /// Parse a Transmission Control Protocol packet and return a high-level representation. - pub fn parse(packet: &Packet<&'a T>, src_addr: &IpAddress, dst_addr: &IpAddress, - checksum_caps: &ChecksumCapabilities) -> Result> - where T: AsRef<[u8]> + ?Sized { + pub fn parse( + packet: &Packet<&'a T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) -> Result> + where + T: AsRef<[u8]> + ?Sized, + { // Source and destination ports must be present. - if packet.src_port() == 0 { return Err(Error::Malformed) } - if packet.dst_port() == 0 { return Err(Error::Malformed) } + if packet.src_port() == 0 { + return Err(Error); + } + if packet.dst_port() == 0 { + return Err(Error); + } // Valid checksum is expected. if checksum_caps.tcp.rx() && !packet.verify_checksum(src_addr, dst_addr) { - return Err(Error::Checksum) + return Err(Error); } - let control = - match (packet.syn(), packet.fin(), packet.rst(), packet.psh()) { - (false, false, false, false) => Control::None, - (false, false, false, true) => Control::Psh, - (true, false, false, _) => Control::Syn, - (false, true, false, _) => Control::Fin, - (false, false, true , _) => Control::Rst, - _ => return Err(Error::Malformed) - }; - let ack_number = - match packet.ack() { - true => Some(packet.ack_number()), - false => None - }; + let control = match (packet.syn(), packet.fin(), packet.rst(), packet.psh()) { + (false, false, false, false) => Control::None, + (false, false, false, true) => Control::Psh, + (true, false, false, _) => Control::Syn, + (false, true, false, _) => Control::Fin, + (false, false, true, _) => Control::Rst, + _ => return Err(Error), + }; + let ack_number = match packet.ack() { + true => Some(packet.ack_number()), + false => None, + }; // The PSH flag is ignored. // The URG flag and the urgent field is ignored. This behavior is standards-compliant, // however, most deployed systems (e.g. Linux) are *not* standards-compliant, and would @@ -789,46 +836,49 @@ impl<'a> Repr<'a> { let mut options = packet.options(); let mut sack_permitted = false; let mut sack_ranges = [None, None, None]; - while options.len() > 0 { + while !options.is_empty() { let (next_options, option) = TcpOption::parse(options)?; match option { TcpOption::EndOfList => break, TcpOption::NoOperation => (), - TcpOption::MaxSegmentSize(value) => - max_seg_size = Some(value), + TcpOption::MaxSegmentSize(value) => max_seg_size = Some(value), TcpOption::WindowScale(value) => { // RFC 1323: Thus, the shift count must be limited to 14 (which allows windows - // of 2**30 = 1 Gbyte). If a Window Scale option is received with a shift.cnt + // of 2**30 = 1 Gigabyte). If a Window Scale option is received with a shift.cnt // value exceeding 14, the TCP should log the error but use 14 instead of the // specified value. window_scale = if value > 14 { - net_debug!("{}:{}:{}:{}: parsed window scaling factor >14, setting to 14", src_addr, packet.src_port(), dst_addr, packet.dst_port()); + net_debug!( + "{}:{}:{}:{}: parsed window scaling factor >14, setting to 14", + src_addr, + packet.src_port(), + dst_addr, + packet.dst_port() + ); Some(14) } else { Some(value) }; - }, - TcpOption::SackPermitted => - sack_permitted = true, - TcpOption::SackRange(slice) => - sack_ranges = slice, + } + TcpOption::SackPermitted => sack_permitted = true, + TcpOption::SackRange(slice) => sack_ranges = slice, _ => (), } options = next_options; } Ok(Repr { - src_port: packet.src_port(), - dst_port: packet.dst_port(), - control: control, - seq_number: packet.seq_number(), - ack_number: ack_number, - window_len: packet.window_len(), + src_port: packet.src_port(), + dst_port: packet.dst_port(), + control: control, + seq_number: packet.seq_number(), + ack_number: ack_number, + window_len: packet.window_len(), window_scale: window_scale, max_seg_size: max_seg_size, sack_permitted: sack_permitted, - sack_ranges: sack_ranges, - payload: packet.payload() + sack_ranges: sack_ranges, + payload: packet.payload(), }) } @@ -847,9 +897,11 @@ impl<'a> Repr<'a> { if self.sack_permitted { length += 2; } - let sack_range_len: usize = self.sack_ranges.iter().map( - |o| o.map(|_| 8).unwrap_or(0) - ).sum(); + let sack_range_len: usize = self + .sack_ranges + .iter() + .map(|o| o.map(|_| 8).unwrap_or(0)) + .sum(); if sack_range_len > 0 { length += sack_range_len + 2; } @@ -859,23 +911,21 @@ impl<'a> Repr<'a> { length } - /// Return the length of the header for the TCP protocol. - /// - /// Per RFC 6691, this should be used for MSS calculations. It may be smaller than the buffer - /// space required to accomodate this packet's data. - pub fn mss_header_len(&self) -> usize { - field::URGENT.end - } - /// Return the length of a packet that will be emitted from this high-level representation. pub fn buffer_len(&self) -> usize { self.header_len() + self.payload.len() } /// Emit a high-level representation into a Transmission Control Protocol packet. - pub fn emit(&self, packet: &mut Packet<&mut T>, src_addr: &IpAddress, dst_addr: &IpAddress, - checksum_caps: &ChecksumCapabilities) - where T: AsRef<[u8]> + AsMut<[u8]> + ?Sized { + pub fn emit( + &self, + packet: &mut Packet<&mut T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]> + ?Sized, + { packet.set_src_port(self.src_port); packet.set_dst_port(self.dst_port); packet.set_seq_number(self.seq_number); @@ -885,27 +935,31 @@ impl<'a> Repr<'a> { packet.clear_flags(); match self.control { Control::None => (), - Control::Psh => packet.set_psh(true), - Control::Syn => packet.set_syn(true), - Control::Fin => packet.set_fin(true), - Control::Rst => packet.set_rst(true) + Control::Psh => packet.set_psh(true), + Control::Syn => packet.set_syn(true), + Control::Fin => packet.set_fin(true), + Control::Rst => packet.set_rst(true), } packet.set_ack(self.ack_number.is_some()); { let mut options = packet.options_mut(); if let Some(value) = self.max_seg_size { - let tmp = options; options = TcpOption::MaxSegmentSize(value).emit(tmp); + let tmp = options; + options = TcpOption::MaxSegmentSize(value).emit(tmp); } if let Some(value) = self.window_scale { - let tmp = options; options = TcpOption::WindowScale(value).emit(tmp); + let tmp = options; + options = TcpOption::WindowScale(value).emit(tmp); } if self.sack_permitted { - let tmp = options; options = TcpOption::SackPermitted.emit(tmp); + let tmp = options; + options = TcpOption::SackPermitted.emit(tmp); } else if self.ack_number.is_some() && self.sack_ranges.iter().any(|s| s.is_some()) { - let tmp = options; options = TcpOption::SackRange(self.sack_ranges).emit(tmp); + let tmp = options; + options = TcpOption::SackRange(self.sack_ranges).emit(tmp); } - if options.len() > 0 { + if !options.is_empty() { TcpOption::EndOfList.emit(options); } } @@ -922,16 +976,16 @@ impl<'a> Repr<'a> { } /// Return the length of the segment, in terms of sequence space. - pub fn segment_len(&self) -> usize { + pub const fn segment_len(&self) -> usize { self.payload.len() + self.control.len() } /// Return whether the segment has no flags set (except PSH) and no data. - pub fn is_empty(&self) -> bool { + pub const fn is_empty(&self) -> bool { match self.control { - _ if self.payload.len() != 0 => false, - Control::Syn | Control::Fin | Control::Rst => false, - Control::None | Control::Psh => true + _ if !self.payload.is_empty() => false, + Control::Syn | Control::Fin | Control::Rst => false, + Control::None | Control::Psh => true, } } } @@ -939,15 +993,28 @@ impl<'a> Repr<'a> { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // Cannot use Repr::parse because we don't have the IP addresses. - write!(f, "TCP src={} dst={}", - self.src_port(), self.dst_port())?; - if self.syn() { write!(f, " syn")? } - if self.fin() { write!(f, " fin")? } - if self.rst() { write!(f, " rst")? } - if self.psh() { write!(f, " psh")? } - if self.ece() { write!(f, " ece")? } - if self.cwr() { write!(f, " cwr")? } - if self.ns() { write!(f, " ns" )? } + write!(f, "TCP src={} dst={}", self.src_port(), self.dst_port())?; + if self.syn() { + write!(f, " syn")? + } + if self.fin() { + write!(f, " fin")? + } + if self.rst() { + write!(f, " rst")? + } + if self.psh() { + write!(f, " psh")? + } + if self.ece() { + write!(f, " ece")? + } + if self.cwr() { + write!(f, " cwr")? + } + if self.ns() { + write!(f, " ns")? + } write!(f, " seq={}", self.seq_number())?; if self.ack() { write!(f, " ack={}", self.ack_number())?; @@ -959,25 +1026,19 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { write!(f, " len={}", self.payload().len())?; let mut options = self.options(); - while options.len() > 0 { - let (next_options, option) = - match TcpOption::parse(options) { - Ok(res) => res, - Err(err) => return write!(f, " ({})", err) - }; + while !options.is_empty() { + let (next_options, option) = match TcpOption::parse(options) { + Ok(res) => res, + Err(err) => return write!(f, " ({err})"), + }; match option { TcpOption::EndOfList => break, TcpOption::NoOperation => (), - TcpOption::MaxSegmentSize(value) => - write!(f, " mss={}", value)?, - TcpOption::WindowScale(value) => - write!(f, " ws={}", value)?, - TcpOption::SackPermitted => - write!(f, " sACK")?, - TcpOption::SackRange(slice) => - write!(f, " sACKr{:?}", slice)?, // debug print conveniently includes the []s - TcpOption::Unknown { kind, .. } => - write!(f, " opt({})", kind)?, + TcpOption::MaxSegmentSize(value) => write!(f, " mss={value}")?, + TcpOption::WindowScale(value) => write!(f, " ws={value}")?, + TcpOption::SackPermitted => write!(f, " sACK")?, + TcpOption::SackRange(slice) => write!(f, " sACKr{slice:?}")?, // debug print conveniently includes the []s + TcpOption::Unknown { kind, .. } => write!(f, " opt({kind})")?, } options = next_options; } @@ -987,45 +1048,70 @@ impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { impl<'a> fmt::Display for Repr<'a> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "TCP src={} dst={}", - self.src_port, self.dst_port)?; + write!(f, "TCP src={} dst={}", self.src_port, self.dst_port)?; match self.control { Control::Syn => write!(f, " syn")?, Control::Fin => write!(f, " fin")?, Control::Rst => write!(f, " rst")?, Control::Psh => write!(f, " psh")?, - Control::None => () + Control::None => (), } write!(f, " seq={}", self.seq_number)?; if let Some(ack_number) = self.ack_number { - write!(f, " ack={}", ack_number)?; + write!(f, " ack={ack_number}")?; } write!(f, " win={}", self.window_len)?; write!(f, " len={}", self.payload.len())?; if let Some(max_seg_size) = self.max_seg_size { - write!(f, " mss={}", max_seg_size)?; + write!(f, " mss={max_seg_size}")?; } Ok(()) } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +#[cfg(feature = "defmt")] +impl<'a> defmt::Format for Repr<'a> { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "TCP src={} dst={}", self.src_port, self.dst_port); + match self.control { + Control::Syn => defmt::write!(fmt, " syn"), + Control::Fin => defmt::write!(fmt, " fin"), + Control::Rst => defmt::write!(fmt, " rst"), + Control::Psh => defmt::write!(fmt, " psh"), + Control::None => (), + } + defmt::write!(fmt, " seq={}", self.seq_number); + if let Some(ack_number) = self.ack_number { + defmt::write!(fmt, " ack={}", ack_number); + } + defmt::write!(fmt, " win={}", self.window_len); + defmt::write!(fmt, " len={}", self.payload.len()); + if let Some(max_seg_size) = self.max_seg_size { + defmt::write!(fmt, " mss={}", max_seg_size); + } + } +} + +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { match Packet::new_checked(buffer) { - Err(err) => write!(f, "{}({})", indent, err), - Ok(packet) => write!(f, "{}{}", indent, packet) + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), } } } #[cfg(test)] mod test { - #[cfg(feature = "proto-ipv4")] - use wire::Ipv4Address; use super::*; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Ipv4Address; #[cfg(feature = "proto-ipv4")] const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]); @@ -1033,22 +1119,16 @@ mod test { const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]); #[cfg(feature = "proto-ipv4")] - static PACKET_BYTES: [u8; 28] = - [0xbf, 0x00, 0x00, 0x50, - 0x01, 0x23, 0x45, 0x67, - 0x89, 0xab, 0xcd, 0xef, - 0x60, 0x35, 0x01, 0x23, - 0x01, 0xb6, 0x02, 0x01, - 0x03, 0x03, 0x0c, 0x01, - 0xaa, 0x00, 0x00, 0xff]; + static PACKET_BYTES: [u8; 28] = [ + 0xbf, 0x00, 0x00, 0x50, 0x01, 0x23, 0x45, 0x67, 0x89, 0xab, 0xcd, 0xef, 0x60, 0x35, 0x01, + 0x23, 0x01, 0xb6, 0x02, 0x01, 0x03, 0x03, 0x0c, 0x01, 0xaa, 0x00, 0x00, 0xff, + ]; #[cfg(feature = "proto-ipv4")] - static OPTION_BYTES: [u8; 4] = - [0x03, 0x03, 0x0c, 0x01]; + static OPTION_BYTES: [u8; 4] = [0x03, 0x03, 0x0c, 0x01]; #[cfg(feature = "proto-ipv4")] - static PAYLOAD_BYTES: [u8; 4] = - [0xaa, 0x00, 0x00, 0xff]; + static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; #[test] #[cfg(feature = "proto-ipv4")] @@ -1059,18 +1139,18 @@ mod test { assert_eq!(packet.seq_number(), SeqNumber(0x01234567)); assert_eq!(packet.ack_number(), SeqNumber(0x89abcdefu32 as i32)); assert_eq!(packet.header_len(), 24); - assert_eq!(packet.fin(), true); - assert_eq!(packet.syn(), false); - assert_eq!(packet.rst(), true); - assert_eq!(packet.psh(), false); - assert_eq!(packet.ack(), true); - assert_eq!(packet.urg(), true); + assert!(packet.fin()); + assert!(!packet.syn()); + assert!(packet.rst()); + assert!(!packet.psh()); + assert!(packet.ack()); + assert!(packet.urg()); assert_eq!(packet.window_len(), 0x0123); assert_eq!(packet.urgent_at(), 0x0201); assert_eq!(packet.checksum(), 0x01b6); assert_eq!(packet.options(), &OPTION_BYTES[..]); assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); - assert_eq!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()), true); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); } #[test] @@ -1096,14 +1176,14 @@ mod test { packet.options_mut().copy_from_slice(&OPTION_BYTES[..]); packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into()); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } #[test] #[cfg(feature = "proto-ipv4")] fn test_truncated() { let packet = Packet::new_unchecked(&PACKET_BYTES[..23]); - assert_eq!(packet.check_len(), Err(Error::Truncated)); + assert_eq!(packet.check_len(), Err(Error)); } #[test] @@ -1111,32 +1191,29 @@ mod test { let mut bytes = vec![0; 20]; let mut packet = Packet::new_unchecked(&mut bytes); packet.set_header_len(10); - assert_eq!(packet.check_len(), Err(Error::Malformed)); + assert_eq!(packet.check_len(), Err(Error)); } #[cfg(feature = "proto-ipv4")] - static SYN_PACKET_BYTES: [u8; 24] = - [0xbf, 0x00, 0x00, 0x50, - 0x01, 0x23, 0x45, 0x67, - 0x00, 0x00, 0x00, 0x00, - 0x50, 0x02, 0x01, 0x23, - 0x7a, 0x8d, 0x00, 0x00, - 0xaa, 0x00, 0x00, 0xff]; + static SYN_PACKET_BYTES: [u8; 24] = [ + 0xbf, 0x00, 0x00, 0x50, 0x01, 0x23, 0x45, 0x67, 0x00, 0x00, 0x00, 0x00, 0x50, 0x02, 0x01, + 0x23, 0x7a, 0x8d, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff, + ]; #[cfg(feature = "proto-ipv4")] fn packet_repr() -> Repr<'static> { Repr { - src_port: 48896, - dst_port: 80, - seq_number: SeqNumber(0x01234567), - ack_number: None, - window_len: 0x0123, + src_port: 48896, + dst_port: 80, + seq_number: SeqNumber(0x01234567), + ack_number: None, + window_len: 0x0123, window_scale: None, - control: Control::Syn, + control: Control::Syn, max_seg_size: None, sack_permitted: false, - sack_ranges: [None, None, None], - payload: &PAYLOAD_BYTES + sack_ranges: [None, None, None], + payload: &PAYLOAD_BYTES, } } @@ -1144,7 +1221,13 @@ mod test { #[cfg(feature = "proto-ipv4")] fn test_parse() { let packet = Packet::new_unchecked(&SYN_PACKET_BYTES[..]); - let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default()).unwrap(); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); assert_eq!(repr, packet_repr()); } @@ -1154,8 +1237,13 @@ mod test { let repr = packet_repr(); let mut bytes = vec![0xa5; repr.buffer_len()]; let mut packet = Packet::new_unchecked(&mut bytes); - repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(), &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &SYN_PACKET_BYTES[..]); + repr.emit( + &mut packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &SYN_PACKET_BYTES[..]); } #[test] @@ -1167,57 +1255,59 @@ mod test { } macro_rules! assert_option_parses { - ($opt:expr, $data:expr) => ({ + ($opt:expr, $data:expr) => {{ assert_eq!(TcpOption::parse($data), Ok((&[][..], $opt))); let buffer = &mut [0; 40][..$opt.buffer_len()]; assert_eq!($opt.emit(buffer), &mut []); assert_eq!(&*buffer, $data); - }) + }}; } #[test] fn test_tcp_options() { - assert_option_parses!(TcpOption::EndOfList, - &[0x00]); - assert_option_parses!(TcpOption::NoOperation, - &[0x01]); - assert_option_parses!(TcpOption::MaxSegmentSize(1500), - &[0x02, 0x04, 0x05, 0xdc]); - assert_option_parses!(TcpOption::WindowScale(12), - &[0x03, 0x03, 0x0c]); - assert_option_parses!(TcpOption::SackPermitted, - &[0x4, 0x02]); - assert_option_parses!(TcpOption::SackRange([Some((500, 1500)), None, None]), - &[0x05, 0x0a, - 0x00, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x05, 0xdc]); - assert_option_parses!(TcpOption::SackRange([Some((875, 1225)), Some((1500, 2500)), None]), - &[0x05, 0x12, - 0x00, 0x00, 0x03, 0x6b, 0x00, 0x00, 0x04, 0xc9, - 0x00, 0x00, 0x05, 0xdc, 0x00, 0x00, 0x09, 0xc4]); - assert_option_parses!(TcpOption::SackRange([Some((875000, 1225000)), - Some((1500000, 2500000)), - Some((876543210, 876654320))]), - &[0x05, 0x1a, - 0x00, 0x0d, 0x59, 0xf8, 0x00, 0x12, 0xb1, 0x28, - 0x00, 0x16, 0xe3, 0x60, 0x00, 0x26, 0x25, 0xa0, - 0x34, 0x3e, 0xfc, 0xea, 0x34, 0x40, 0xae, 0xf0]); - assert_option_parses!(TcpOption::Unknown { kind: 12, data: &[1, 2, 3][..] }, - &[0x0c, 0x05, 0x01, 0x02, 0x03]) + assert_option_parses!(TcpOption::EndOfList, &[0x00]); + assert_option_parses!(TcpOption::NoOperation, &[0x01]); + assert_option_parses!(TcpOption::MaxSegmentSize(1500), &[0x02, 0x04, 0x05, 0xdc]); + assert_option_parses!(TcpOption::WindowScale(12), &[0x03, 0x03, 0x0c]); + assert_option_parses!(TcpOption::SackPermitted, &[0x4, 0x02]); + assert_option_parses!( + TcpOption::SackRange([Some((500, 1500)), None, None]), + &[0x05, 0x0a, 0x00, 0x00, 0x01, 0xf4, 0x00, 0x00, 0x05, 0xdc] + ); + assert_option_parses!( + TcpOption::SackRange([Some((875, 1225)), Some((1500, 2500)), None]), + &[ + 0x05, 0x12, 0x00, 0x00, 0x03, 0x6b, 0x00, 0x00, 0x04, 0xc9, 0x00, 0x00, 0x05, 0xdc, + 0x00, 0x00, 0x09, 0xc4 + ] + ); + assert_option_parses!( + TcpOption::SackRange([ + Some((875000, 1225000)), + Some((1500000, 2500000)), + Some((876543210, 876654320)) + ]), + &[ + 0x05, 0x1a, 0x00, 0x0d, 0x59, 0xf8, 0x00, 0x12, 0xb1, 0x28, 0x00, 0x16, 0xe3, 0x60, + 0x00, 0x26, 0x25, 0xa0, 0x34, 0x3e, 0xfc, 0xea, 0x34, 0x40, 0xae, 0xf0 + ] + ); + assert_option_parses!( + TcpOption::Unknown { + kind: 12, + data: &[1, 2, 3][..] + }, + &[0x0c, 0x05, 0x01, 0x02, 0x03] + ) } #[test] fn test_malformed_tcp_options() { - assert_eq!(TcpOption::parse(&[]), - Err(Error::Truncated)); - assert_eq!(TcpOption::parse(&[0xc]), - Err(Error::Truncated)); - assert_eq!(TcpOption::parse(&[0xc, 0x05, 0x01, 0x02]), - Err(Error::Truncated)); - assert_eq!(TcpOption::parse(&[0xc, 0x01]), - Err(Error::Truncated)); - assert_eq!(TcpOption::parse(&[0x2, 0x02]), - Err(Error::Malformed)); - assert_eq!(TcpOption::parse(&[0x3, 0x02]), - Err(Error::Malformed)); + assert_eq!(TcpOption::parse(&[]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc, 0x05, 0x01, 0x02]), Err(Error)); + assert_eq!(TcpOption::parse(&[0xc, 0x01]), Err(Error)); + assert_eq!(TcpOption::parse(&[0x2, 0x02]), Err(Error)); + assert_eq!(TcpOption::parse(&[0x3, 0x02]), Err(Error)); } } diff --git a/src/wire/udp.rs b/src/wire/udp.rs index b309f6309..77f9f84b3 100644 --- a/src/wire/udp.rs +++ b/src/wire/udp.rs @@ -1,35 +1,38 @@ -use core::fmt; use byteorder::{ByteOrder, NetworkEndian}; +use core::fmt; -use {Error, Result}; -use phy::ChecksumCapabilities; -use super::{IpProtocol, IpAddress}; -use super::ip::checksum; +use super::{Error, Result}; +use crate::phy::ChecksumCapabilities; +use crate::wire::ip::checksum; +use crate::wire::{IpAddress, IpProtocol}; /// A read/write wrapper around an User Datagram Protocol packet buffer. -#[derive(Debug, PartialEq, Clone)] +#[derive(Debug, PartialEq, Eq, Clone)] pub struct Packet> { - buffer: T + buffer: T, } mod field { #![allow(non_snake_case)] - use wire::field::*; + use crate::wire::field::*; pub const SRC_PORT: Field = 0..2; pub const DST_PORT: Field = 2..4; - pub const LENGTH: Field = 4..6; + pub const LENGTH: Field = 4..6; pub const CHECKSUM: Field = 6..8; - pub fn PAYLOAD(length: u16) -> Field { + pub const fn PAYLOAD(length: u16) -> Field { CHECKSUM.end..(length as usize) } } +pub const HEADER_LEN: usize = field::CHECKSUM.end; + +#[allow(clippy::len_without_is_empty)] impl> Packet { /// Imbue a raw octet buffer with UDP packet structure. - pub fn new_unchecked(buffer: T) -> Packet { + pub const fn new_unchecked(buffer: T) -> Packet { Packet { buffer } } @@ -44,8 +47,8 @@ impl> Packet { } /// Ensure that no accessor method will panic if called. - /// Returns `Err(Error::Truncated)` if the buffer is too short. - /// Returns `Err(Error::Malformed)` if the length field has a value smaller + /// Returns `Err(Error)` if the buffer is too short. + /// Returns `Err(Error)` if the length field has a value smaller /// than the header length. /// /// The result of this check is invalidated by calling [set_len]. @@ -53,14 +56,12 @@ impl> Packet { /// [set_len]: #method.set_len pub fn check_len(&self) -> Result<()> { let buffer_len = self.buffer.as_ref().len(); - if buffer_len < field::CHECKSUM.end { - Err(Error::Truncated) + if buffer_len < HEADER_LEN { + Err(Error) } else { let field_len = self.len() as usize; - if buffer_len < field_len { - Err(Error::Truncated) - } else if field_len < field::CHECKSUM.end { - Err(Error::Malformed) + if buffer_len < field_len || field_len < HEADER_LEN { + Err(Error) } else { Ok(()) } @@ -109,13 +110,22 @@ impl> Packet { /// # Fuzzing /// This function always returns `true` when fuzzing. pub fn verify_checksum(&self, src_addr: &IpAddress, dst_addr: &IpAddress) -> bool { - if cfg!(fuzzing) { return true } + if cfg!(fuzzing) { + return true; + } + + // From the RFC: + // > An all zero transmitted checksum value means that the transmitter + // > generated no checksum (for debugging or for higher level protocols + // > that don't care). + if self.checksum() == 0 { + return true; + } let data = self.buffer.as_ref(); checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, - self.len() as u32), - checksum::data(&data[..self.len() as usize]) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32), + checksum::data(&data[..self.len() as usize]), ]) == !0 } } @@ -169,9 +179,8 @@ impl + AsMut<[u8]>> Packet { let checksum = { let data = self.buffer.as_ref(); !checksum::combine(&[ - checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, - self.len() as u32), - checksum::data(&data[..self.len() as usize]) + checksum::pseudo_header(src_addr, dst_addr, IpProtocol::Udp, self.len() as u32), + checksum::data(&data[..self.len() as usize]), ]) }; // UDP checksum value of 0 means no checksum; if the checksum really is zero, @@ -198,54 +207,78 @@ impl> AsRef<[u8]> for Packet { /// A high-level representation of an User Datagram Protocol packet. #[derive(Debug, PartialEq, Eq, Clone, Copy)] -pub struct Repr<'a> { +pub struct Repr { pub src_port: u16, pub dst_port: u16, - pub payload: &'a [u8] } -impl<'a> Repr<'a> { +impl Repr { /// Parse an User Datagram Protocol packet and return a high-level representation. - pub fn parse(packet: &Packet<&'a T>, src_addr: &IpAddress, dst_addr: &IpAddress, - checksum_caps: &ChecksumCapabilities) -> Result> - where T: AsRef<[u8]> + ?Sized { + pub fn parse( + packet: &Packet<&T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + checksum_caps: &ChecksumCapabilities, + ) -> Result + where + T: AsRef<[u8]> + ?Sized, + { // Destination port cannot be omitted (but source port can be). - if packet.dst_port() == 0 { return Err(Error::Malformed) } + if packet.dst_port() == 0 { + return Err(Error); + } // Valid checksum is expected... if checksum_caps.udp.rx() && !packet.verify_checksum(src_addr, dst_addr) { match (src_addr, dst_addr) { // ... except on UDP-over-IPv4, where it can be omitted. #[cfg(feature = "proto-ipv4")] - (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) - if packet.checksum() == 0 => (), - _ => { - return Err(Error::Checksum) - } + (&IpAddress::Ipv4(_), &IpAddress::Ipv4(_)) if packet.checksum() == 0 => (), + _ => return Err(Error), } } Ok(Repr { src_port: packet.src_port(), dst_port: packet.dst_port(), - payload: packet.payload() }) } - /// Return the length of a packet that will be emitted from this high-level representation. - pub fn buffer_len(&self) -> usize { - field::CHECKSUM.end + self.payload.len() + /// Return the length of the packet header that will be emitted from this high-level representation. + pub const fn header_len(&self) -> usize { + HEADER_LEN + } + + /// Emit a high-level representation into an User Datagram Protocol packet. + /// + /// This never calculates the checksum, and is intended for internal-use only, + /// not for packets that are going to be actually sent over the network. For + /// example, when decompressing 6lowpan. + pub(crate) fn emit_header(&self, packet: &mut Packet<&mut T>, payload_len: usize) + where + T: AsRef<[u8]> + AsMut<[u8]>, + { + packet.set_src_port(self.src_port); + packet.set_dst_port(self.dst_port); + packet.set_len((HEADER_LEN + payload_len) as u16); + packet.set_checksum(0); } /// Emit a high-level representation into an User Datagram Protocol packet. - pub fn emit(&self, packet: &mut Packet<&mut T>, - src_addr: &IpAddress, - dst_addr: &IpAddress, - checksum_caps: &ChecksumCapabilities) - where T: AsRef<[u8]> + AsMut<[u8]> { + pub fn emit( + &self, + packet: &mut Packet<&mut T>, + src_addr: &IpAddress, + dst_addr: &IpAddress, + payload_len: usize, + emit_payload: impl FnOnce(&mut [u8]), + checksum_caps: &ChecksumCapabilities, + ) where + T: AsRef<[u8]> + AsMut<[u8]>, + { packet.set_src_port(self.src_port); packet.set_dst_port(self.dst_port); - packet.set_len((field::CHECKSUM.end + self.payload.len()) as u16); - packet.payload_mut().copy_from_slice(self.payload); + packet.set_len((HEADER_LEN + payload_len) as u16); + emit_payload(packet.payload_mut()); if checksum_caps.udp.tx() { packet.fill_checksum(src_addr, dst_addr) @@ -260,35 +293,63 @@ impl<'a> Repr<'a> { impl<'a, T: AsRef<[u8]> + ?Sized> fmt::Display for Packet<&'a T> { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { // Cannot use Repr::parse because we don't have the IP addresses. - write!(f, "UDP src={} dst={} len={}", - self.src_port(), self.dst_port(), self.payload().len()) + write!( + f, + "UDP src={} dst={} len={}", + self.src_port(), + self.dst_port(), + self.payload().len() + ) } } -impl<'a> fmt::Display for Repr<'a> { +#[cfg(feature = "defmt")] +impl<'a, T: AsRef<[u8]> + ?Sized> defmt::Format for Packet<&'a T> { + fn format(&self, fmt: defmt::Formatter) { + // Cannot use Repr::parse because we don't have the IP addresses. + defmt::write!( + fmt, + "UDP src={} dst={} len={}", + self.src_port(), + self.dst_port(), + self.payload().len() + ); + } +} + +impl fmt::Display for Repr { fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result { - write!(f, "UDP src={} dst={} len={}", - self.src_port, self.dst_port, self.payload.len()) + write!(f, "UDP src={} dst={}", self.src_port, self.dst_port) + } +} + +#[cfg(feature = "defmt")] +impl defmt::Format for Repr { + fn format(&self, fmt: defmt::Formatter) { + defmt::write!(fmt, "UDP src={} dst={}", self.src_port, self.dst_port); } } -use super::pretty_print::{PrettyPrint, PrettyIndent}; +use crate::wire::pretty_print::{PrettyIndent, PrettyPrint}; impl> PrettyPrint for Packet { - fn pretty_print(buffer: &dyn AsRef<[u8]>, f: &mut fmt::Formatter, - indent: &mut PrettyIndent) -> fmt::Result { + fn pretty_print( + buffer: &dyn AsRef<[u8]>, + f: &mut fmt::Formatter, + indent: &mut PrettyIndent, + ) -> fmt::Result { match Packet::new_checked(buffer) { - Err(err) => write!(f, "{}({})", indent, err), - Ok(packet) => write!(f, "{}{}", indent, packet) + Err(err) => write!(f, "{indent}({err})"), + Ok(packet) => write!(f, "{indent}{packet}"), } } } #[cfg(test)] mod test { - #[cfg(feature = "proto-ipv4")] - use wire::Ipv4Address; use super::*; + #[cfg(feature = "proto-ipv4")] + use crate::wire::Ipv4Address; #[cfg(feature = "proto-ipv4")] const SRC_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 1]); @@ -296,20 +357,17 @@ mod test { const DST_ADDR: Ipv4Address = Ipv4Address([192, 168, 1, 2]); #[cfg(feature = "proto-ipv4")] - static PACKET_BYTES: [u8; 12] = - [0xbf, 0x00, 0x00, 0x35, - 0x00, 0x0c, 0x12, 0x4d, - 0xaa, 0x00, 0x00, 0xff]; + static PACKET_BYTES: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x12, 0x4d, 0xaa, 0x00, 0x00, 0xff, + ]; #[cfg(feature = "proto-ipv4")] - static NO_CHECKSUM_PACKET: [u8; 12] = - [0xbf, 0x00, 0x00, 0x35, - 0x00, 0x0c, 0x00, 0x00, - 0xaa, 0x00, 0x00, 0xff]; + static NO_CHECKSUM_PACKET: [u8; 12] = [ + 0xbf, 0x00, 0x00, 0x35, 0x00, 0x0c, 0x00, 0x00, 0xaa, 0x00, 0x00, 0xff, + ]; #[cfg(feature = "proto-ipv4")] - static PAYLOAD_BYTES: [u8; 4] = - [0xaa, 0x00, 0x00, 0xff]; + static PAYLOAD_BYTES: [u8; 4] = [0xaa, 0x00, 0x00, 0xff]; #[test] #[cfg(feature = "proto-ipv4")] @@ -320,7 +378,7 @@ mod test { assert_eq!(packet.len(), 12); assert_eq!(packet.checksum(), 0x124d); assert_eq!(packet.payload(), &PAYLOAD_BYTES[..]); - assert_eq!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into()), true); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); } #[test] @@ -334,7 +392,7 @@ mod test { packet.set_checksum(0xffff); packet.payload_mut().copy_from_slice(&PAYLOAD_BYTES[..]); packet.fill_checksum(&SRC_ADDR.into(), &DST_ADDR.into()); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } #[test] @@ -342,7 +400,7 @@ mod test { let mut bytes = vec![0; 12]; let mut packet = Packet::new_unchecked(&mut bytes); packet.set_len(4); - assert_eq!(packet.check_len(), Err(Error::Malformed)); + assert_eq!(packet.check_len(), Err(Error)); } #[test] @@ -357,12 +415,23 @@ mod test { assert_eq!(packet.checksum(), 0xffff); } + #[test] + #[cfg(feature = "proto-ipv4")] + fn test_no_checksum() { + let mut bytes = vec![0; 8]; + let mut packet = Packet::new_unchecked(&mut bytes); + packet.set_src_port(1); + packet.set_dst_port(31881); + packet.set_len(8); + packet.set_checksum(0); + assert!(packet.verify_checksum(&SRC_ADDR.into(), &DST_ADDR.into())); + } + #[cfg(feature = "proto-ipv4")] - fn packet_repr() -> Repr<'static> { + fn packet_repr() -> Repr { Repr { src_port: 48896, dst_port: 53, - payload: &PAYLOAD_BYTES } } @@ -370,8 +439,13 @@ mod test { #[cfg(feature = "proto-ipv4")] fn test_parse() { let packet = Packet::new_unchecked(&PACKET_BYTES[..]); - let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(), - &ChecksumCapabilities::default()).unwrap(); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); assert_eq!(repr, packet_repr()); } @@ -379,19 +453,30 @@ mod test { #[cfg(feature = "proto-ipv4")] fn test_emit() { let repr = packet_repr(); - let mut bytes = vec![0xa5; repr.buffer_len()]; + let mut bytes = vec![0xa5; repr.header_len() + PAYLOAD_BYTES.len()]; let mut packet = Packet::new_unchecked(&mut bytes); - repr.emit(&mut packet, &SRC_ADDR.into(), &DST_ADDR.into(), - &ChecksumCapabilities::default()); - assert_eq!(&packet.into_inner()[..], &PACKET_BYTES[..]); + repr.emit( + &mut packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + PAYLOAD_BYTES.len(), + |payload| payload.copy_from_slice(&PAYLOAD_BYTES), + &ChecksumCapabilities::default(), + ); + assert_eq!(&*packet.into_inner(), &PACKET_BYTES[..]); } #[test] #[cfg(feature = "proto-ipv4")] fn test_checksum_omitted() { let packet = Packet::new_unchecked(&NO_CHECKSUM_PACKET[..]); - let repr = Repr::parse(&packet, &SRC_ADDR.into(), &DST_ADDR.into(), - &ChecksumCapabilities::default()).unwrap(); + let repr = Repr::parse( + &packet, + &SRC_ADDR.into(), + &DST_ADDR.into(), + &ChecksumCapabilities::default(), + ) + .unwrap(); assert_eq!(repr, packet_repr()); } } diff --git a/utils/packet2pcap.rs b/utils/packet2pcap.rs index b82118678..7d06c6f16 100644 --- a/utils/packet2pcap.rs +++ b/utils/packet2pcap.rs @@ -1,34 +1,30 @@ -extern crate smoltcp; -extern crate getopts; - -use std::cell::RefCell; -use std::io::{self, Read, Write}; -use std::path::Path; -use std::fs::File; -use std::env; -use std::process::exit; +use getopts::Options; use smoltcp::phy::{PcapLinkType, PcapSink}; use smoltcp::time::Instant; -use getopts::Options; +use std::env; +use std::fs::File; +use std::io::{self, Read}; +use std::path::Path; +use std::process::exit; -fn convert(packet_filename: &Path, pcap_filename: &Path, link_type: PcapLinkType) - -> io::Result<()> { +fn convert( + packet_filename: &Path, + pcap_filename: &Path, + link_type: PcapLinkType, +) -> io::Result<()> { let mut packet_file = File::open(packet_filename)?; let mut packet = Vec::new(); packet_file.read_to_end(&mut packet)?; - let pcap = RefCell::new(Vec::new()); - PcapSink::global_header(&pcap, link_type); - PcapSink::packet(&pcap, Instant::from_millis(0), &packet[..]); - let mut pcap_file = File::create(pcap_filename)?; - pcap_file.write_all(&pcap.borrow()[..])?; + PcapSink::global_header(&mut pcap_file, link_type); + PcapSink::packet(&mut pcap_file, Instant::from_millis(0), &packet[..]); Ok(()) } fn print_usage(program: &str, opts: Options) { - let brief = format!("Usage: {} [options] INPUT OUTPUT", program); + let brief = format!("Usage: {program} [options] INPUT OUTPUT"); print!("{}", opts.usage(&brief)); } @@ -38,34 +34,40 @@ fn main() { let mut opts = Options::new(); opts.optflag("h", "help", "print this help menu"); - opts.optopt("t", "link-type", "set link type (one of: ethernet ip)", "TYPE"); + opts.optopt( + "t", + "link-type", + "set link type (one of: ethernet ip)", + "TYPE", + ); let matches = match opts.parse(&args[1..]) { Ok(m) => m, Err(e) => { - eprintln!("{}", e); - return + eprintln!("{e}"); + return; } }; - let link_type = - match matches.opt_str("t").as_ref().map(|s| &s[..]) { - Some("ethernet") => Some(PcapLinkType::Ethernet), - Some("ip") => Some(PcapLinkType::Ip), - _ => None - }; + let link_type = match matches.opt_str("t").as_ref().map(|s| &s[..]) { + Some("ethernet") => Some(PcapLinkType::Ethernet), + Some("ip") => Some(PcapLinkType::Ip), + _ => None, + }; if matches.opt_present("h") || matches.free.len() != 2 || link_type.is_none() { print_usage(&program, opts); - return + return; } - match convert(Path::new(&matches.free[0]), - Path::new(&matches.free[1]), - link_type.unwrap()) { + match convert( + Path::new(&matches.free[0]), + Path::new(&matches.free[1]), + link_type.unwrap(), + ) { Ok(()) => (), Err(e) => { - eprintln!("Cannot convert packet to pcap: {}", e); + eprintln!("Cannot convert packet to pcap: {e}"); exit(1); } }