|
| 1 | +import pytest |
| 2 | +import copy |
| 3 | +import os |
| 4 | +from enum import Enum, auto |
| 5 | + |
| 6 | +from configuration import available_ports |
| 7 | +from common import ProviderOptions, Protocols, random_str |
| 8 | +from fixtures import managed_process # lgtm [py/unused-import] |
| 9 | +from providers import Provider, S2N |
| 10 | +from utils import invalid_test_parameters, get_parameter_name, to_bytes |
| 11 | + |
| 12 | +SERVER_STATE_FILE = 'server_state' |
| 13 | +CLIENT_STATE_FILE = 'client_state' |
| 14 | + |
| 15 | +SERVER_DATA = f"Some random data from the server:" + random_str(10) |
| 16 | +CLIENT_DATA = f"Some random data from the client:" + random_str(10) |
| 17 | + |
| 18 | + |
| 19 | +class MainlineRole(Enum): |
| 20 | + Serialize = auto() |
| 21 | + Deserialize = auto() |
| 22 | + |
| 23 | + |
| 24 | +class Mode(Enum): |
| 25 | + Server = auto() |
| 26 | + Client = auto() |
| 27 | + |
| 28 | + |
| 29 | +""" |
| 30 | +This test file checks that a serialized connection can be deserialized by an older version of |
| 31 | +s2n-tls and vice versa. This ensures that any future changes we make to the handshake are backwards-compatible |
| 32 | +with an older version of s2n-tls. |
| 33 | +
|
| 34 | +This feature requires an uninterrupted TCP connection with the peer in-between serialization and |
| 35 | +deserialization. Our integration test setup can't easily provide that while also using two different |
| 36 | +s2n-tls versions. To get around that we do a hack and serialize/deserialize both peers in the TLS connection. |
| 37 | +This prevents one peer from receiving a TCP FIN message and shutting the connection down early. |
| 38 | +""" |
| 39 | + |
| 40 | + |
| 41 | +@pytest.mark.uncollect_if(func=invalid_test_parameters) |
| 42 | +@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name) |
| 43 | +@pytest.mark.parametrize("mainline_role", [MainlineRole.Serialize, MainlineRole.Deserialize], ids=get_parameter_name) |
| 44 | +@pytest.mark.parametrize("version_change", [Mode.Server, Mode.Client], ids=get_parameter_name) |
| 45 | +def test_server_serialization_backwards_compat(managed_process, tmp_path, protocol, mainline_role, version_change): |
| 46 | + server_state_file = str(tmp_path / SERVER_STATE_FILE) |
| 47 | + client_state_file = str(tmp_path / CLIENT_STATE_FILE) |
| 48 | + assert not os.path.exists(server_state_file) |
| 49 | + assert not os.path.exists(client_state_file) |
| 50 | + |
| 51 | + options = ProviderOptions( |
| 52 | + port=next(available_ports), |
| 53 | + protocol=protocol, |
| 54 | + insecure=True, |
| 55 | + ) |
| 56 | + |
| 57 | + client_options = copy.copy(options) |
| 58 | + client_options.mode = Provider.ClientMode |
| 59 | + client_options.extra_flags = ['--serialize-out', client_state_file] |
| 60 | + |
| 61 | + server_options = copy.copy(options) |
| 62 | + server_options.mode = Provider.ServerMode |
| 63 | + server_options.extra_flags = ['--serialize-out', server_state_file] |
| 64 | + |
| 65 | + if mainline_role is MainlineRole.Serialize: |
| 66 | + if version_change == Mode.Server: |
| 67 | + server_options.use_mainline_version = True |
| 68 | + else: |
| 69 | + client_options.use_mainline_version = True |
| 70 | + |
| 71 | + server = managed_process( |
| 72 | + S2N, server_options, send_marker=S2N.get_send_marker()) |
| 73 | + client = managed_process(S2N, client_options, send_marker=S2N.get_send_marker()) |
| 74 | + |
| 75 | + for results in client.get_results(): |
| 76 | + results.assert_success() |
| 77 | + assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout |
| 78 | + |
| 79 | + for results in server.get_results(): |
| 80 | + results.assert_success() |
| 81 | + assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout |
| 82 | + |
| 83 | + assert os.path.exists(server_state_file) |
| 84 | + assert os.path.exists(client_state_file) |
| 85 | + |
| 86 | + client_options.extra_flags = ['--deserialize-in', client_state_file] |
| 87 | + server_options.extra_flags = ['--deserialize-in', server_state_file] |
| 88 | + if mainline_role is MainlineRole.Deserialize: |
| 89 | + if version_change == Mode.Server: |
| 90 | + server_options.use_mainline_version = True |
| 91 | + else: |
| 92 | + client_options.use_mainline_version = True |
| 93 | + |
| 94 | + server_options.data_to_send = SERVER_DATA.encode() |
| 95 | + client_options.data_to_send = CLIENT_DATA.encode() |
| 96 | + |
| 97 | + server = managed_process(S2N, server_options, send_marker=CLIENT_DATA) |
| 98 | + client = managed_process(S2N, client_options, send_marker="Connected to localhost", close_marker=SERVER_DATA) |
| 99 | + |
| 100 | + for results in server.get_results(): |
| 101 | + results.assert_success() |
| 102 | + # No protocol version printout since deserialization means skipping the handshake |
| 103 | + assert to_bytes("Actual protocol version:") not in results.stdout |
| 104 | + assert CLIENT_DATA.encode() in results.stdout |
| 105 | + |
| 106 | + for results in client.get_results(): |
| 107 | + results.assert_success() |
| 108 | + assert to_bytes("Actual protocol version:") not in results.stdout |
| 109 | + assert SERVER_DATA.encode() in results.stdout |
0 commit comments