Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 73 additions & 1 deletion brian2tools/mdexport/expander.py
Original file line number Diff line number Diff line change
Expand Up @@ -679,6 +679,73 @@ def expand_equations(self, equations):
rend_eqns += self.expand_equation(var, equation)
return rend_eqns

def _annotate_random_expr(self, value_str, variable_str):
"""
Attempt to annotate an initializer expression that contains a single
``rand()`` or ``randn()`` call with human-readable bounds or
distribution parameters.

For a single ``rand()`` call the expression is evaluated at 0 and 1 to
produce lower/upper bounds: ``variable ∈ [lower, upper]``.

For a single ``randn()`` call the coefficient is recovered via
substitution (works for linear expressions) and the mean/variance of
the implied normal distribution are shown: ``μ=…, σ²=…``.

Returns an annotation string (already wrapped in ``$…$`` or an img
tag) or an empty string when no annotation can be produced.
"""
try:
import sympy
# Count occurrences in the raw string first: sympy deduplicates
# identical function calls (both rand() share the same placeholder
# argument), so expr.atoms() alone cannot distinguish one rand()
# from two rand() calls.
# Use a negative lookahead so that randn() is not counted as rand().
n_rand = len(re.findall(r'\brand(?!n)\s*\(', value_str))
n_randn = len(re.findall(r'\brandn\s*\(', value_str))

expr = str_to_sympy(value_str)
funcs = expr.atoms(sympy.Function)

rand_funcs = [f for f in funcs if f.func.__name__ == 'rand']
randn_funcs = [f for f in funcs if f.func.__name__ == 'randn']

if n_rand == 1 and n_randn == 0 and len(rand_funcs) == 1:
r = rand_funcs[0]
min_val = sympy.simplify(expr.subs(r, 0))
max_val = sympy.simplify(expr.subs(r, 1))

var_tex = sympy.latex(str_to_sympy(variable_str), mode='plain')
min_tex = sympy.latex(min_val, mode='plain')
max_tex = sympy.latex(max_val, mode='plain')
annot_tex = fr"{var_tex} \in [{min_tex}, {max_tex}]"

elif n_randn == 1 and n_rand == 0 and len(randn_funcs) == 1:
r = randn_funcs[0]
# Use substitution rather than sympy.diff, which cannot
# differentiate w.r.t. a Function application reliably.
mean_val = sympy.simplify(expr.subs(r, 0))
# For a linear expression f(randn()) = a*randn() + b the
# coefficient 'a' equals f(1) − f(0), i.e. the std deviation.
coeff = sympy.simplify(expr.subs(r, 1) - expr.subs(r, 0))
var_val = sympy.simplify(coeff ** 2)

mean_tex = sympy.latex(mean_val, mode='plain')
var_tex = sympy.latex(var_val, mode='plain')
annot_tex = fr"\mu={mean_tex},\ \sigma^2={var_tex}"

else:
return ''

if self.github_md:
return (f'<img src="https://render.githubusercontent.com/'
f'render/math?math={annot_tex}">')
return f'${annot_tex}$'

except Exception:
return ''

def expand_initializer(self, initializer):
"""
Expand initializer from initializer dictionary
Expand All @@ -694,11 +761,16 @@ def expand_initializer(self, initializer):
self.render_expression(initializer['variable']))
if self.keep_initializer_order:
init_str += (' of ' + self.expand_SpikeSource(initializer['source']) +
' initialized with ')
' initialized with ')
else:
init_str += '= '
init_str += self.render_expression(initializer['value'])

annot = self._annotate_random_expr(str(initializer['value']),
str(initializer['variable']))
if annot:
init_str += f' (implies {annot})'

# not a good checking
if (isinstance(initializer['index'], str) and
(initializer['index'] != 'True' and initializer['index'] != 'False')):
Expand Down
71 changes: 69 additions & 2 deletions brian2tools/tests/test_mdexport.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ def test_simple_syntax():
condition = 'abs(i-j)<=5'
syn.connect(condition=condition, p=0.999, n=2)
syn.w = '1 * mV'
net = Network(group, spikegen, po_grp, syn)
net = Network(group, spikegen, po_grp, syn)
mon = SpikeMonitor(po_grp)
mon2 = EventMonitor(group, 'custom')
net.add(mon, mon2)
Expand Down Expand Up @@ -316,10 +316,77 @@ def test_user_options():
device.reinit()


def test_rand_annotation():
"""
Verify that a single rand() initializer is annotated with lower/upper
bounds derived by substituting rand() → 0 and rand() → 1.

Example: ``v = El + (V_th - El)*rand()``
should produce an annotation containing ``\\in`` with ``El`` as the lower
and ``V_{th}`` (or equivalent) as the upper bound.
"""
set_device('markdown')
El = -60 * mV
V_th = -50 * mV
group = NeuronGroup(10, 'dv/dt = -v/(10*ms) : volt',
threshold='v > V_th', reset='v = El', method='euler')
group.v = 'El + (V_th - El)*rand()'
run(0 * ms)
md_str = device.md_text
assert _markdown_lint(md_str)
# The annotation must appear and must contain the interval notation.
assert 'implies' in md_str
assert r'\in' in md_str
device.reinit()


def test_randn_annotation():
"""
Verify that a single randn() initializer is annotated with the mean and
variance of the implied normal distribution, computed via substitution.

Example: ``v = El + (randn() * 5 - 5)*mV``
→ mean = El − 5 mV, σ² = (5 mV)² = 25 mV²
The annotation must include ``\\mu`` and ``\\sigma``.
"""
set_device('markdown')
El = -60 * mV
group = NeuronGroup(10, 'dv/dt = -v/(10*ms) : volt',
threshold='False', reset='', method='euler')
group.v = 'El + (randn() * 5 - 5)*mV'
run(0 * ms)
md_str = device.md_text
assert _markdown_lint(md_str)
assert 'implies' in md_str
assert r'\mu' in md_str
assert r'\sigma' in md_str
device.reinit()


def test_no_annotation_for_multiple_rand():
"""
When an expression contains more than one rand() call no annotation should
be emitted (the bounds are not well-defined).
"""
set_device('markdown')
group = NeuronGroup(10, 'dv/dt = -v/(10*ms) : volt',
threshold='False', reset='', method='euler')
# Two rand() calls – annotation must be suppressed
group.v = 'rand() * rand() * mV'
run(0 * ms)
md_str = device.md_text
assert _markdown_lint(md_str)
# "implies" should NOT appear for this initializer
assert 'implies' not in md_str
device.reinit()


if __name__ == '__main__':

test_simple_syntax()
test_common_example()
test_from_papers_example()
test_custom_expander()
test_user_options()
test_rand_annotation()
test_randn_annotation()
test_no_annotation_for_multiple_rand()