Edgewall Software

source: branches/stable/0.5.x/genshi/template/eval.py

Last change on this file was 1010, checked in by cmlenz, 15 years ago

Ported [1008] and [1009] to 0.5.x branch.

  • Property svn:eol-style set to native
File size: 27.3 KB
Line 
1# -*- coding: utf-8 -*-
2#
3# Copyright (C) 2006-2008 Edgewall Software
4# All rights reserved.
5#
6# This software is licensed as described in the file COPYING, which
7# you should have received as part of this distribution. The terms
8# are also available at http://genshi.edgewall.org/wiki/License.
9#
10# This software consists of voluntary contributions made by many
11# individuals. For the exact contribution history, see the revision
12# history and logs, available at http://genshi.edgewall.org/log/.
13
14"""Support for "safe" evaluation of Python expressions."""
15
16import __builtin__
17from compiler import ast, parse
18from compiler.pycodegen import ExpressionCodeGenerator, ModuleCodeGenerator
19import new
20try:
21    set
22except NameError:
23    from sets import ImmutableSet as frozenset
24    from sets import Set as set
25from textwrap import dedent
26from types import CodeType
27
28from genshi.core import Markup
29from genshi.template.base import TemplateRuntimeError
30from genshi.util import flatten
31
32__all__ = ['Code', 'Expression', 'Suite', 'LenientLookup', 'StrictLookup',
33           'Undefined', 'UndefinedError']
34__docformat__ = 'restructuredtext en'
35
36# Check for a Python 2.4 bug in the eval loop
37has_star_import_bug = False
38try:
39    class _FakeMapping(object):
40        __getitem__ = __setitem__ = lambda *a: None
41    exec 'from sys import *' in {}, _FakeMapping()
42except SystemError:
43    has_star_import_bug = True
44except TypeError:
45    pass # Python 2.3
46del _FakeMapping
47
48def _star_import_patch(mapping, modname):
49    """This function is used as helper if a Python version with a broken
50    star-import opcode is in use.
51    """
52    module = __import__(modname, None, None, ['__all__'])
53    if hasattr(module, '__all__'):
54        members = module.__all__
55    else:
56        members = [x for x in module.__dict__ if not x.startswith('_')]
57    mapping.update([(name, getattr(module, name)) for name in members])
58
59
60class Code(object):
61    """Abstract base class for the `Expression` and `Suite` classes."""
62    __slots__ = ['source', 'code', 'ast', '_globals']
63
64    def __init__(self, source, filename=None, lineno=-1, lookup='strict',
65                 xform=None):
66        """Create the code object, either from a string, or from an AST node.
67       
68        :param source: either a string containing the source code, or an AST
69                       node
70        :param filename: the (preferably absolute) name of the file containing
71                         the code
72        :param lineno: the number of the line on which the code was found
73        :param lookup: the lookup class that defines how variables are looked
74                       up in the context; can be either "strict" (the default),
75                       "lenient", or a custom lookup class
76        :param xform: the AST transformer that should be applied to the code;
77                      if `None`, the appropriate transformation is chosen
78                      depending on the mode
79        """
80        if isinstance(source, basestring):
81            self.source = source
82            node = _parse(source, mode=self.mode)
83        else:
84            assert isinstance(source, ast.Node), \
85                'Expected string or AST node, but got %r' % source
86            self.source = '?'
87            if self.mode == 'eval':
88                node = ast.Expression(source)
89            else:
90                node = ast.Module(None, source)
91
92        self.ast = node
93        self.code = _compile(node, self.source, mode=self.mode,
94                             filename=filename, lineno=lineno, xform=xform)
95        if lookup is None:
96            lookup = LenientLookup
97        elif isinstance(lookup, basestring):
98            lookup = {'lenient': LenientLookup, 'strict': StrictLookup}[lookup]
99        self._globals = lookup.globals
100
101    def __getstate__(self):
102        state = {'source': self.source, 'ast': self.ast,
103                 'lookup': self._globals.im_self}
104        c = self.code
105        state['code'] = (c.co_nlocals, c.co_stacksize, c.co_flags, c.co_code,
106                         c.co_consts, c.co_names, c.co_varnames, c.co_filename,
107                         c.co_name, c.co_firstlineno, c.co_lnotab, (), ())
108        return state
109
110    def __setstate__(self, state):
111        self.source = state['source']
112        self.ast = state['ast']
113        self.code = new.code(0, *state['code'])
114        self._globals = state['lookup'].globals
115
116    def __eq__(self, other):
117        return (type(other) == type(self)) and (self.code == other.code)
118
119    def __hash__(self):
120        return hash(self.code)
121
122    def __ne__(self, other):
123        return not self == other
124
125    def __repr__(self):
126        return '%s(%r)' % (self.__class__.__name__, self.source)
127
128
129class Expression(Code):
130    """Evaluates Python expressions used in templates.
131
132    >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
133    >>> Expression('test').evaluate(data)
134    'Foo'
135
136    >>> Expression('items[0]').evaluate(data)
137    1
138    >>> Expression('items[-1]').evaluate(data)
139    3
140    >>> Expression('dict["some"]').evaluate(data)
141    'thing'
142   
143    Similar to e.g. Javascript, expressions in templates can use the dot
144    notation for attribute access to access items in mappings:
145   
146    >>> Expression('dict.some').evaluate(data)
147    'thing'
148   
149    This also works the other way around: item access can be used to access
150    any object attribute:
151   
152    >>> class MyClass(object):
153    ...     myattr = 'Bar'
154    >>> data = dict(mine=MyClass(), key='myattr')
155    >>> Expression('mine.myattr').evaluate(data)
156    'Bar'
157    >>> Expression('mine["myattr"]').evaluate(data)
158    'Bar'
159    >>> Expression('mine[key]').evaluate(data)
160    'Bar'
161   
162    All of the standard Python operators are available to template expressions.
163    Built-in functions such as ``len()`` are also available in template
164    expressions:
165   
166    >>> data = dict(items=[1, 2, 3])
167    >>> Expression('len(items)').evaluate(data)
168    3
169    """
170    __slots__ = []
171    mode = 'eval'
172
173    def evaluate(self, data):
174        """Evaluate the expression against the given data dictionary.
175       
176        :param data: a mapping containing the data to evaluate against
177        :return: the result of the evaluation
178        """
179        __traceback_hide__ = 'before_and_this'
180        _globals = self._globals(data)
181        return eval(self.code, _globals, {'__data__': data})
182
183
184class Suite(Code):
185    """Executes Python statements used in templates.
186
187    >>> data = dict(test='Foo', items=[1, 2, 3], dict={'some': 'thing'})
188    >>> Suite("foo = dict['some']").execute(data)
189    >>> data['foo']
190    'thing'
191    """
192    __slots__ = []
193    mode = 'exec'
194
195    def execute(self, data):
196        """Execute the suite in the given data dictionary.
197       
198        :param data: a mapping containing the data to execute in
199        """
200        __traceback_hide__ = 'before_and_this'
201        _globals = self._globals(data)
202        exec self.code in _globals, data
203
204
205UNDEFINED = object()
206
207
208class UndefinedError(TemplateRuntimeError):
209    """Exception thrown when a template expression attempts to access a variable
210    not defined in the context.
211   
212    :see: `LenientLookup`, `StrictLookup`
213    """
214    def __init__(self, name, owner=UNDEFINED):
215        if owner is not UNDEFINED:
216            message = '%s has no member named "%s"' % (repr(owner), name)
217        else:
218            message = '"%s" not defined' % name
219        TemplateRuntimeError.__init__(self, message)
220
221
222class Undefined(object):
223    """Represents a reference to an undefined variable.
224   
225    Unlike the Python runtime, template expressions can refer to an undefined
226    variable without causing a `NameError` to be raised. The result will be an
227    instance of the `Undefined` class, which is treated the same as ``False`` in
228    conditions, but raise an exception on any other operation:
229   
230    >>> foo = Undefined('foo')
231    >>> bool(foo)
232    False
233    >>> list(foo)
234    []
235    >>> print foo
236    undefined
237   
238    However, calling an undefined variable, or trying to access an attribute
239    of that variable, will raise an exception that includes the name used to
240    reference that undefined variable.
241   
242    >>> foo('bar')
243    Traceback (most recent call last):
244        ...
245    UndefinedError: "foo" not defined
246
247    >>> foo.bar
248    Traceback (most recent call last):
249        ...
250    UndefinedError: "foo" not defined
251   
252    :see: `LenientLookup`
253    """
254    __slots__ = ['_name', '_owner']
255
256    def __init__(self, name, owner=UNDEFINED):
257        """Initialize the object.
258       
259        :param name: the name of the reference
260        :param owner: the owning object, if the variable is accessed as a member
261        """
262        self._name = name
263        self._owner = owner
264
265    def __iter__(self):
266        return iter([])
267
268    def __nonzero__(self):
269        return False
270
271    def __repr__(self):
272        return '<%s %r>' % (self.__class__.__name__, self._name)
273
274    def __str__(self):
275        return 'undefined'
276
277    def _die(self, *args, **kwargs):
278        """Raise an `UndefinedError`."""
279        __traceback_hide__ = True
280        raise UndefinedError(self._name, self._owner)
281    __call__ = __getattr__ = __getitem__ = _die
282
283
284class LookupBase(object):
285    """Abstract base class for variable lookup implementations."""
286
287    def globals(cls, data):
288        """Construct the globals dictionary to use as the execution context for
289        the expression or suite.
290        """
291        return {
292            '__data__': data,
293            '_lookup_name': cls.lookup_name,
294            '_lookup_attr': cls.lookup_attr,
295            '_lookup_item': cls.lookup_item,
296            '_star_import_patch': _star_import_patch,
297            'UndefinedError': UndefinedError,
298        }
299    globals = classmethod(globals)
300
301    def lookup_name(cls, data, name):
302        __traceback_hide__ = True
303        val = data.get(name, UNDEFINED)
304        if val is UNDEFINED:
305            val = BUILTINS.get(name, val)
306            if val is UNDEFINED:
307                val = cls.undefined(name)
308        return val
309    lookup_name = classmethod(lookup_name)
310
311    def lookup_attr(cls, obj, key):
312        __traceback_hide__ = True
313        try:
314            val = getattr(obj, key)
315        except AttributeError:
316            if hasattr(obj.__class__, key):
317                raise
318            else:
319                try:
320                    val = obj[key]
321                except (KeyError, TypeError):
322                    val = cls.undefined(key, owner=obj)
323        return val
324    lookup_attr = classmethod(lookup_attr)
325
326    def lookup_item(cls, obj, key):
327        __traceback_hide__ = True
328        if len(key) == 1:
329            key = key[0]
330        try:
331            return obj[key]
332        except (AttributeError, KeyError, IndexError, TypeError), e:
333            if isinstance(key, basestring):
334                val = getattr(obj, key, UNDEFINED)
335                if val is UNDEFINED:
336                    val = cls.undefined(key, owner=obj)
337                return val
338            raise
339    lookup_item = classmethod(lookup_item)
340
341    def undefined(cls, key, owner=UNDEFINED):
342        """Can be overridden by subclasses to specify behavior when undefined
343        variables are accessed.
344       
345        :param key: the name of the variable
346        :param owner: the owning object, if the variable is accessed as a member
347        """
348        raise NotImplementedError
349    undefined = classmethod(undefined)
350
351
352class LenientLookup(LookupBase):
353    """Default variable lookup mechanism for expressions.
354   
355    When an undefined variable is referenced using this lookup style, the
356    reference evaluates to an instance of the `Undefined` class:
357   
358    >>> expr = Expression('nothing', lookup='lenient')
359    >>> undef = expr.evaluate({})
360    >>> undef
361    <Undefined 'nothing'>
362   
363    The same will happen when a non-existing attribute or item is accessed on
364    an existing object:
365   
366    >>> expr = Expression('something.nil', lookup='lenient')
367    >>> expr.evaluate({'something': dict()})
368    <Undefined 'nil'>
369   
370    See the documentation of the `Undefined` class for details on the behavior
371    of such objects.
372   
373    :see: `StrictLookup`
374    """
375    def undefined(cls, key, owner=UNDEFINED):
376        """Return an ``Undefined`` object."""
377        __traceback_hide__ = True
378        return Undefined(key, owner=owner)
379    undefined = classmethod(undefined)
380
381
382class StrictLookup(LookupBase):
383    """Strict variable lookup mechanism for expressions.
384   
385    Referencing an undefined variable using this lookup style will immediately
386    raise an ``UndefinedError``:
387   
388    >>> expr = Expression('nothing', lookup='strict')
389    >>> expr.evaluate({})
390    Traceback (most recent call last):
391        ...
392    UndefinedError: "nothing" not defined
393   
394    The same happens when a non-existing attribute or item is accessed on an
395    existing object:
396   
397    >>> expr = Expression('something.nil', lookup='strict')
398    >>> expr.evaluate({'something': dict()})
399    Traceback (most recent call last):
400        ...
401    UndefinedError: {} has no member named "nil"
402    """
403    def undefined(cls, key, owner=UNDEFINED):
404        """Raise an ``UndefinedError`` immediately."""
405        __traceback_hide__ = True
406        raise UndefinedError(key, owner=owner)
407    undefined = classmethod(undefined)
408
409
410def _parse(source, mode='eval'):
411    source = source.strip()
412    if mode == 'exec':
413        lines = [line.expandtabs() for line in source.splitlines()]
414        if lines:
415            first = lines[0]
416            rest = dedent('\n'.join(lines[1:])).rstrip()
417            if first.rstrip().endswith(':') and not rest[0].isspace():
418                rest = '\n'.join(['    %s' % line for line in rest.splitlines()])
419            source = '\n'.join([first, rest])
420    if isinstance(source, unicode):
421        source = '\xef\xbb\xbf' + source.encode('utf-8')
422    return parse(source, mode)
423
424def _compile(node, source=None, mode='eval', filename=None, lineno=-1,
425             xform=None):
426    if xform is None:
427        xform = {'eval': ExpressionASTTransformer}.get(mode,
428                                                       TemplateASTTransformer)
429    tree = xform().visit(node)
430    if isinstance(filename, unicode):
431        # unicode file names not allowed for code objects
432        filename = filename.encode('utf-8', 'replace')
433    elif not filename:
434        filename = '<string>'
435    tree.filename = filename
436    if lineno <= 0:
437        lineno = 1
438
439    if mode == 'eval':
440        gen = ExpressionCodeGenerator(tree)
441        name = '<Expression %r>' % (source or '?')
442    else:
443        gen = ModuleCodeGenerator(tree)
444        lines = source.splitlines()
445        if not lines:
446            extract = ''
447        else:
448            extract = lines[0]
449        if len(lines) > 1:
450            extract += ' ...'
451        name = '<Suite %r>' % (extract)
452    gen.optimized = True
453    code = gen.getCode()
454
455    # We'd like to just set co_firstlineno, but it's readonly. So we need to
456    # clone the code object while adjusting the line number
457    return CodeType(0, code.co_nlocals, code.co_stacksize,
458                    code.co_flags | 0x0040, code.co_code, code.co_consts,
459                    code.co_names, code.co_varnames, filename, name, lineno,
460                    code.co_lnotab, (), ())
461
462BUILTINS = __builtin__.__dict__.copy()
463BUILTINS.update({'Markup': Markup, 'Undefined': Undefined})
464CONSTANTS = frozenset(['False', 'True', 'None', 'NotImplemented', 'Ellipsis'])
465
466
467class ASTTransformer(object):
468    """General purpose base class for AST transformations.
469   
470    Every visitor method can be overridden to return an AST node that has been
471    altered or replaced in some way.
472    """
473
474    def visit(self, node):
475        if node is None:
476            return None
477        if type(node) is tuple:
478            return tuple([self.visit(n) for n in node])
479        visitor = getattr(self, 'visit%s' % node.__class__.__name__,
480                          self._visitDefault)
481        return visitor(node)
482
483    def _clone(self, node, *args):
484        lineno = getattr(node, 'lineno', None)
485        node = node.__class__(*args)
486        if lineno is not None:
487            node.lineno = lineno
488        if isinstance(node, (ast.Class, ast.Function, ast.Lambda)) or \
489                hasattr(ast, 'GenExpr') and isinstance(node, ast.GenExpr):
490            node.filename = '<string>' # workaround for bug in pycodegen
491        return node
492
493    def _visitDefault(self, node):
494        return node
495
496    def visitExpression(self, node):
497        return self._clone(node, self.visit(node.node))
498
499    def visitModule(self, node):
500        return self._clone(node, node.doc, self.visit(node.node))
501
502    def visitStmt(self, node):
503        return self._clone(node, [self.visit(x) for x in node.nodes])
504
505    # Classes, Functions & Accessors
506
507    def visitCallFunc(self, node):
508        return self._clone(node, self.visit(node.node),
509            [self.visit(x) for x in node.args],
510            node.star_args and self.visit(node.star_args) or None,
511            node.dstar_args and self.visit(node.dstar_args) or None
512        )
513
514    def visitClass(self, node):
515        return self._clone(node, node.name, [self.visit(x) for x in node.bases],
516            node.doc, self.visit(node.code)
517        )
518
519    def visitFrom(self, node):
520        if not has_star_import_bug or node.names != [('*', None)]:
521            # This is a Python 2.4 bug. Only if we have a broken Python
522            # version we have to apply the hack
523            return node
524        new_node = ast.Discard(ast.CallFunc(
525            ast.Name('_star_import_patch'),
526            [ast.Name('__data__'), ast.Const(node.modname)], None, None
527        ))
528        if hasattr(node, 'lineno'): # No lineno in Python 2.3
529            new_node.lineno = node.lineno
530        return new_node
531
532    def visitFunction(self, node):
533        args = []
534        if hasattr(node, 'decorators'):
535            args.append(self.visit(node.decorators))
536        return self._clone(node, *args + [
537            node.name,
538            node.argnames,
539            [self.visit(x) for x in node.defaults],
540            node.flags,
541            node.doc,
542            self.visit(node.code)
543        ])
544
545    def visitGetattr(self, node):
546        return self._clone(node, self.visit(node.expr), node.attrname)
547
548    def visitLambda(self, node):
549        node = self._clone(node, node.argnames,
550            [self.visit(x) for x in node.defaults], node.flags,
551            self.visit(node.code)
552        )
553        return node
554
555    def visitSubscript(self, node):
556        return self._clone(node, self.visit(node.expr), node.flags,
557            [self.visit(x) for x in node.subs]
558        )
559
560    # Statements
561
562    def visitAssert(self, node):
563        return self._clone(node, self.visit(node.test), self.visit(node.fail))
564
565    def visitAssign(self, node):
566        return self._clone(node, [self.visit(x) for x in node.nodes],
567            self.visit(node.expr)
568        )
569
570    def visitAssAttr(self, node):
571        return self._clone(node, self.visit(node.expr), node.attrname,
572            node.flags
573        )
574
575    def visitAugAssign(self, node):
576        return self._clone(node, self.visit(node.node), node.op,
577            self.visit(node.expr)
578        )
579
580    def visitDecorators(self, node):
581        return self._clone(node, [self.visit(x) for x in node.nodes])
582
583    def visitExec(self, node):
584        return self._clone(node, self.visit(node.expr), self.visit(node.locals),
585            self.visit(node.globals)
586        )
587
588    def visitFor(self, node):
589        return self._clone(node, self.visit(node.assign), self.visit(node.list),
590            self.visit(node.body), self.visit(node.else_)
591        )
592
593    def visitIf(self, node):
594        return self._clone(node, [self.visit(x) for x in node.tests],
595            self.visit(node.else_)
596        )
597
598    def visitImport(self, node):
599        return self._clone(node, node.names)
600
601    def _visitPrint(self, node):
602        return self._clone(node, [self.visit(x) for x in node.nodes],
603            self.visit(node.dest)
604        )
605    visitPrint = visitPrintnl = _visitPrint
606
607    def visitRaise(self, node):
608        return self._clone(node, self.visit(node.expr1), self.visit(node.expr2),
609            self.visit(node.expr3)
610        )
611
612    def visitReturn(self, node):
613        return self._clone(node, self.visit(node.value))
614
615    def visitTryExcept(self, node):
616        return self._clone(node, self.visit(node.body), self.visit(node.handlers),
617            self.visit(node.else_)
618        )
619
620    def visitTryFinally(self, node):
621        return self._clone(node, self.visit(node.body), self.visit(node.final))
622
623    def visitWhile(self, node):
624        return self._clone(node, self.visit(node.test), self.visit(node.body),
625            self.visit(node.else_)
626        )
627
628    def visitWith(self, node):
629        return self._clone(node, self.visit(node.expr),
630            [self.visit(x) for x in node.vars], self.visit(node.body)
631        )
632
633    def visitYield(self, node):
634        return self._clone(node, self.visit(node.value))
635
636    # Operators
637
638    def _visitBoolOp(self, node):
639        return self._clone(node, [self.visit(x) for x in node.nodes])
640    visitAnd = visitOr = visitBitand = visitBitor = visitBitxor = _visitBoolOp
641    visitAssTuple = visitAssList = _visitBoolOp
642
643    def _visitBinOp(self, node):
644        return self._clone(node,
645            (self.visit(node.left), self.visit(node.right))
646        )
647    visitAdd = visitSub = _visitBinOp
648    visitDiv = visitFloorDiv = visitMod = visitMul = visitPower = _visitBinOp
649    visitLeftShift = visitRightShift = _visitBinOp
650
651    def visitCompare(self, node):
652        return self._clone(node, self.visit(node.expr),
653            [(op, self.visit(n)) for op, n in  node.ops]
654        )
655
656    def _visitUnaryOp(self, node):
657        return self._clone(node, self.visit(node.expr))
658    visitUnaryAdd = visitUnarySub = visitNot = visitInvert = _visitUnaryOp
659    visitBackquote = visitDiscard = _visitUnaryOp
660
661    def visitIfExp(self, node):
662        return self._clone(node, self.visit(node.test), self.visit(node.then),
663            self.visit(node.else_)
664        )
665
666    # Identifiers, Literals and Comprehensions
667
668    def visitDict(self, node):
669        return self._clone(node, 
670            [(self.visit(k), self.visit(v)) for k, v in node.items]
671        )
672
673    def visitGenExpr(self, node):
674        return self._clone(node, self.visit(node.code))
675
676    def visitGenExprFor(self, node):
677        return self._clone(node, self.visit(node.assign), self.visit(node.iter),
678            [self.visit(x) for x in node.ifs]
679        )
680
681    def visitGenExprIf(self, node):
682        return self._clone(node, self.visit(node.test))
683
684    def visitGenExprInner(self, node):
685        quals = [self.visit(x) for x in node.quals]
686        return self._clone(node, self.visit(node.expr), quals)
687
688    def visitKeyword(self, node):
689        return self._clone(node, node.name, self.visit(node.expr))
690
691    def visitList(self, node):
692        return self._clone(node, [self.visit(n) for n in node.nodes])
693
694    def visitListComp(self, node):
695        quals = [self.visit(x) for x in node.quals]
696        return self._clone(node, self.visit(node.expr), quals)
697
698    def visitListCompFor(self, node):
699        return self._clone(node, self.visit(node.assign), self.visit(node.list),
700            [self.visit(x) for x in node.ifs]
701        )
702
703    def visitListCompIf(self, node):
704        return self._clone(node, self.visit(node.test))
705
706    def visitSlice(self, node):
707        return self._clone(node, self.visit(node.expr), node.flags,
708            node.lower and self.visit(node.lower) or None,
709            node.upper and self.visit(node.upper) or None
710        )
711
712    def visitSliceobj(self, node):
713        return self._clone(node, [self.visit(x) for x in node.nodes])
714
715    def visitTuple(self, node):
716        return self._clone(node, [self.visit(n) for n in node.nodes])
717
718
719class TemplateASTTransformer(ASTTransformer):
720    """Concrete AST transformer that implements the AST transformations needed
721    for code embedded in templates.
722    """
723
724    def __init__(self):
725        self.locals = [CONSTANTS]
726
727    def visitConst(self, node):
728        if isinstance(node.value, str):
729            try: # If the string is ASCII, return a `str` object
730                node.value.decode('ascii')
731            except ValueError: # Otherwise return a `unicode` object
732                return ast.Const(node.value.decode('utf-8'))
733        return node
734
735    def visitAssName(self, node):
736        if len(self.locals) > 1:
737            self.locals[-1].add(node.name)
738        return node
739
740    def visitAugAssign(self, node):
741        if isinstance(node.node, ast.Name) \
742                and node.node.name not in flatten(self.locals):
743            name = node.node.name
744            node.node = ast.Subscript(ast.Name('__data__'), 'OP_APPLY',
745                                      [ast.Const(name)])
746            node.expr = self.visit(node.expr)
747            return ast.If([
748                (ast.Compare(ast.Const(name), [('in', ast.Name('__data__'))]),
749                 ast.Stmt([node]))],
750                ast.Stmt([ast.Raise(ast.CallFunc(ast.Name('UndefinedError'),
751                                                 [ast.Const(name)]),
752                                    None, None)]))
753        else:
754            return ASTTransformer.visitAugAssign(self, node)
755
756    def visitClass(self, node):
757        if len(self.locals) > 1:
758            self.locals[-1].add(node.name)
759        self.locals.append(set())
760        try:
761            return ASTTransformer.visitClass(self, node)
762        finally:
763            self.locals.pop()
764
765    def visitFrom(self, node):
766        if node.names != [('*', None)]:
767            if len(self.locals) > 1:
768                self.locals[-1].update([n[1] or n[0] for n in node.names])
769        return ASTTransformer.visitFrom(self, node)
770
771    def visitFunction(self, node):
772        if len(self.locals) > 1:
773            self.locals[-1].add(node.name)
774        self.locals.append(set(node.argnames))
775        try:
776            return ASTTransformer.visitFunction(self, node)
777        finally:
778            self.locals.pop()
779
780    def visitGenExpr(self, node):
781        self.locals.append(set())
782        try:
783            return ASTTransformer.visitGenExpr(self, node)
784        finally:
785            self.locals.pop()
786
787    def visitImport(self, node):
788        if len(self.locals) > 1:
789            self.locals[-1].update([n.asname or n.name for n in node.names])
790        return ASTTransformer.visitImport(self, node)
791
792    def visitLambda(self, node):
793        self.locals.append(set(flatten(node.argnames)))
794        try:
795            return ASTTransformer.visitLambda(self, node)
796        finally:
797            self.locals.pop()
798
799    def visitListComp(self, node):
800        self.locals.append(set())
801        try:
802            return ASTTransformer.visitListComp(self, node)
803        finally:
804            self.locals.pop()
805
806    def visitName(self, node):
807        # If the name refers to a local inside a lambda, list comprehension, or
808        # generator expression, leave it alone
809        if node.name not in flatten(self.locals):
810            # Otherwise, translate the name ref into a context lookup
811            func_args = [ast.Name('__data__'), ast.Const(node.name)]
812            node = ast.CallFunc(ast.Name('_lookup_name'), func_args)
813        return node
814
815
816class ExpressionASTTransformer(TemplateASTTransformer):
817    """Concrete AST transformer that implements the AST transformations needed
818    for code embedded in templates.
819    """
820
821    def visitGetattr(self, node):
822        return ast.CallFunc(ast.Name('_lookup_attr'), [
823            self.visit(node.expr),
824            ast.Const(node.attrname)
825        ])
826
827    def visitSubscript(self, node):
828        return ast.CallFunc(ast.Name('_lookup_item'), [
829            self.visit(node.expr),
830            ast.Tuple([self.visit(sub) for sub in node.subs])
831        ])
Note: See TracBrowser for help on using the repository browser.