@@ -22,13 +22,20 @@ class Root(object):
22
22
def __init__ (self , et ):
23
23
self .elementtree = et
24
24
self .children = []
25
- if et .docinfo .internalDTD :
26
- self .children .append (Doctype (self ,
27
- ensure_str (et .docinfo .root_name ),
28
- ensure_str (et .docinfo .public_id ),
29
- ensure_str (et .docinfo .system_url )))
30
- root = et .getroot ()
31
- node = root
25
+
26
+ try :
27
+ if et .docinfo .internalDTD :
28
+ self .children .append (Doctype (self ,
29
+ ensure_str (et .docinfo .root_name ),
30
+ ensure_str (et .docinfo .public_id ),
31
+ ensure_str (et .docinfo .system_url )))
32
+ except AttributeError :
33
+ pass
34
+
35
+ try :
36
+ node = et .getroot ()
37
+ except AttributeError :
38
+ node = et
32
39
33
40
while node .getprevious () is not None :
34
41
node = node .getprevious ()
@@ -118,12 +125,12 @@ def __len__(self):
118
125
class TreeWalker (_base .NonRecursiveTreeWalker ):
119
126
def __init__ (self , tree ):
120
127
# pylint:disable=redefined-variable-type
121
- if hasattr (tree , "getroot" ):
122
- self .fragmentChildren = set ()
123
- tree = Root (tree )
124
- elif isinstance (tree , list ):
128
+ if isinstance (tree , list ):
125
129
self .fragmentChildren = set (tree )
126
130
tree = FragmentRoot (tree )
131
+ else :
132
+ self .fragmentChildren = set ()
133
+ tree = Root (tree )
127
134
_base .NonRecursiveTreeWalker .__init__ (self , tree )
128
135
self .filter = ihatexml .InfosetFilter ()
129
136
0 commit comments