From a2ee6e6b0ec015a3a777546192d8af80074d60b0 Mon Sep 17 00:00:00 2001 From: Geoffrey Sneddon Date: Sun, 22 May 2016 04:12:53 +0100 Subject: [PATCH] Fix #228: make sure the lxml treewalker works with trees from lxml --- html5lib/tests/test_treewalkers.py | 21 +++++++++++++++++++++ html5lib/treewalkers/lxmletree.py | 29 ++++++++++++++++++----------- 2 files changed, 39 insertions(+), 11 deletions(-) diff --git a/html5lib/tests/test_treewalkers.py b/html5lib/tests/test_treewalkers.py index 332027ac..81ed2778 100644 --- a/html5lib/tests/test_treewalkers.py +++ b/html5lib/tests/test_treewalkers.py @@ -2,6 +2,11 @@ import pytest +try: + import lxml.etree +except ImportError: + pass + from .support import treeTypes from html5lib import html5parser, treewalkers @@ -93,3 +98,19 @@ def test_treewalker_six_mix(): for tree in sorted(treeTypes.items()): for intext, attrs, expected in sm_tests: yield runTreewalkerEditTest, intext, expected, attrs, tree + + +@pytest.mark.skipif(treeTypes["lxml"] is None, reason="lxml not importable") +def test_lxml_xml(): + expected = [ + {'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'}, + {'data': {}, 'name': 'div', 'namespace': None, 'type': 'StartTag'}, + {'name': 'div', 'namespace': None, 'type': 'EndTag'}, + {'name': 'div', 'namespace': None, 'type': 'EndTag'} + ] + + lxmltree = lxml.etree.fromstring('
') + walker = treewalkers.getTreeWalker('lxml') + output = Lint(walker(lxmltree)) + + assert list(output) == expected diff --git a/html5lib/treewalkers/lxmletree.py b/html5lib/treewalkers/lxmletree.py index 7d99adc2..ff31a44e 100644 --- a/html5lib/treewalkers/lxmletree.py +++ b/html5lib/treewalkers/lxmletree.py @@ -22,13 +22,20 @@ class Root(object): def __init__(self, et): self.elementtree = et self.children = [] - if et.docinfo.internalDTD: - self.children.append(Doctype(self, - ensure_str(et.docinfo.root_name), - ensure_str(et.docinfo.public_id), - ensure_str(et.docinfo.system_url))) - root = et.getroot() - node = root + + try: + if et.docinfo.internalDTD: + self.children.append(Doctype(self, + ensure_str(et.docinfo.root_name), + ensure_str(et.docinfo.public_id), + ensure_str(et.docinfo.system_url))) + except AttributeError: + pass + + try: + node = et.getroot() + except AttributeError: + node = et while node.getprevious() is not None: node = node.getprevious() @@ -118,12 +125,12 @@ def __len__(self): class TreeWalker(_base.NonRecursiveTreeWalker): def __init__(self, tree): # pylint:disable=redefined-variable-type - if hasattr(tree, "getroot"): - self.fragmentChildren = set() - tree = Root(tree) - elif isinstance(tree, list): + if isinstance(tree, list): self.fragmentChildren = set(tree) tree = FragmentRoot(tree) + else: + self.fragmentChildren = set() + tree = Root(tree) _base.NonRecursiveTreeWalker.__init__(self, tree) self.filter = ihatexml.InfosetFilter()