diff --git a/brian2tools/mdexport/expander.py b/brian2tools/mdexport/expander.py index 500ae746..d5bcd19f 100644 --- a/brian2tools/mdexport/expander.py +++ b/brian2tools/mdexport/expander.py @@ -20,7 +20,7 @@ select_autoescape, ) from markdown_strings import * -from sympy import Derivative, symbols +from sympy import Derivative, Symbol, symbols from sympy.abc import * from sympy.printing import latex @@ -304,6 +304,40 @@ def render_expression(self, expression, differential=False): # to remove `$` (in most md compiler single $ is used) return rend_exp[1:][:-1] + def _rand_uniform_annotation(self, expression): + """ + Return a simple Uniform(lower, upper) annotation for expressions + containing exactly one rand() call. Otherwise return None. + """ + if not isinstance(expression, str): + return None + + if expression.count('rand()') != 1: + return None + + if 'randn()' in expression: + return None + + try: + expr_str = expression.replace('rand()', 'RANDX') + sym_expr = str_to_sympy(expr_str) + + rand_symbols = [symbol for symbol in sym_expr.free_symbols + if str(symbol) == 'RANDX'] + if len(rand_symbols) != 1: + return None + + rand_sym = rand_symbols[0] + lower = sym_expr.subs(rand_sym, 0) + upper = sym_expr.subs(rand_sym, 1) + + lower_str = self.render_expression(lower) + upper_str = self.render_expression(upper) + + return f' (approximately Uniform({lower_str}, {upper_str}))' + except Exception: + return None + def create_md_string(self, net_dict, template_name): """ Create markdown text by checking the standard dictionary and call @@ -697,29 +731,36 @@ def expand_initializer(self, initializer): ' initialized with ') else: init_str += '= ' - init_str += self.render_expression(initializer['value']) + rendered_value = self.render_expression(initializer['value']) + init_str += rendered_value + + annotation = self._rand_uniform_annotation(initializer['value']) + if annotation is not None: + init_str += annotation # not a good checking - if (isinstance(initializer['index'], str) and - (initializer['index'] != 'True' and initializer['index'] != 'False')): - init_str += ' if ' + self.render_expression(initializer['index']) - elif (isinstance(initializer['index'], bool) or - (isinstance(initializer['index'], str) and - (initializer['index'] == 'True' or - initializer['index'] == 'False'))): - if initializer['index'] is True or initializer['index'] == 'True': + index_value = initializer['index'] + if isinstance(index_value, str): + if index_value not in ('True', 'False'): + init_str += ' if ' + self.render_expression(index_value) + elif index_value == 'True': + init_str += '' # "to all members" implied + else: + raise AssertionError('Initialization with \'False\' as index?') + elif isinstance(index_value, bool): + if index_value: init_str += '' # "to all members" implied else: raise AssertionError('Initialization with \'False\' as index?') else: init_str += (' to member' + - self.check_plural(initializer['index']) + ' ') - if not hasattr(initializer['index'], '__iter___'): - init_str += str(initializer['index']) + self.check_plural(index_value) + ' ') + if not hasattr(index_value, '__iter__'): + init_str += str(index_value) else: init_str += ','.join( - [str(ind) for ind in initializer['index']] - ) + [str(ind) for ind in index_value] + ) if 'identifiers' in initializer: init_str += (', where ' + self.expand_identifiers(initializer['identifiers']) + '.') # pad new line if ordered in list