diff --git a/html5lib/tests/test_treewalkers.py b/html5lib/tests/test_treewalkers.py
index 332027ac..81ed2778 100644
--- a/html5lib/tests/test_treewalkers.py
+++ b/html5lib/tests/test_treewalkers.py
@@ -2,6 +2,11 @@
import pytest
+try:
+ import lxml.etree
+except ImportError:
+ pass
+
from .support import treeTypes
from html5lib import html5parser, treewalkers
@@ -93,3 +98,19 @@ def test_treewalker_six_mix():
for tree in sorted(treeTypes.items()):
for intext, attrs, expected in sm_tests:
yield runTreewalkerEditTest, intext, expected, attrs, tree
+
+
+@pytest.mark.skipif(treeTypes["lxml"] is None, reason="lxml not importable")
+def test_lxml_xml():
+ expected = [
+ {'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'},
+ {'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'},
+ {'name': 'div', 'namespace': None, 'type': 'EndTag'},
+ {'name': 'div', 'namespace': None, 'type': 'EndTag'}
+ ]
+
+ lxmltree = lxml.etree.fromstring('
')
+ walker = treewalkers.getTreeWalker('lxml')
+ output = Lint(walker(lxmltree))
+
+ assert list(output) == expected
diff --git a/html5lib/treewalkers/lxmletree.py b/html5lib/treewalkers/lxmletree.py
index 7d99adc2..ff31a44e 100644
--- a/html5lib/treewalkers/lxmletree.py
+++ b/html5lib/treewalkers/lxmletree.py
@@ -22,13 +22,20 @@ class Root(object):
def __init__(self, et):
self.elementtree = et
self.children = []
- if et.docinfo.internalDTD:
- self.children.append(Doctype(self,
- ensure_str(et.docinfo.root_name),
- ensure_str(et.docinfo.public_id),
- ensure_str(et.docinfo.system_url)))
- root = et.getroot()
- node = root
+
+ try:
+ if et.docinfo.internalDTD:
+ self.children.append(Doctype(self,
+ ensure_str(et.docinfo.root_name),
+ ensure_str(et.docinfo.public_id),
+ ensure_str(et.docinfo.system_url)))
+ except AttributeError:
+ pass
+
+ try:
+ node = et.getroot()
+ except AttributeError:
+ node = et
while node.getprevious() is not None:
node = node.getprevious()
@@ -118,12 +125,12 @@ def __len__(self):
class TreeWalker(_base.NonRecursiveTreeWalker):
def __init__(self, tree):
# pylint:disable=redefined-variable-type
- if hasattr(tree, "getroot"):
- self.fragmentChildren = set()
- tree = Root(tree)
- elif isinstance(tree, list):
+ if isinstance(tree, list):
self.fragmentChildren = set(tree)
tree = FragmentRoot(tree)
+ else:
+ self.fragmentChildren = set()
+ tree = Root(tree)
_base.NonRecursiveTreeWalker.__init__(self, tree)
self.filter = ihatexml.InfosetFilter()