[utils] traverse_obj: Various improvements

- Add `set` key for transformations/filters
- Add `re.Match` group names
- Fix behavior for `expected_type` with `dict` key
- Raise for filter function signature mismatch in debug

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2023-02-02 06:40:19 +01:00 committed by GitHub
parent 8b008d6254
commit 776995bc10
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 88 additions and 10 deletions

View File

@ -105,6 +105,7 @@ from yt_dlp.utils import (
sanitized_Request, sanitized_Request,
shell_quote, shell_quote,
smuggle_url, smuggle_url,
str_or_none,
str_to_int, str_to_int,
strip_jsonp, strip_jsonp,
strip_or_none, strip_or_none,
@ -2015,6 +2016,29 @@ Line 1
msg='function as query key should perform a filter based on (key, value)') msg='function as query key should perform a filter based on (key, value)')
self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'}, self.assertCountEqual(traverse_obj(_TEST_DATA, lambda _, x: isinstance(x[0], str)), {'str'},
msg='exceptions in the query function should be catched') msg='exceptions in the query function should be catched')
if __debug__:
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a: ...)
with self.assertRaises(Exception, msg='Wrong function signature should raise in debug'):
traverse_obj(_TEST_DATA, lambda a, b, c: ...)
# Test set as key (transformation/type, like `expected_type`)
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper}, )), ['STR'],
msg='Function in set should be a transformation')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str})), ['str'],
msg='Type in set should be a type filter')
self.assertEqual(traverse_obj(_TEST_DATA, {dict}), _TEST_DATA,
msg='A single set should be wrapped into a path')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str.upper})), ['STR'],
msg='Transformation function should not raise')
self.assertEqual(traverse_obj(_TEST_DATA, (..., {str_or_none})),
[item for item in map(str_or_none, _TEST_DATA.values()) if item is not None],
msg='Function in set should be a transformation')
if __debug__:
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, set())
with self.assertRaises(Exception, msg='Sets with length != 1 should raise in debug'):
traverse_obj(_TEST_DATA, {str.upper, str})
# Test alternative paths # Test alternative paths
self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str', self.assertEqual(traverse_obj(_TEST_DATA, 'fail', 'str'), 'str',
@ -2106,6 +2130,20 @@ Line 1
msg='wrap expected_type fuction in try_call') msg='wrap expected_type fuction in try_call')
self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'], self.assertEqual(traverse_obj(_EXPECTED_TYPE_DATA, ..., expected_type=str), ['str'],
msg='eliminate items that expected_type fails on') msg='eliminate items that expected_type fails on')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2}, expected_type=int), {0: 100},
msg='type as expected_type should filter dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: 100, 1: 1.2, 2: 'None'}, expected_type=str_or_none), {0: '100', 1: '1.2'},
msg='function as expected_type should transform dict values')
self.assertEqual(traverse_obj(_TEST_DATA, ({0: 1.2}, 0, {int_or_none}), expected_type=int), 1,
msg='expected_type should not filter non final dict values')
self.assertEqual(traverse_obj(_TEST_DATA, {0: {0: 100, 1: 'str'}}, expected_type=int), {0: {0: 100}},
msg='expected_type should transform deep dict values')
self.assertEqual(traverse_obj(_TEST_DATA, [({0: '...'}, {0: '...'})], expected_type=type(...)), [{0: ...}, {0: ...}],
msg='expected_type should transform branched dict values')
self.assertEqual(traverse_obj({1: {3: 4}}, [(1, 2), 3], expected_type=int), [4],
msg='expected_type regression for type matching in tuple branching')
self.assertEqual(traverse_obj(_TEST_DATA, ['data', ...], expected_type=int), [],
msg='expected_type regression for type matching in dict result')
# Test get_all behavior # Test get_all behavior
_GET_ALL_DATA = {'key': [0, 1, 2]} _GET_ALL_DATA = {'key': [0, 1, 2]}
@ -2189,6 +2227,8 @@ Line 1
msg='failing str key on a `re.Match` should return `default`') msg='failing str key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, 8), None, self.assertEqual(traverse_obj(mobj, 8), None,
msg='failing int key on a `re.Match` should return `default`') msg='failing int key on a `re.Match` should return `default`')
self.assertEqual(traverse_obj(mobj, lambda k, _: k in (0, 'group')), ['0123', '3'],
msg='function on a `re.Match` should give group name as well')
if __name__ == '__main__': if __name__ == '__main__':

View File

@ -5424,6 +5424,9 @@ def traverse_obj(
The keys in the path can be one of: The keys in the path can be one of:
- `None`: Return the current object. - `None`: Return the current object.
- `set`: Requires the only item in the set to be a type or function,
like `{type}`/`{func}`. If a `type`, returns only values
of this type. If a function, returns `func(obj)`.
- `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`. - `str`/`int`: Return `obj[key]`. For `re.Match`, return `obj.group(key)`.
- `slice`: Branch out and return all values in `obj[key]`. - `slice`: Branch out and return all values in `obj[key]`.
- `Ellipsis`: Branch out and return a list of all values. - `Ellipsis`: Branch out and return a list of all values.
@ -5432,6 +5435,8 @@ def traverse_obj(
- `function`: Branch out and return values filtered by the function. - `function`: Branch out and return values filtered by the function.
Read as: `[value for key, value in obj if function(key, value)]`. Read as: `[value for key, value in obj if function(key, value)]`.
For `Sequence`s, `key` is the index of the value. For `Sequence`s, `key` is the index of the value.
For `re.Match`es, `key` is the group number (0 = full match)
as well as additionally any group names, if given.
- `dict` Transform the current object and return a matching dict. - `dict` Transform the current object and return a matching dict.
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
@ -5441,6 +5446,8 @@ def traverse_obj(
@param default Value to return if the paths do not match. @param default Value to return if the paths do not match.
@param expected_type If a `type`, only accept final values of this type. @param expected_type If a `type`, only accept final values of this type.
If any other callable, try to call the function on each result. If any other callable, try to call the function on each result.
If the last key in the path is a `dict`, it will apply to each value inside
the dict instead, recursively. This does respect branching paths.
@param get_all If `False`, return the first matching result, otherwise all matching ones. @param get_all If `False`, return the first matching result, otherwise all matching ones.
@param casesense If `False`, consider string dictionary keys as case insensitive. @param casesense If `False`, consider string dictionary keys as case insensitive.
@ -5466,16 +5473,25 @@ def traverse_obj(
else: else:
type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,)) type_test = lambda val: try_call(expected_type or IDENTITY, args=(val,))
def apply_key(key, obj): def apply_key(key, test_type, obj):
if obj is None: if obj is None:
return return
elif key is None: elif key is None:
yield obj yield obj
elif isinstance(key, set):
assert len(key) == 1, 'Set should only be used to wrap a single item'
item = next(iter(key))
if isinstance(item, type):
if isinstance(obj, item):
yield obj
else:
yield try_call(item, args=(obj,))
elif isinstance(key, (list, tuple)): elif isinstance(key, (list, tuple)):
for branch in key: for branch in key:
_, result = apply_path(obj, branch) _, result = apply_path(obj, branch, test_type)
yield from result yield from result
elif key is ...: elif key is ...:
@ -5494,7 +5510,9 @@ def traverse_obj(
elif isinstance(obj, collections.abc.Mapping): elif isinstance(obj, collections.abc.Mapping):
iter_obj = obj.items() iter_obj = obj.items()
elif isinstance(obj, re.Match): elif isinstance(obj, re.Match):
iter_obj = enumerate((obj.group(), *obj.groups())) iter_obj = itertools.chain(
enumerate((obj.group(), *obj.groups())),
obj.groupdict().items())
elif traverse_string: elif traverse_string:
iter_obj = enumerate(str(obj)) iter_obj = enumerate(str(obj))
else: else:
@ -5502,7 +5520,7 @@ def traverse_obj(
yield from (v for k, v in iter_obj if try_call(key, args=(k, v))) yield from (v for k, v in iter_obj if try_call(key, args=(k, v)))
elif isinstance(key, dict): elif isinstance(key, dict):
iter_obj = ((k, _traverse_obj(obj, v)) for k, v in key.items()) iter_obj = ((k, _traverse_obj(obj, v, test_type=test_type)) for k, v in key.items())
yield {k: v if v is not None else default for k, v in iter_obj yield {k: v if v is not None else default for k, v in iter_obj
if v is not None or default is not NO_DEFAULT} if v is not None or default is not NO_DEFAULT}
@ -5537,11 +5555,24 @@ def traverse_obj(
with contextlib.suppress(IndexError): with contextlib.suppress(IndexError):
yield obj[key] yield obj[key]
def apply_path(start_obj, path): def lazy_last(iterable):
iterator = iter(iterable)
prev = next(iterator, NO_DEFAULT)
if prev is NO_DEFAULT:
return
for item in iterator:
yield False, prev
prev = item
yield True, prev
def apply_path(start_obj, path, test_type=False):
objs = (start_obj,) objs = (start_obj,)
has_branched = False has_branched = False
for key in variadic(path): key = None
for last, key in lazy_last(variadic(path, (str, bytes, dict, set))):
if is_user_input and key == ':': if is_user_input and key == ':':
key = ... key = ...
@ -5551,14 +5582,21 @@ def traverse_obj(
if key is ... or isinstance(key, (list, tuple)) or callable(key): if key is ... or isinstance(key, (list, tuple)) or callable(key):
has_branched = True has_branched = True
key_func = functools.partial(apply_key, key) if __debug__ and callable(key):
# Verify function signature
inspect.signature(key).bind(None, None)
key_func = functools.partial(apply_key, key, last)
objs = itertools.chain.from_iterable(map(key_func, objs)) objs = itertools.chain.from_iterable(map(key_func, objs))
if test_type and not isinstance(key, (dict, list, tuple)):
objs = map(type_test, objs)
return has_branched, objs return has_branched, objs
def _traverse_obj(obj, path, use_list=True): def _traverse_obj(obj, path, use_list=True, test_type=True):
has_branched, results = apply_path(obj, path) has_branched, results = apply_path(obj, path, test_type)
results = LazyList(x for x in map(type_test, results) if x is not None) results = LazyList(x for x in results if x is not None)
if get_all and has_branched: if get_all and has_branched:
return results.exhaust() if results or use_list else None return results.exhaust() if results or use_list else None