Skip to content

Commit 752534b

Browse files
authored
use getptr a lot less (#618)
* use getptr a lot less * Don't fail fast * typo * Fix: Use GC.@preserve in pyconvert_rule_ctypessimplevalue * preserve in incref/decref (use-after-free) * more preserve * more preserve --------- Co-authored-by: Christopher Doris <github.com/cjdoris>
1 parent 8b8ac73 commit 752534b

File tree

15 files changed

+169
-159
lines changed

15 files changed

+169
-159
lines changed

.github/workflows/tests.yml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@ jobs:
1515
name: Test (${{ matrix.os }}, julia ${{ matrix.jlversion }})
1616
runs-on: ${{ matrix.os }}
1717
strategy:
18-
fail-fast: true
18+
fail-fast: false
1919
matrix:
2020
arch: [x64] # x86 unsupported by MicroMamba
2121
os: [ubuntu-latest, windows-latest, macos-latest]
@@ -53,7 +53,7 @@ jobs:
5353
name: Test (${{ matrix.os }}, python ${{ matrix.pyversion }})
5454
runs-on: ${{ matrix.os }}
5555
strategy:
56-
fail-fast: true
56+
fail-fast: false
5757
matrix:
5858
os: [ubuntu-latest, windows-latest, macos-latest]
5959
pyversion: [">=3.8", "3.8"]

src/C/extras.jl

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,26 @@
1-
Py_Type(x::PyPtr) = PyPtr(UnsafePtr(x).type[!])
1+
asptr(x) = Base.unsafe_convert(PyPtr, x)
22

3-
PyObject_Type(x::PyPtr) = (t = Py_Type(x); Py_IncRef(t); t)
3+
Py_Type(x) = Base.GC.@preserve x PyPtr(UnsafePtr(asptr(x)).type[!])
44

5-
Py_TypeCheck(o::PyPtr, t::PyPtr) = PyType_IsSubtype(Py_Type(o), t)
6-
Py_TypeCheckFast(o::PyPtr, f::Integer) = PyType_IsSubtypeFast(Py_Type(o), f)
5+
PyObject_Type(x) = Base.GC.@preserve x (t = Py_Type(asptr(x)); Py_IncRef(t); t)
76

8-
PyType_IsSubtypeFast(t::PyPtr, f::Integer) =
9-
Cint(!iszero(UnsafePtr{PyTypeObject}(t).flags[] & f))
7+
Py_TypeCheck(o, t) = Base.GC.@preserve o t PyType_IsSubtype(Py_Type(asptr(o)), asptr(t))
8+
Py_TypeCheckFast(o, f::Integer) = Base.GC.@preserve o PyType_IsSubtypeFast(Py_Type(asptr(o)), f)
109

11-
PyMemoryView_GET_BUFFER(m::PyPtr) = Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(m).view)
10+
PyType_IsSubtypeFast(t, f::Integer) =
11+
Base.GC.@preserve t Cint(!iszero(UnsafePtr{PyTypeObject}(asptr(t)).flags[] & f))
1212

13-
PyType_CheckBuffer(t::PyPtr) = begin
14-
p = UnsafePtr{PyTypeObject}(t).as_buffer[]
13+
PyMemoryView_GET_BUFFER(m) = Base.GC.@preserve m Ptr{Py_buffer}(UnsafePtr{PyMemoryViewObject}(asptr(m)).view)
14+
15+
PyType_CheckBuffer(t) = Base.GC.@preserve t begin
16+
p = UnsafePtr{PyTypeObject}(asptr(t)).as_buffer[]
1517
return p != C_NULL && p.get[!] != C_NULL
1618
end
1719

18-
PyObject_CheckBuffer(o::PyPtr) = PyType_CheckBuffer(Py_Type(o))
20+
PyObject_CheckBuffer(o) = Base.GC.@preserve o PyType_CheckBuffer(Py_Type(asptr(o)))
1921

20-
PyObject_GetBuffer(o::PyPtr, b, flags) = begin
22+
PyObject_GetBuffer(_o, b, flags) = Base.GC.@preserve _o begin
23+
o = asptr(_o)
2124
p = UnsafePtr{PyTypeObject}(Py_Type(o)).as_buffer[]
2225
if p == C_NULL || p.get[!] == C_NULL
2326
PyErr_SetString(
@@ -61,8 +64,8 @@ function PyOS_RunInputHook()
6164
end
6265
end
6366

64-
function PySimpleObject_GetValue(::Type{T}, o::PyPtr) where {T}
65-
UnsafePtr{PySimpleObject{T}}(o).value[!]
67+
function PySimpleObject_GetValue(::Type{T}, o) where {T}
68+
Base.GC.@preserve o UnsafePtr{PySimpleObject{T}}(asptr(o)).value[!]
6669
end
6770

6871
# FAST REFCOUNTING

src/Compat/pycall.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,6 @@ function init_pycall(PyCall::Module)
1616
end
1717
@eval function PyCall.PyObject(x::Py)
1818
C.CTX.matches_pycall::Bool || error($errmsg)
19-
return $PyCall.PyObject($PyCall.PyPtr(incref(getptr(x))))
19+
return $PyCall.PyObject($PyCall.PyPtr(getptr(incref(x))))
2020
end
2121
end

src/Convert/ctypes.jl

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,14 @@
11
struct pyconvert_rule_ctypessimplevalue{R,S} <: Function end
22

33
function (::pyconvert_rule_ctypessimplevalue{R,SAFE})(::Type{T}, x::Py) where {R,SAFE,T}
4-
ptr = Base.GC.@preserve x C.PySimpleObject_GetValue(Ptr{R}, getptr(x))
5-
ans = unsafe_load(ptr)
6-
if SAFE
7-
pyconvert_return(convert(T, ans))
8-
else
9-
pyconvert_tryconvert(T, ans)
4+
Base.GC.@preserve x begin
5+
ptr = C.PySimpleObject_GetValue(Ptr{R}, x)
6+
ans = unsafe_load(ptr)
7+
if SAFE
8+
pyconvert_return(convert(T, ans))
9+
else
10+
pyconvert_tryconvert(T, ans)
11+
end
1012
end
1113
end
1214

src/Convert/numpy.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
struct pyconvert_rule_numpysimplevalue{R,S} <: Function end
22

33
function (::pyconvert_rule_numpysimplevalue{R,SAFE})(::Type{T}, x::Py) where {R,SAFE,T}
4-
ans = Base.GC.@preserve x C.PySimpleObject_GetValue(R, getptr(x))
4+
ans = C.PySimpleObject_GetValue(R, x)
55
if SAFE
66
pyconvert_return(convert(T, ans))
77
else

src/Convert/pyconvert.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,7 @@ function _pyconvert_get_rules(pytype::Py)
238238
end
239239
end
240240
for (t, x) in reverse(collect(zip(mro, xmro)))
241-
if C.PyType_CheckBuffer(getptr(t))
241+
if C.PyType_CheckBuffer(t)
242242
push!(x, "<buffer>")
243243
break
244244
end
@@ -350,7 +350,7 @@ function pytryconvert(::Type{T}, x_) where {T}
350350

351351
# get rules from the cache
352352
# TODO: we should hold weak references and clear the cache if types get deleted
353-
tptr = C.Py_Type(getptr(x))
353+
tptr = C.Py_Type(x)
354354
trules = pyconvert_rules_cache(T)
355355
rules = get!(trules, tptr) do
356356
t = pynew(incref(tptr))

src/Convert/rules.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -61,8 +61,8 @@ pyconvert_rule_bytes(::Type{Base.CodeUnits{UInt8,String}}, x::Py) =
6161
pyconvert_rule_int(::Type{T}, x::Py) where {T<:Number} = begin
6262
# first try to convert to Clonglong (or Culonglong if unsigned)
6363
v =
64-
T <: Unsigned ? C.PyLong_AsUnsignedLongLong(getptr(x)) :
65-
C.PyLong_AsLongLong(getptr(x))
64+
T <: Unsigned ? C.PyLong_AsUnsignedLongLong(x) :
65+
C.PyLong_AsLongLong(x)
6666
if !iserrset_ambig(v)
6767
# success
6868
return pyconvert_tryconvert(T, v)

src/Core/Py.jl

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,11 @@ pyconvert(::Type{Py}, x::Py) = x
5151

5252
setptr!(x::Py, ptr::C.PyPtr) = (setfield!(x, :ptr, ptr); x)
5353

54+
incref(x::Py) = Base.GC.@preserve x (incref(getptr(x)); x)
55+
decref(x::Py) = Base.GC.@preserve x (decref(getptr(x)); x)
56+
57+
Base.unsafe_convert(::Type{C.PyPtr}, x::Py) = getptr(x)
58+
5459
const PYNULL_CACHE = Py[]
5560

5661
"""
@@ -75,7 +80,7 @@ const PyNULL = pynew()
7580

7681
pynew(ptr::C.PyPtr) = setptr!(pynew(), ptr)
7782

78-
pynew(x::Py) = pynew(incref(getptr(x)))
83+
pynew(x::Py) = Base.GC.@preserve x pynew(incref(getptr(x)))
7984

8085
"""
8186
pycopy!(dst::Py, src)
@@ -164,13 +169,13 @@ Base.print(io::IO, x::Py) = print(io, string(x))
164169

165170
function Base.show(io::IO, x::Py)
166171
if get(io, :typeinfo, Any) == Py
167-
if getptr(x) == C.PyNULL
172+
if pyisnull(x)
168173
print(io, "NULL")
169174
else
170175
print(io, pyrepr(String, x))
171176
end
172177
else
173-
if getptr(x) == C.PyNULL
178+
if pyisnull(x)
174179
print(io, "<py NULL>")
175180
else
176181
s = pyrepr(String, x)

0 commit comments

Comments
 (0)