Skip to content

Commit 61263b0

Browse files
authored
test: backwards compatibility test for the serialization feature (#4548)
1 parent 68829bf commit 61263b0

File tree

3 files changed

+118
-10
lines changed

3 files changed

+118
-10
lines changed

tests/integrationv2/common.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,8 @@
33
import subprocess
44
import threading
55
import itertools
6-
6+
import random
7+
import string
78

89
from constants import TEST_CERT_DIRECTORY
910
from global_flags import get_flag, S2N_PROVIDER_VERSION
@@ -29,6 +30,10 @@ def data_bytes(n_bytes):
2930
return bytes(byte_array)
3031

3132

33+
def random_str(n):
34+
return "".join(random.choice(string.ascii_uppercase + string.digits) for _ in range(n))
35+
36+
3237
def pq_enabled():
3338
"""
3439
Returns true or false to indicate whether PQ crypto is enabled in s2n

tests/integrationv2/test_key_update.py

Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,14 @@
11
import copy
2-
import random
3-
import string
42
import pytest
53

64
from configuration import available_ports, TLS13_CIPHERS
7-
from common import ProviderOptions, Protocols
5+
from common import ProviderOptions, Protocols, random_str
86
from fixtures import managed_process # lgtm [py/unused-import]
97
from providers import Provider, S2N, OpenSSL
108
from utils import invalid_test_parameters, get_parameter_name
119

12-
SERVER_DATA = f"Some random data from the server:" + "".join(
13-
random.choice(string.ascii_uppercase + string.digits) for _ in range(10)
14-
)
15-
CLIENT_DATA = f"Some random data from the client:" + "".join(
16-
random.choice(string.ascii_uppercase + string.digits) for _ in range(10)
17-
)
10+
SERVER_DATA = f"Some random data from the server:" + random_str(10)
11+
CLIENT_DATA = f"Some random data from the client:" + random_str(10)
1812

1913

2014
def test_nothing():
Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import pytest
2+
import copy
3+
import os
4+
from enum import Enum, auto
5+
6+
from configuration import available_ports
7+
from common import ProviderOptions, Protocols, random_str
8+
from fixtures import managed_process # lgtm [py/unused-import]
9+
from providers import Provider, S2N
10+
from utils import invalid_test_parameters, get_parameter_name, to_bytes
11+
12+
SERVER_STATE_FILE = 'server_state'
13+
CLIENT_STATE_FILE = 'client_state'
14+
15+
SERVER_DATA = f"Some random data from the server:" + random_str(10)
16+
CLIENT_DATA = f"Some random data from the client:" + random_str(10)
17+
18+
19+
class MainlineRole(Enum):
20+
Serialize = auto()
21+
Deserialize = auto()
22+
23+
24+
class Mode(Enum):
25+
Server = auto()
26+
Client = auto()
27+
28+
29+
"""
30+
This test file checks that a serialized connection can be deserialized by an older version of
31+
s2n-tls and vice versa. This ensures that any future changes we make to the handshake are backwards-compatible
32+
with an older version of s2n-tls.
33+
34+
This feature requires an uninterrupted TCP connection with the peer in-between serialization and
35+
deserialization. Our integration test setup can't easily provide that while also using two different
36+
s2n-tls versions. To get around that we do a hack and serialize/deserialize both peers in the TLS connection.
37+
This prevents one peer from receiving a TCP FIN message and shutting the connection down early.
38+
"""
39+
40+
41+
@pytest.mark.uncollect_if(func=invalid_test_parameters)
42+
@pytest.mark.parametrize("protocol", [Protocols.TLS13, Protocols.TLS12], ids=get_parameter_name)
43+
@pytest.mark.parametrize("mainline_role", [MainlineRole.Serialize, MainlineRole.Deserialize], ids=get_parameter_name)
44+
@pytest.mark.parametrize("version_change", [Mode.Server, Mode.Client], ids=get_parameter_name)
45+
def test_server_serialization_backwards_compat(managed_process, tmp_path, protocol, mainline_role, version_change):
46+
server_state_file = str(tmp_path / SERVER_STATE_FILE)
47+
client_state_file = str(tmp_path / CLIENT_STATE_FILE)
48+
assert not os.path.exists(server_state_file)
49+
assert not os.path.exists(client_state_file)
50+
51+
options = ProviderOptions(
52+
port=next(available_ports),
53+
protocol=protocol,
54+
insecure=True,
55+
)
56+
57+
client_options = copy.copy(options)
58+
client_options.mode = Provider.ClientMode
59+
client_options.extra_flags = ['--serialize-out', client_state_file]
60+
61+
server_options = copy.copy(options)
62+
server_options.mode = Provider.ServerMode
63+
server_options.extra_flags = ['--serialize-out', server_state_file]
64+
65+
if mainline_role is MainlineRole.Serialize:
66+
if version_change == Mode.Server:
67+
server_options.use_mainline_version = True
68+
else:
69+
client_options.use_mainline_version = True
70+
71+
server = managed_process(
72+
S2N, server_options, send_marker=S2N.get_send_marker())
73+
client = managed_process(S2N, client_options, send_marker=S2N.get_send_marker())
74+
75+
for results in client.get_results():
76+
results.assert_success()
77+
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout
78+
79+
for results in server.get_results():
80+
results.assert_success()
81+
assert to_bytes("Actual protocol version: {}".format(protocol.value)) in results.stdout
82+
83+
assert os.path.exists(server_state_file)
84+
assert os.path.exists(client_state_file)
85+
86+
client_options.extra_flags = ['--deserialize-in', client_state_file]
87+
server_options.extra_flags = ['--deserialize-in', server_state_file]
88+
if mainline_role is MainlineRole.Deserialize:
89+
if version_change == Mode.Server:
90+
server_options.use_mainline_version = True
91+
else:
92+
client_options.use_mainline_version = True
93+
94+
server_options.data_to_send = SERVER_DATA.encode()
95+
client_options.data_to_send = CLIENT_DATA.encode()
96+
97+
server = managed_process(S2N, server_options, send_marker=CLIENT_DATA)
98+
client = managed_process(S2N, client_options, send_marker="Connected to localhost", close_marker=SERVER_DATA)
99+
100+
for results in server.get_results():
101+
results.assert_success()
102+
# No protocol version printout since deserialization means skipping the handshake
103+
assert to_bytes("Actual protocol version:") not in results.stdout
104+
assert CLIENT_DATA.encode() in results.stdout
105+
106+
for results in client.get_results():
107+
results.assert_success()
108+
assert to_bytes("Actual protocol version:") not in results.stdout
109+
assert SERVER_DATA.encode() in results.stdout

0 commit comments

Comments
 (0)