Skip to content

Commit 2f1d040

Browse files
committed
Fix #228: make sure the lxml treewalker works with trees from lxml
1 parent 5288737 commit 2f1d040

File tree

2 files changed

+39
-11
lines changed

2 files changed

+39
-11
lines changed

html5lib/tests/test_treewalkers.py

+21
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,11 @@
22

33
import pytest
44

5+
try:
6+
import lxml.etree
7+
except ImportError:
8+
pass
9+
510
from .support import treeTypes
611

712
from html5lib import html5parser, treewalkers
@@ -93,3 +98,19 @@ def test_treewalker_six_mix():
9398
for tree in sorted(treeTypes.items()):
9499
for intext, attrs, expected in sm_tests:
95100
yield runTreewalkerEditTest, intext, expected, attrs, tree
101+
102+
103+
@pytest.mark.skipif(treeTypes["lxml"] is None, reason="lxml not importable")
104+
def test_lxml_xml():
105+
expected = [
106+
{'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'},
107+
{'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'},
108+
{'name': 'div', 'namespace': None, 'type': 'EndTag'},
109+
{'name': 'div', 'namespace': None, 'type': 'EndTag'}
110+
]
111+
112+
lxmltree = lxml.etree.fromstring('<div><div></div></div>')
113+
walker = treewalkers.getTreeWalker('lxml')
114+
output = Lint(walker(lxmltree))
115+
116+
assert list(output) == expected

html5lib/treewalkers/lxmletree.py

+18-11
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,20 @@ class Root(object):
2222
def __init__(self, et):
2323
self.elementtree = et
2424
self.children = []
25-
if et.docinfo.internalDTD:
26-
self.children.append(Doctype(self,
27-
ensure_str(et.docinfo.root_name),
28-
ensure_str(et.docinfo.public_id),
29-
ensure_str(et.docinfo.system_url)))
30-
root = et.getroot()
31-
node = root
25+
26+
try:
27+
if et.docinfo.internalDTD:
28+
self.children.append(Doctype(self,
29+
ensure_str(et.docinfo.root_name),
30+
ensure_str(et.docinfo.public_id),
31+
ensure_str(et.docinfo.system_url)))
32+
except AttributeError:
33+
pass
34+
35+
try:
36+
node = et.getroot()
37+
except AttributeError:
38+
node = et
3239

3340
while node.getprevious() is not None:
3441
node = node.getprevious()
@@ -118,12 +125,12 @@ def __len__(self):
118125
class TreeWalker(_base.NonRecursiveTreeWalker):
119126
def __init__(self, tree):
120127
# pylint:disable=redefined-variable-type
121-
if hasattr(tree, "getroot"):
122-
self.fragmentChildren = set()
123-
tree = Root(tree)
124-
elif isinstance(tree, list):
128+
if isinstance(tree, list):
125129
self.fragmentChildren = set(tree)
126130
tree = FragmentRoot(tree)
131+
else:
132+
self.fragmentChildren = set()
133+
tree = Root(tree)
127134
_base.NonRecursiveTreeWalker.__init__(self, tree)
128135
self.filter = ihatexml.InfosetFilter()
129136

0 commit comments

Comments
 (0)