mirror of
https://github.com/yt-dlp/yt-dlp.git
synced 2024-12-22 18:17:17 -05:00
[utils] traverse_obj
: Support xml.etree.ElementTree.Element
(#8911)
Authored by: Grub4K
This commit is contained in:
parent
292d60b1ed
commit
ffbd4f2a02
@ -2340,6 +2340,58 @@ Line 1
|
|||||||
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
|
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
|
||||||
msg='function on a `re.Match` should give group name as well')
|
msg='function on a `re.Match` should give group name as well')
|
||||||
|
|
||||||
|
# Test xml.etree.ElementTree.Element as input obj
|
||||||
|
etree = xml.etree.ElementTree.fromstring('''<?xml version="1.0"?>
|
||||||
|
<data>
|
||||||
|
<country name="Liechtenstein">
|
||||||
|
<rank>1</rank>
|
||||||
|
<year>2008</year>
|
||||||
|
<gdppc>141100</gdppc>
|
||||||
|
<neighbor name="Austria" direction="E"/>
|
||||||
|
<neighbor name="Switzerland" direction="W"/>
|
||||||
|
</country>
|
||||||
|
<country name="Singapore">
|
||||||
|
<rank>4</rank>
|
||||||
|
<year>2011</year>
|
||||||
|
<gdppc>59900</gdppc>
|
||||||
|
<neighbor name="Malaysia" direction="N"/>
|
||||||
|
</country>
|
||||||
|
<country name="Panama">
|
||||||
|
<rank>68</rank>
|
||||||
|
<year>2011</year>
|
||||||
|
<gdppc>13600</gdppc>
|
||||||
|
<neighbor name="Costa Rica" direction="W"/>
|
||||||
|
<neighbor name="Colombia" direction="E"/>
|
||||||
|
</country>
|
||||||
|
</data>''')
|
||||||
|
self.assertEqual(traverse_obj(etree, ''), etree,
|
||||||
|
msg='empty str key should return the element itself')
|
||||||
|
self.assertEqual(traverse_obj(etree, 'country'), list(etree),
|
||||||
|
msg='str key should lead all children with that tag name')
|
||||||
|
self.assertEqual(traverse_obj(etree, ...), list(etree),
|
||||||
|
msg='`...` as key should return all children')
|
||||||
|
self.assertEqual(traverse_obj(etree, lambda _, x: x[0].text == '4'), [etree[1]],
|
||||||
|
msg='function as key should get element as value')
|
||||||
|
self.assertEqual(traverse_obj(etree, lambda i, _: i == 1), [etree[1]],
|
||||||
|
msg='function as key should get index as key')
|
||||||
|
self.assertEqual(traverse_obj(etree, 0), etree[0],
|
||||||
|
msg='int key should return the nth child')
|
||||||
|
self.assertEqual(traverse_obj(etree, './/neighbor/@name'),
|
||||||
|
['Austria', 'Switzerland', 'Malaysia', 'Costa Rica', 'Colombia'],
|
||||||
|
msg='`@<attribute>` at end of path should give that attribute')
|
||||||
|
self.assertEqual(traverse_obj(etree, '//neighbor/@fail'), [None, None, None, None, None],
|
||||||
|
msg='`@<nonexistant>` at end of path should give `None`')
|
||||||
|
self.assertEqual(traverse_obj(etree, ('//neighbor/@', 2)), {'name': 'Malaysia', 'direction': 'N'},
|
||||||
|
msg='`@` should give the full attribute dict')
|
||||||
|
self.assertEqual(traverse_obj(etree, '//year/text()'), ['2008', '2011', '2011'],
|
||||||
|
msg='`text()` at end of path should give the inner text')
|
||||||
|
self.assertEqual(traverse_obj(etree, '//*[@direction]/@direction'), ['E', 'W', 'N', 'W', 'E'],
|
||||||
|
msg='full python xpath features should be supported')
|
||||||
|
self.assertEqual(traverse_obj(etree, (0, '@name')), 'Liechtenstein',
|
||||||
|
msg='special transformations should act on current element')
|
||||||
|
self.assertEqual(traverse_obj(etree, ('country', 0, ..., 'text()', {int_or_none})), [1, 2008, 141100],
|
||||||
|
msg='special transformations should act on current element')
|
||||||
|
|
||||||
def test_http_header_dict(self):
|
def test_http_header_dict(self):
|
||||||
headers = HTTPHeaderDict()
|
headers = HTTPHeaderDict()
|
||||||
headers['ytdl-test'] = b'0'
|
headers['ytdl-test'] = b'0'
|
||||||
|
@ -3,6 +3,7 @@ import contextlib
|
|||||||
import inspect
|
import inspect
|
||||||
import itertools
|
import itertools
|
||||||
import re
|
import re
|
||||||
|
import xml.etree.ElementTree
|
||||||
|
|
||||||
from ._utils import (
|
from ._utils import (
|
||||||
IDENTITY,
|
IDENTITY,
|
||||||
@ -118,7 +119,7 @@ def traverse_obj(
|
|||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
result = obj.values()
|
result = obj.values()
|
||||||
elif is_iterable_like(obj):
|
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
|
||||||
result = obj
|
result = obj
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
result = obj.groups()
|
result = obj.groups()
|
||||||
@ -132,7 +133,7 @@ def traverse_obj(
|
|||||||
branching = True
|
branching = True
|
||||||
if isinstance(obj, collections.abc.Mapping):
|
if isinstance(obj, collections.abc.Mapping):
|
||||||
iter_obj = obj.items()
|
iter_obj = obj.items()
|
||||||
elif is_iterable_like(obj):
|
elif is_iterable_like(obj) or isinstance(obj, xml.etree.ElementTree.Element):
|
||||||
iter_obj = enumerate(obj)
|
iter_obj = enumerate(obj)
|
||||||
elif isinstance(obj, re.Match):
|
elif isinstance(obj, re.Match):
|
||||||
iter_obj = itertools.chain(
|
iter_obj = itertools.chain(
|
||||||
@ -168,7 +169,7 @@ def traverse_obj(
|
|||||||
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
result = next((v for k, v in obj.groupdict().items() if casefold(k) == key), None)
|
||||||
|
|
||||||
elif isinstance(key, (int, slice)):
|
elif isinstance(key, (int, slice)):
|
||||||
if is_iterable_like(obj, collections.abc.Sequence):
|
if is_iterable_like(obj, (collections.abc.Sequence, xml.etree.ElementTree.Element)):
|
||||||
branching = isinstance(key, slice)
|
branching = isinstance(key, slice)
|
||||||
with contextlib.suppress(IndexError):
|
with contextlib.suppress(IndexError):
|
||||||
result = obj[key]
|
result = obj[key]
|
||||||
@ -176,6 +177,34 @@ def traverse_obj(
|
|||||||
with contextlib.suppress(IndexError):
|
with contextlib.suppress(IndexError):
|
||||||
result = str(obj)[key]
|
result = str(obj)[key]
|
||||||
|
|
||||||
|
elif isinstance(obj, xml.etree.ElementTree.Element) and isinstance(key, str):
|
||||||
|
xpath, _, special = key.rpartition('/')
|
||||||
|
if not special.startswith('@') and special != 'text()':
|
||||||
|
xpath = key
|
||||||
|
special = None
|
||||||
|
|
||||||
|
# Allow abbreviations of relative paths, absolute paths error
|
||||||
|
if xpath.startswith('/'):
|
||||||
|
xpath = f'.{xpath}'
|
||||||
|
elif xpath and not xpath.startswith('./'):
|
||||||
|
xpath = f'./{xpath}'
|
||||||
|
|
||||||
|
def apply_specials(element):
|
||||||
|
if special is None:
|
||||||
|
return element
|
||||||
|
if special == '@':
|
||||||
|
return element.attrib
|
||||||
|
if special.startswith('@'):
|
||||||
|
return try_call(element.attrib.get, args=(special[1:],))
|
||||||
|
if special == 'text()':
|
||||||
|
return element.text
|
||||||
|
assert False, f'apply_specials is missing case for {special!r}'
|
||||||
|
|
||||||
|
if xpath:
|
||||||
|
result = list(map(apply_specials, obj.iterfind(xpath)))
|
||||||
|
else:
|
||||||
|
result = apply_specials(obj)
|
||||||
|
|
||||||
return branching, result if branching else (result,)
|
return branching, result if branching else (result,)
|
||||||
|
|
||||||
def lazy_last(iterable):
|
def lazy_last(iterable):
|
||||||
|
Loading…
Reference in New Issue
Block a user