Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
93 changes: 57 additions & 36 deletions pybars/_compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
except NameError:
# Python 3 support
str_class = str
basestring = str


# Flag for testing
Expand Down Expand Up @@ -163,31 +164,43 @@ class PybarsError(Exception):
pass


class strlist(list):
class strlist(object):
__slots__ = ['value']
def __init__(self, default=None):
self.value = u''
if default:
self.grow(default)

"""A quasi-list to let the template code avoid special casing."""
def __str__(self):
return self.value

def __unicode__(self):
return self.value

def __str__(self): # Python 3
return ''.join(self)
def append(self, other):
self.value += other

def __unicode__(self): # Python 2
return u''.join(self)
def extend(self, other):
if type(other) is strlist:
self.value += other.value
else:
self.value += ''.join(other)

def grow(self, thing):
"""Make the list longer, appending for unicode, extending otherwise."""
if type(thing) == str_class:
self.append(thing)
def __iter__(self):
return iter([self])

# This will only ever match in Python 2 since str_class is str in
# Python 3.
elif type(thing) == str:
self.append(unicode(thing))
def __add__(self, other):
self.value += other.value
return self

def grow(self, other):
if isinstance(other, basestring):
self.value += other
elif type(other) is strlist:
self.value += other.value
else:
# Recursively expand to a flat list; may deserve a C accelerator at
# some point.
for element in thing:
self.grow(element)
for item in other:
self.grow(item)


_map = {
Expand All @@ -212,14 +225,15 @@ def escape(something, _escape_re=_escape_re, substitute=substitute):


def pick(context, name, default=None):
if isinstance(name, str) and hasattr(context, name):
if type(name) is str and hasattr(context, name):
return getattr(context, name)
if hasattr(context, 'get'):
return context.get(name)
try:
return context[name]
except (KeyError, TypeError):
return default
pass
return default


sentinel = object()
Expand Down Expand Up @@ -270,34 +284,40 @@ def __unicode__(self):
return unicode(self.context)


ITERABLE_TYPES = (list, tuple)

def resolve(context, *segments):
carryover_data = False

context_type = type(context)
# This makes sure that bare "this" paths don't return a Scope object
if segments == ('',) and isinstance(context, Scope):
if segments == ('',) and context_type is Scope:
return context.get('this')

carryover_data = False
for segment in segments:

if context is None:
return None

# Handle @../index syntax by popping the extra @ along the segment path
if carryover_data:
segment = u'@' + segment
carryover_data = False
segment = u'@%s' % segment
if len(segment) > 1 and segment[0:2] == '@@':

if segment[:2] == '@@':
segment = segment[1:]
carryover_data = True

if context is None:
return None
if segment in (None, ""):
if not segment:
continue
if type(context) in (list, tuple):
offset = int(segment)
context = context[offset]
elif isinstance(context, Scope):

if context_type is Scope:
context = context.get(segment)
elif context_type in ITERABLE_TYPES:
context = context[int(segment)]
else:
context = pick(context, segment)
context_type = type(context)
return context


Expand Down Expand Up @@ -336,7 +356,7 @@ def prepare(value, should_escape):


def ensure_scope(context, root):
return context if isinstance(context, Scope) else Scope(context, context, root)
return context if type(context) is Scope else Scope(context, context, root)


def _each(this, options, context):
Expand Down Expand Up @@ -509,8 +529,8 @@ def start(self):
])
else:
self._result.grow(u"def %s(context, helpers, partials, root):\n" % function_name)
self._result.grow(u" result = strlist()\n")
self._result.grow(u" context = ensure_scope(context, root)\n")
self._result.grow(u" result = strlist()\n")

def finish(self):
lines, ns, function_name = self.stack.pop(-1)
Expand All @@ -521,7 +541,7 @@ def finish(self):
self._result.grow(u" result = %s(result)\n" % str_class.__name__)
self._result.grow(u" return result\n")

source = str_class(u"".join(lines))
source = str_class(lines)

self._result = self.stack and self.stack[-1][0]
self._locals = self.stack and self.stack[-1][1]
Expand Down Expand Up @@ -578,11 +598,12 @@ def add_block(self, symbol, arguments, nested, alt_nested):
u" value = helper(context, options%s\n" % call,
u" else:\n"
u" value = helpers['blockHelperMissing'](context, options, value)\n"
u" result.grow(value or '')\n"
u" if value:\n"
u" result.grow(value)\n"
])

def add_literal(self, value):
self._result.grow(u" result.append(%s)\n" % repr(value))
self._result.grow(u" result.value += %s\n" % (repr(value),))

def _lookup_arg(self, arg):
if not arg:
Expand Down