diff --git a/html5lib/tests/test_treewalkers.py b/html5lib/tests/test_treewalkers.py index 549a4747..8c120163 100644 --- a/html5lib/tests/test_treewalkers.py +++ b/html5lib/tests/test_treewalkers.py @@ -310,3 +310,54 @@ def test_treewalker(): "document")] errors = errors.split("\n") yield runTreewalkerTest, innerHTML, input, expected, errors, treeCls + + +def set_attribute_on_first_child(docfrag, name, value, treeName): + """naively sets an attribute on the first child of the document + fragment passed in""" + setter = {'ElementTree': lambda d: d[0].set, + 'DOM': lambda d: d.firstChild.setAttribute} + setter['cElementTree'] = setter['ElementTree'] + try: + setter.get(treeName, setter['DOM'])(docfrag)(name, value) + except AttributeError: + setter['ElementTree'](docfrag)(name, value) + + +def runTreewalkerEditTest(intext, expected, attrs_to_add, tree): + """tests what happens when we add attributes to the intext""" + treeName, treeClass = tree + parser = html5parser.HTMLParser(tree=treeClass["builder"]) + document = parser.parseFragment(intext) + for nom, val in attrs_to_add: + set_attribute_on_first_child(document, nom, val, treeName) + + document = treeClass.get("adapter", lambda x: x)(document) + output = convertTokens(treeClass["walker"](document)) + output = attrlist.sub(sortattrs, output) + if not output in expected: + raise AssertionError("TreewalkerEditTest: %s\nExpected:\n%s\nReceived:\n%s" % (treeName, expected, output)) + + +def test_treewalker_six_mix(): + """Str/Unicode mix. If str attrs added to tree""" + + # On Python 2.x string literals are of type str. Unless, like this + # file, the programmer imports unicode_literals from __future__. + # In that case, string literals become objects of type unicode. + + # This test simulates a Py2 user, modifying attributes on a document + # fragment but not using the u'' syntax nor importing unicode_literals + sm_tests = [ + ('Example', + [(str('class'), str('test123'))], + '\n class="test123"\n href="http://example.com"\n "Example"'), + + ('', + [(str('rel'), str('alternate'))], + '\n href="http://example.com/cow"\n rel="alternate"\n "Example"') + ] + + for tree in treeTypes.items(): + for intext, attrs, expected in sm_tests: + yield runTreewalkerEditTest, intext, expected, attrs, tree diff --git a/html5lib/treewalkers/_base.py b/html5lib/treewalkers/_base.py index 48b6da48..bd0e0c58 100644 --- a/html5lib/treewalkers/_base.py +++ b/html5lib/treewalkers/_base.py @@ -1,5 +1,5 @@ from __future__ import absolute_import, division, unicode_literals -from six import text_type +from six import text_type, string_types import gettext _ = gettext.gettext @@ -8,6 +8,24 @@ spaceCharacters = "".join(spaceCharacters) +def to_text(s, blank_if_none=True): + """Wrapper around six.text_type to convert None to empty string""" + if s is None: + if blank_if_none: + return "" + else: + return None + elif isinstance(s, text_type): + return s + else: + return text_type(s) + + +def is_text_or_none(string): + """Wrapper around isinstance(string_types) or is None""" + return string is None or isinstance(string, string_types) + + class TreeWalker(object): def __init__(self, tree): self.tree = tree @@ -19,45 +37,47 @@ def error(self, msg): return {"type": "SerializeError", "data": msg} def emptyTag(self, namespace, name, attrs, hasChildren=False): - assert namespace is None or isinstance(namespace, text_type), type(namespace) - assert isinstance(name, text_type), type(name) - assert all((namespace is None or isinstance(namespace, text_type)) and - isinstance(name, text_type) and - isinstance(value, text_type) + assert namespace is None or isinstance(namespace, string_types), type(namespace) + assert isinstance(name, string_types), type(name) + assert all((namespace is None or isinstance(namespace, string_types)) and + isinstance(name, string_types) and + isinstance(value, string_types) for (namespace, name), value in attrs.items()) - yield {"type": "EmptyTag", "name": name, - "namespace": namespace, + yield {"type": "EmptyTag", "name": to_text(name, False), + "namespace": to_text(namespace), "data": attrs} if hasChildren: yield self.error(_("Void element has children")) def startTag(self, namespace, name, attrs): - assert namespace is None or isinstance(namespace, text_type), type(namespace) - assert isinstance(name, text_type), type(name) - assert all((namespace is None or isinstance(namespace, text_type)) and - isinstance(name, text_type) and - isinstance(value, text_type) + assert namespace is None or isinstance(namespace, string_types), type(namespace) + assert isinstance(name, string_types), type(name) + assert all((namespace is None or isinstance(namespace, string_types)) and + isinstance(name, string_types) and + isinstance(value, string_types) for (namespace, name), value in attrs.items()) return {"type": "StartTag", - "name": name, - "namespace": namespace, - "data": attrs} + "name": text_type(name), + "namespace": to_text(namespace), + "data": dict(((to_text(namespace, False), to_text(name)), + to_text(value, False)) + for (namespace, name), value in attrs.items())} def endTag(self, namespace, name): - assert namespace is None or isinstance(namespace, text_type), type(namespace) - assert isinstance(name, text_type), type(namespace) + assert namespace is None or isinstance(namespace, string_types), type(namespace) + assert isinstance(name, string_types), type(namespace) return {"type": "EndTag", - "name": name, - "namespace": namespace, + "name": to_text(name, False), + "namespace": to_text(namespace), "data": {}} def text(self, data): - assert isinstance(data, text_type), type(data) + assert isinstance(data, string_types), type(data) - data = data + data = to_text(data) middle = data.lstrip(spaceCharacters) left = data[:len(data) - len(middle)] if left: @@ -71,25 +91,25 @@ def text(self, data): yield {"type": "SpaceCharacters", "data": right} def comment(self, data): - assert isinstance(data, text_type), type(data) + assert isinstance(data, string_types), type(data) - return {"type": "Comment", "data": data} + return {"type": "Comment", "data": text_type(data)} def doctype(self, name, publicId=None, systemId=None, correct=True): - assert name is None or isinstance(name, text_type), type(name) - assert publicId is None or isinstance(publicId, text_type), type(publicId) - assert systemId is None or isinstance(systemId, text_type), type(systemId) + assert is_text_or_none(name), type(name) + assert is_text_or_none(publicId), type(publicId) + assert is_text_or_none(systemId), type(systemId) return {"type": "Doctype", - "name": name if name is not None else "", - "publicId": publicId, - "systemId": systemId, - "correct": correct} + "name": to_text(name), + "publicId": to_text(publicId), + "systemId": to_text(systemId), + "correct": to_text(correct)} def entity(self, name): - assert isinstance(name, text_type), type(name) + assert isinstance(name, string_types), type(name) - return {"type": "Entity", "name": name} + return {"type": "Entity", "name": text_type(name)} def unknown(self, nodeType): return self.error(_("Unknown node type: ") + nodeType)