Skip to content

Commit 1296683

Browse files
authored
Added FixedPool abstraction for use in optimizing various types (#7303)
1 parent fdbce86 commit 1296683

File tree

4 files changed

+185
-0
lines changed

4 files changed

+185
-0
lines changed

src/cryptography/hazmat/bindings/_rust/__init__.pyi

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
33
# for complete details.
44

5+
import typing
6+
57
def check_pkcs7_padding(data: bytes) -> bool: ...
68
def check_ansix923_padding(data: bytes) -> bool: ...
79

@@ -11,3 +13,17 @@ class ObjectIdentifier:
1113
def dotted_string(self) -> str: ...
1214
@property
1315
def _name(self) -> str: ...
16+
17+
T = typing.TypeVar("T")
18+
19+
class FixedPool(typing.Generic[T]):
20+
def __init__(
21+
self,
22+
create: typing.Callable[[], T],
23+
destroy: typing.Callable[[T], None],
24+
) -> None: ...
25+
def acquire(self) -> PoolAcquisition[T]: ...
26+
27+
class PoolAcquisition(typing.Generic[T]):
28+
def __enter__(self) -> T: ...
29+
def __exit__(self, exc_type, exc_value, exc_tb) -> None: ...

src/rust/src/lib.rs

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
mod asn1;
88
mod intern;
99
pub(crate) mod oid;
10+
mod pool;
1011
mod x509;
1112

1213
use std::convert::TryInto;
@@ -77,6 +78,7 @@ fn _rust(py: pyo3::Python<'_>, m: &pyo3::types::PyModule) -> pyo3::PyResult<()>
7778
m.add_function(pyo3::wrap_pyfunction!(check_pkcs7_padding, m)?)?;
7879
m.add_function(pyo3::wrap_pyfunction!(check_ansix923_padding, m)?)?;
7980
m.add_class::<oid::ObjectIdentifier>()?;
81+
m.add_class::<pool::FixedPool>()?;
8082

8183
m.add_submodule(asn1::create_submodule(py)?)?;
8284

src/rust/src/pool.rs

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
// This file is dual licensed under the terms of the Apache License, Version
2+
// 2.0, and the BSD License. See the LICENSE file in the root of this repository
3+
// for complete details.
4+
5+
use std::cell::Cell;
6+
7+
// An object pool that can contain a single object and will dynamically
8+
// allocate new objects to fulfill requests if the pool'd object is already in
9+
// use.
10+
#[pyo3::prelude::pyclass]
11+
pub(crate) struct FixedPool {
12+
create_fn: pyo3::PyObject,
13+
destroy_fn: pyo3::PyObject,
14+
15+
value: Cell<Option<pyo3::PyObject>>,
16+
}
17+
18+
#[pyo3::prelude::pyclass]
19+
struct PoolAcquisition {
20+
pool: pyo3::Py<FixedPool>,
21+
22+
value: pyo3::PyObject,
23+
fresh: bool,
24+
}
25+
26+
#[pyo3::pymethods]
27+
impl FixedPool {
28+
#[new]
29+
fn new(
30+
py: pyo3::Python<'_>,
31+
create: pyo3::PyObject,
32+
destroy: pyo3::PyObject,
33+
) -> pyo3::PyResult<Self> {
34+
let value = create.call0(py)?;
35+
36+
Ok(FixedPool {
37+
create_fn: create,
38+
destroy_fn: destroy,
39+
40+
value: Cell::new(Some(value)),
41+
})
42+
}
43+
44+
fn acquire(slf: pyo3::Py<Self>, py: pyo3::Python<'_>) -> pyo3::PyResult<PoolAcquisition> {
45+
let v = slf.as_ref(py).borrow().value.replace(None);
46+
if let Some(value) = v {
47+
Ok(PoolAcquisition {
48+
pool: slf,
49+
value,
50+
fresh: false,
51+
})
52+
} else {
53+
let value = slf.as_ref(py).borrow().create_fn.call0(py)?;
54+
Ok(PoolAcquisition {
55+
pool: slf,
56+
value,
57+
fresh: true,
58+
})
59+
}
60+
}
61+
}
62+
63+
impl Drop for FixedPool {
64+
fn drop(&mut self) {
65+
if let Some(value) = self.value.replace(None) {
66+
let gil = pyo3::Python::acquire_gil();
67+
let py = gil.python();
68+
self.destroy_fn
69+
.call1(py, (value,))
70+
.expect("FixedPool destroy function failed in destructor");
71+
}
72+
}
73+
}
74+
75+
#[pyo3::pymethods]
76+
impl PoolAcquisition {
77+
fn __enter__(&self, py: pyo3::Python<'_>) -> pyo3::PyObject {
78+
self.value.clone_ref(py)
79+
}
80+
81+
fn __exit__(
82+
&self,
83+
py: pyo3::Python<'_>,
84+
_exc_type: &pyo3::PyAny,
85+
_exc_value: &pyo3::PyAny,
86+
_exc_tb: &pyo3::PyAny,
87+
) -> pyo3::PyResult<()> {
88+
let pool = self.pool.as_ref(py).borrow();
89+
if self.fresh {
90+
pool.destroy_fn.call1(py, (self.value.clone_ref(py),))?;
91+
} else {
92+
pool.value.replace(Some(self.value.clone_ref(py)));
93+
}
94+
Ok(())
95+
}
96+
}

tests/test_rust_utils.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,71 @@
1+
# This file is dual licensed under the terms of the Apache License, Version
2+
# 2.0, and the BSD License. See the LICENSE file in the root of this repository
3+
# for complete details.
4+
5+
import gc
6+
import threading
7+
8+
from cryptography.hazmat.bindings._rust import FixedPool
9+
10+
11+
class TestFixedPool:
12+
def test_basic(self):
13+
c = 0
14+
events = []
15+
16+
def create():
17+
nonlocal c
18+
c += 1
19+
events.append(("create", c))
20+
return c
21+
22+
def destroy(c):
23+
events.append(("destroy", c))
24+
25+
pool = FixedPool(create, destroy)
26+
assert events == [("create", 1)]
27+
with pool.acquire() as c:
28+
assert c == 1
29+
assert events == [("create", 1)]
30+
31+
with pool.acquire() as c:
32+
assert c == 2
33+
assert events == [("create", 1), ("create", 2)]
34+
35+
assert events == [("create", 1), ("create", 2), ("destroy", 2)]
36+
37+
assert events == [("create", 1), ("create", 2), ("destroy", 2)]
38+
39+
del pool
40+
gc.collect()
41+
gc.collect()
42+
gc.collect()
43+
44+
assert events == [
45+
("create", 1),
46+
("create", 2),
47+
("destroy", 2),
48+
("destroy", 1),
49+
]
50+
51+
def test_thread_stress(self):
52+
def create():
53+
return None
54+
55+
def destroy(c):
56+
pass
57+
58+
pool = FixedPool(create, destroy)
59+
60+
def thread_fn():
61+
with pool.acquire():
62+
pass
63+
64+
threads = []
65+
for i in range(1024):
66+
t = threading.Thread(target=thread_fn)
67+
t.start()
68+
threads.append(t)
69+
70+
for t in threads:
71+
t.join()

0 commit comments

Comments
 (0)