diff --git a/html5lib/tests/test_treewalkers.py b/html5lib/tests/test_treewalkers.py
index 549a4747..8c120163 100644
--- a/html5lib/tests/test_treewalkers.py
+++ b/html5lib/tests/test_treewalkers.py
@@ -310,3 +310,54 @@ def test_treewalker():
"document")]
errors = errors.split("\n")
yield runTreewalkerTest, innerHTML, input, expected, errors, treeCls
+
+
+def set_attribute_on_first_child(docfrag, name, value, treeName):
+ """naively sets an attribute on the first child of the document
+ fragment passed in"""
+ setter = {'ElementTree': lambda d: d[0].set,
+ 'DOM': lambda d: d.firstChild.setAttribute}
+ setter['cElementTree'] = setter['ElementTree']
+ try:
+ setter.get(treeName, setter['DOM'])(docfrag)(name, value)
+ except AttributeError:
+ setter['ElementTree'](docfrag)(name, value)
+
+
+def runTreewalkerEditTest(intext, expected, attrs_to_add, tree):
+ """tests what happens when we add attributes to the intext"""
+ treeName, treeClass = tree
+ parser = html5parser.HTMLParser(tree=treeClass["builder"])
+ document = parser.parseFragment(intext)
+ for nom, val in attrs_to_add:
+ set_attribute_on_first_child(document, nom, val, treeName)
+
+ document = treeClass.get("adapter", lambda x: x)(document)
+ output = convertTokens(treeClass["walker"](document))
+ output = attrlist.sub(sortattrs, output)
+ if not output in expected:
+ raise AssertionError("TreewalkerEditTest: %s\nExpected:\n%s\nReceived:\n%s" % (treeName, expected, output))
+
+
+def test_treewalker_six_mix():
+ """Str/Unicode mix. If str attrs added to tree"""
+
+ # On Python 2.x string literals are of type str. Unless, like this
+ # file, the programmer imports unicode_literals from __future__.
+ # In that case, string literals become objects of type unicode.
+
+ # This test simulates a Py2 user, modifying attributes on a document
+ # fragment but not using the u'' syntax nor importing unicode_literals
+ sm_tests = [
+ ('Example',
+ [(str('class'), str('test123'))],
+ '\n class="test123"\n href="http://example.com"\n "Example"'),
+
+ ('',
+ [(str('rel'), str('alternate'))],
+ '\n href="http://example.com/cow"\n rel="alternate"\n "Example"')
+ ]
+
+ for tree in treeTypes.items():
+ for intext, attrs, expected in sm_tests:
+ yield runTreewalkerEditTest, intext, expected, attrs, tree
diff --git a/html5lib/treewalkers/_base.py b/html5lib/treewalkers/_base.py
index 48b6da48..bd0e0c58 100644
--- a/html5lib/treewalkers/_base.py
+++ b/html5lib/treewalkers/_base.py
@@ -1,5 +1,5 @@
from __future__ import absolute_import, division, unicode_literals
-from six import text_type
+from six import text_type, string_types
import gettext
_ = gettext.gettext
@@ -8,6 +8,24 @@
spaceCharacters = "".join(spaceCharacters)
+def to_text(s, blank_if_none=True):
+ """Wrapper around six.text_type to convert None to empty string"""
+ if s is None:
+ if blank_if_none:
+ return ""
+ else:
+ return None
+ elif isinstance(s, text_type):
+ return s
+ else:
+ return text_type(s)
+
+
+def is_text_or_none(string):
+ """Wrapper around isinstance(string_types) or is None"""
+ return string is None or isinstance(string, string_types)
+
+
class TreeWalker(object):
def __init__(self, tree):
self.tree = tree
@@ -19,45 +37,47 @@ def error(self, msg):
return {"type": "SerializeError", "data": msg}
def emptyTag(self, namespace, name, attrs, hasChildren=False):
- assert namespace is None or isinstance(namespace, text_type), type(namespace)
- assert isinstance(name, text_type), type(name)
- assert all((namespace is None or isinstance(namespace, text_type)) and
- isinstance(name, text_type) and
- isinstance(value, text_type)
+ assert namespace is None or isinstance(namespace, string_types), type(namespace)
+ assert isinstance(name, string_types), type(name)
+ assert all((namespace is None or isinstance(namespace, string_types)) and
+ isinstance(name, string_types) and
+ isinstance(value, string_types)
for (namespace, name), value in attrs.items())
- yield {"type": "EmptyTag", "name": name,
- "namespace": namespace,
+ yield {"type": "EmptyTag", "name": to_text(name, False),
+ "namespace": to_text(namespace),
"data": attrs}
if hasChildren:
yield self.error(_("Void element has children"))
def startTag(self, namespace, name, attrs):
- assert namespace is None or isinstance(namespace, text_type), type(namespace)
- assert isinstance(name, text_type), type(name)
- assert all((namespace is None or isinstance(namespace, text_type)) and
- isinstance(name, text_type) and
- isinstance(value, text_type)
+ assert namespace is None or isinstance(namespace, string_types), type(namespace)
+ assert isinstance(name, string_types), type(name)
+ assert all((namespace is None or isinstance(namespace, string_types)) and
+ isinstance(name, string_types) and
+ isinstance(value, string_types)
for (namespace, name), value in attrs.items())
return {"type": "StartTag",
- "name": name,
- "namespace": namespace,
- "data": attrs}
+ "name": text_type(name),
+ "namespace": to_text(namespace),
+ "data": dict(((to_text(namespace, False), to_text(name)),
+ to_text(value, False))
+ for (namespace, name), value in attrs.items())}
def endTag(self, namespace, name):
- assert namespace is None or isinstance(namespace, text_type), type(namespace)
- assert isinstance(name, text_type), type(namespace)
+ assert namespace is None or isinstance(namespace, string_types), type(namespace)
+ assert isinstance(name, string_types), type(namespace)
return {"type": "EndTag",
- "name": name,
- "namespace": namespace,
+ "name": to_text(name, False),
+ "namespace": to_text(namespace),
"data": {}}
def text(self, data):
- assert isinstance(data, text_type), type(data)
+ assert isinstance(data, string_types), type(data)
- data = data
+ data = to_text(data)
middle = data.lstrip(spaceCharacters)
left = data[:len(data) - len(middle)]
if left:
@@ -71,25 +91,25 @@ def text(self, data):
yield {"type": "SpaceCharacters", "data": right}
def comment(self, data):
- assert isinstance(data, text_type), type(data)
+ assert isinstance(data, string_types), type(data)
- return {"type": "Comment", "data": data}
+ return {"type": "Comment", "data": text_type(data)}
def doctype(self, name, publicId=None, systemId=None, correct=True):
- assert name is None or isinstance(name, text_type), type(name)
- assert publicId is None or isinstance(publicId, text_type), type(publicId)
- assert systemId is None or isinstance(systemId, text_type), type(systemId)
+ assert is_text_or_none(name), type(name)
+ assert is_text_or_none(publicId), type(publicId)
+ assert is_text_or_none(systemId), type(systemId)
return {"type": "Doctype",
- "name": name if name is not None else "",
- "publicId": publicId,
- "systemId": systemId,
- "correct": correct}
+ "name": to_text(name),
+ "publicId": to_text(publicId),
+ "systemId": to_text(systemId),
+ "correct": to_text(correct)}
def entity(self, name):
- assert isinstance(name, text_type), type(name)
+ assert isinstance(name, string_types), type(name)
- return {"type": "Entity", "name": name}
+ return {"type": "Entity", "name": text_type(name)}
def unknown(self, nodeType):
return self.error(_("Unknown node type: ") + nodeType)