Skip to content

Commit 4b374f6

Browse files
authored
fix: fine grained lock in smoltcp (#265)
* Copy smoltcp v0.11.0 into awkernel Signed-off-by: koichiimai <kotty.0704@gmail.com> * build succeed without warnings&errors Signed-off-by: koichiimai <kotty.0704@gmail.com> * apply clippy Signed-off-by: koichiimai <kotty.0704@gmail.com> * apply cargo fmt Signed-off-by: koichiimai <kotty.0704@gmail.com> * WIP Signed-off-by: koichiimai <kotty.0704@gmail.com> * socket level lock, build succeed Signed-off-by: Koichi Imai <koichi.imai.2@tier4.jp> * eliminate the warning & apply clippy Signed-off-by: Koichi Imai <koichi.imai.2@tier4.jp> * delete unnecessary files Signed-off-by: Koichi Imai <koichi.imai.2@tier4.jp> * avoid deadlock in socket_meta.rs Signed-off-by: Koichi Imai <koichi.imai.2@tier4.jp> * avoid deadlock in awkernel_lib/src/net/udp_socket.rs Signed-off-by: koichiimai <kotty.0704@gmail.com> * avoid deadlock in awkernel_lib/src/net/tcp_*.rs Signed-off-by: koichiimai <kotty.0704@gmail.com> --------- Signed-off-by: koichiimai <kotty.0704@gmail.com> Signed-off-by: Koichi Imai <koichi.imai.2@tier4.jp>
1 parent 5901d2c commit 4b374f6

File tree

14 files changed

+1372
-1265
lines changed

14 files changed

+1372
-1265
lines changed

awkernel_lib/src/net/if_net.rs

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ use smoltcp::{
2727
wire::HardwareAddress,
2828
};
2929

30-
use crate::sync::{mcs::MCSNode, mutex::Mutex};
30+
use crate::sync::{mcs::MCSNode, mutex::Mutex, rwlock::RwLock};
3131

3232
use super::{
3333
ether::{extract_headers, NetworkHdr, TransportHdr, ETHER_ADDR_LEN},
@@ -163,6 +163,7 @@ impl Device for NetDriverRef<'_> {
163163
pub(super) struct IfNet {
164164
vlan: Option<u16>,
165165
pub(super) inner: Mutex<IfNetInner>,
166+
pub(super) socket_set: RwLock<SocketSet<'static>>,
166167
rx_irq_to_driver: BTreeMap<u16, NetDriver>,
167168
tx_only_ringq: Vec<Mutex<RingQ<Vec<u8>>>>,
168169
pub(super) net_device: Arc<dyn NetDevice + Sync + Send>,
@@ -173,7 +174,6 @@ pub(super) struct IfNet {
173174

174175
pub(super) struct IfNetInner {
175176
pub(super) interface: Interface,
176-
pub(super) socket_set: SocketSet<'static>,
177177
pub(super) default_gateway_ipv4: Option<smoltcp::wire::Ipv4Address>,
178178

179179
multicast_addr_ipv4: BTreeSet<Ipv4Addr>,
@@ -182,8 +182,8 @@ pub(super) struct IfNetInner {
182182

183183
impl IfNetInner {
184184
#[inline(always)]
185-
pub fn split(&mut self) -> (&mut Interface, &mut SocketSet<'static>) {
186-
(&mut self.interface, &mut self.socket_set)
185+
pub fn get_interface(&mut self) -> &mut Interface {
186+
&mut self.interface
187187
}
188188

189189
#[inline(always)]
@@ -277,11 +277,11 @@ impl IfNet {
277277
vlan,
278278
inner: Mutex::new(IfNetInner {
279279
interface,
280-
socket_set,
281280
default_gateway_ipv4: None,
282281
multicast_addr_ipv4: BTreeSet::new(),
283282
multicast_addr_mac: BTreeMap::new(),
284283
}),
284+
socket_set: RwLock::new(socket_set),
285285
rx_irq_to_driver,
286286
net_device,
287287
tx_only_ringq,
@@ -488,8 +488,8 @@ impl IfNet {
488488
let mut node = MCSNode::new();
489489
let mut inner = self.inner.lock(&mut node);
490490

491-
let (interface, socket_set) = inner.split();
492-
interface.poll(timestamp, &mut device_ref, socket_set)
491+
let interface = inner.get_interface();
492+
interface.poll(timestamp, &mut device_ref, &self.socket_set)
493493
};
494494

495495
// send packets from the queue.
@@ -547,9 +547,8 @@ impl IfNet {
547547
let mut node = MCSNode::new();
548548
let mut inner = self.inner.lock(&mut node);
549549

550-
let (interface, socket_set) = inner.split();
551-
552-
interface.poll(timestamp, &mut device_ref, socket_set)
550+
let interface = inner.get_interface();
551+
interface.poll(timestamp, &mut device_ref, &self.socket_set)
553552
};
554553

555554
// send packets from the queue.

awkernel_lib/src/net/tcp_listener.rs

Lines changed: 46 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -76,12 +76,7 @@ impl TcpListener {
7676
// Create a TCP socket.
7777
let socket = create_listen_socket(&addr, port.port(), rx_buffer_size, tx_buffer_size);
7878

79-
let handle = {
80-
let mut node = MCSNode::new();
81-
let mut if_net_inner = if_net.inner.lock(&mut node);
82-
83-
if_net_inner.socket_set.add(socket)
84-
};
79+
let handle = if_net.socket_set.write().add(socket);
8580

8681
handles.push(handle);
8782
}
@@ -128,43 +123,54 @@ impl TcpListener {
128123
let if_net = if_net.clone();
129124
drop(net_manager);
130125

131-
let mut node = MCSNode::new();
132-
let mut interface = if_net.inner.lock(&mut node);
133-
134126
for handle in self.handles.iter_mut() {
135-
let socket: &mut smoltcp::socket::tcp::Socket = interface.socket_set.get_mut(*handle);
136-
if socket.may_send() {
127+
let (may_send, is_not_open) = {
128+
let socket_set = if_net.socket_set.read();
129+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
130+
let socket = socket_set
131+
.get::<smoltcp::socket::tcp::Socket>(*handle)
132+
.lock(&mut node);
133+
if socket.may_send() {
134+
(true, false)
135+
} else if !socket.is_open() {
136+
(false, true)
137+
} else {
138+
(false, false)
139+
}
140+
};
141+
142+
if may_send {
137143
// If the connection is established, create a new socket and add it to the interface.
138144
let new_socket = create_listen_socket(
139145
&self.addr,
140146
self.port.port(),
141147
self.rx_buffer_size,
142148
self.tx_buffer_size,
143149
);
144-
let mut new_handle = interface.socket_set.add(new_socket);
150+
let mut socket_set = if_net.socket_set.write();
151+
let mut new_handle = socket_set.add(new_socket);
145152

146153
// Swap the new handle with the old handle.
147154
core::mem::swap(handle, &mut new_handle);
148155

149156
// The old handle is now a connected socket.
150157
self.connected_sockets.push_back(new_handle);
151-
} else if !socket.is_open() {
158+
} else if is_not_open {
152159
// If the socket is closed, create a new socket and add it to the interface.
153160
let new_socket = create_listen_socket(
154161
&self.addr,
155162
self.port.port(),
156163
self.rx_buffer_size,
157164
self.tx_buffer_size,
158165
);
159-
interface.socket_set.remove(*handle);
160-
*handle = interface.socket_set.add(new_socket);
166+
let mut socket_set = if_net.socket_set.write();
167+
socket_set.remove(*handle);
168+
*handle = socket_set.add(new_socket);
161169
}
162170
}
163171

164172
// If there is a connected socket, return it.
165173
if let Some(handle) = self.connected_sockets.pop_front() {
166-
drop(interface);
167-
168174
let port = {
169175
let mut net_manager = NET_MANAGER.write();
170176
if self.addr.is_ipv4() {
@@ -183,14 +189,16 @@ impl TcpListener {
183189
}));
184190
}
185191

192+
let socket_set = if_net.socket_set.read();
186193
// Register the waker for the listening sockets.
187194
for handle in self.handles.iter() {
188-
let socket: &mut smoltcp::socket::tcp::Socket = interface.socket_set.get_mut(*handle);
195+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
196+
let mut socket = socket_set
197+
.get::<smoltcp::socket::tcp::Socket>(*handle)
198+
.lock(&mut node);
189199
socket.register_send_waker(waker);
190200
}
191201

192-
drop(interface);
193-
194202
Ok(None)
195203
}
196204
}
@@ -203,23 +211,28 @@ impl Drop for TcpListener {
203211
let if_net = if_net.clone();
204212
drop(net_manager);
205213

206-
let mut node = MCSNode::new();
207-
let mut inner = if_net.inner.lock(&mut node);
214+
{
215+
let socket_set = if_net.socket_set.read();
208216

209-
// Close listening sockets.
210-
for handle in self.handles.iter() {
211-
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(*handle);
212-
socket.abort();
213-
}
217+
// Close listening sockets.
218+
for handle in self.handles.iter() {
219+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
220+
let mut socket = socket_set
221+
.get::<smoltcp::socket::tcp::Socket>(*handle)
222+
.lock(&mut node);
223+
socket.abort();
224+
}
214225

215-
// Close connected sockets.
216-
for handle in self.connected_sockets.iter() {
217-
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(*handle);
218-
socket.abort();
226+
// Close connected sockets.
227+
for handle in self.connected_sockets.iter() {
228+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
229+
let mut socket = socket_set
230+
.get::<smoltcp::socket::tcp::Socket>(*handle)
231+
.lock(&mut node);
232+
socket.abort();
233+
}
219234
}
220235

221-
drop(inner);
222-
223236
let que_id = crate::cpu::raw_cpu_id() & (if_net.net_device.num_queues() - 1);
224237
if_net.poll_tx_only(que_id);
225238
}

awkernel_lib/src/net/tcp_stream.rs

Lines changed: 65 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -41,19 +41,29 @@ impl Drop for TcpStream {
4141
drop(net_manager);
4242

4343
{
44-
let mut node = MCSNode::new();
45-
let mut inner = if_net.inner.lock(&mut node);
44+
let socket_set = if_net.socket_set.read();
45+
let closed = {
46+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
47+
let socket = socket_set
48+
.get::<smoltcp::socket::tcp::Socket>(self.handle)
49+
.lock(&mut node);
4650

47-
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
51+
matches!(socket.state(), smoltcp::socket::tcp::State::Closed)
52+
};
4853

4954
// If the socket is already closed, remove it from the socket set.
50-
if matches!(socket.state(), smoltcp::socket::tcp::State::Closed) {
51-
inner.socket_set.remove(self.handle);
52-
55+
if closed {
56+
drop(socket_set);
57+
let mut socket_set = if_net.socket_set.write();
58+
socket_set.remove(self.handle);
5359
return;
5460
}
5561

5662
// Otherwise, close the socket.
63+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
64+
let mut socket = socket_set
65+
.get::<smoltcp::socket::tcp::Socket>(self.handle)
66+
.lock(&mut node);
5767
socket.close();
5868
}
5969

@@ -98,16 +108,25 @@ pub fn close_connections() {
98108
let mut remain_v = VecDeque::new();
99109

100110
{
101-
let mut node = MCSNode::new();
102-
let mut inner = if_net.inner.lock(&mut node);
103-
104111
while let Some((handle, port)) = v.pop_front() {
105-
let socket: &mut smoltcp::socket::tcp::Socket =
106-
inner.socket_set.get_mut(handle);
107-
if socket.state() == smoltcp::socket::tcp::State::Closed {
112+
let socket_set = if_net.socket_set.read();
113+
let closed = {
114+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
115+
let socket = socket_set
116+
.get::<smoltcp::socket::tcp::Socket>(handle)
117+
.lock(&mut node);
118+
socket.state() == smoltcp::socket::tcp::State::Closed
119+
};
120+
if closed {
121+
drop(socket_set);
122+
let mut socket_set = if_net.socket_set.write();
108123
// If the socket is already closed, remove it from the socket set.
109-
inner.socket_set.remove(handle);
124+
socket_set.remove(handle);
110125
} else {
126+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
127+
let mut socket = socket_set
128+
.get::<smoltcp::socket::tcp::Socket>(handle)
129+
.lock(&mut node);
111130
socket.close();
112131
remain_v.push_back((handle, port));
113132
}
@@ -180,20 +199,27 @@ impl TcpStream {
180199
let mut node = MCSNode::new();
181200
let mut inner = if_net.inner.lock(&mut node);
182201

183-
let (interface, socket_set) = inner.split();
202+
let interface = inner.get_interface();
184203

204+
let mut socket_set = if_net.socket_set.write();
185205
handle = socket_set.add(socket);
186206

187-
let socket: &mut smoltcp::socket::tcp::Socket = socket_set.get_mut(handle);
188-
189-
if socket
190-
.connect(
191-
interface.context(),
192-
(remote_addr.addr, remote_port),
193-
local_port.port(),
194-
)
195-
.is_err()
196-
{
207+
let connect_is_err = {
208+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
209+
let mut socket = socket_set
210+
.get::<smoltcp::socket::tcp::Socket>(handle)
211+
.lock(&mut node);
212+
213+
socket
214+
.connect(
215+
interface.context(),
216+
(remote_addr.addr, remote_port),
217+
local_port.port(),
218+
)
219+
.is_err()
220+
};
221+
222+
if connect_is_err {
197223
socket_set.remove(handle);
198224
return Err(NetManagerError::InvalidState);
199225
}
@@ -227,10 +253,12 @@ impl TcpStream {
227253
let if_net = if_net.clone();
228254
drop(net_manager);
229255

230-
let mut node = MCSNode::new();
231-
let mut inner = if_net.inner.lock(&mut node);
256+
let socket_set = if_net.socket_set.read();
232257

233-
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
258+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
259+
let mut socket = socket_set
260+
.get::<smoltcp::socket::tcp::Socket>(self.handle)
261+
.lock(&mut node);
234262

235263
if socket.state() == smoltcp::socket::tcp::State::SynSent {
236264
socket.register_recv_waker(waker);
@@ -271,10 +299,12 @@ impl TcpStream {
271299
let if_net = if_net.clone();
272300
drop(net_manager);
273301

274-
let mut node = MCSNode::new();
275-
let mut inner = if_net.inner.lock(&mut node);
302+
let socket_set = if_net.socket_set.read();
276303

277-
let socket: &mut smoltcp::socket::tcp::Socket = inner.socket_set.get_mut(self.handle);
304+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
305+
let mut socket = socket_set
306+
.get::<smoltcp::socket::tcp::Socket>(self.handle)
307+
.lock(&mut node);
278308

279309
if socket.state() == smoltcp::socket::tcp::State::SynSent {
280310
socket.register_recv_waker(waker);
@@ -308,10 +338,12 @@ impl TcpStream {
308338
let if_net = if_net.clone();
309339
drop(net_manager);
310340

311-
let mut node = MCSNode::new();
312-
let inner = if_net.inner.lock(&mut node);
341+
let socket_set = if_net.socket_set.read();
313342

314-
let socket: &smoltcp::socket::tcp::Socket = inner.socket_set.get(self.handle);
343+
let mut node: MCSNode<smoltcp::socket::tcp::Socket> = MCSNode::new();
344+
let socket = socket_set
345+
.get::<smoltcp::socket::tcp::Socket>(self.handle)
346+
.lock(&mut node);
315347

316348
if let Some(endpoint) = socket.remote_endpoint() {
317349
Ok((

0 commit comments

Comments
 (0)