[utils] Generalize traverse_dict to traverse_obj

This commit is contained in:
pukkandan 2021-06-08 14:23:56 +05:30
parent beb982bead
commit 324ad82006
No known key found for this signature in database
GPG Key ID: 0F00D95A001F4698
3 changed files with 34 additions and 17 deletions

View File

@ -101,7 +101,7 @@ from .utils import (
strftime_or_none,
subtitles_filename,
to_high_limit_path,
traverse_dict,
traverse_obj,
UnavailableVideoError,
url_basename,
version_tuple,
@ -855,7 +855,7 @@ class YoutubeDL(object):
def get_value(mdict):
# Object traversal
fields = mdict['fields'].split('.')
value = traverse_dict(info_dict, fields)
value = traverse_obj(info_dict, fields)
# Negative
if mdict['negate']:
value = float_or_none(value)
@ -872,7 +872,7 @@ class YoutubeDL(object):
item, multiplier = (item[1:], -1) if item[0] == '-' else (item, 1)
offset = float_or_none(item)
if offset is None:
offset = float_or_none(traverse_dict(info_dict, item.split('.')))
offset = float_or_none(traverse_obj(info_dict, item.split('.')))
try:
value = operator(value, multiplier * offset)
except (TypeError, ZeroDivisionError):

View File

@ -23,7 +23,7 @@ from ..utils import (
ISO639Utils,
process_communicate_or_kill,
replace_extension,
traverse_dict,
traverse_obj,
)
@ -229,7 +229,7 @@ class FFmpegPostProcessor(PostProcessor):
def get_stream_number(self, path, keys, value):
streams = self.get_metadata_object(path)['streams']
num = next(
(i for i, stream in enumerate(streams) if traverse_dict(stream, keys, casesense=False) == value),
(i for i, stream in enumerate(streams) if traverse_obj(stream, keys, casesense=False) == value),
None)
return num, len(streams)

View File

@ -6181,21 +6181,38 @@ def load_plugins(name, suffix, namespace):
return classes
def traverse_dict(dictn, keys, casesense=True):
def traverse_obj(obj, keys, *, casesense=True, is_user_input=False, traverse_string=False):
''' Traverse nested list/dict/tuple
@param casesense Whether to consider dictionary keys as case sensitive
@param is_user_input Whether the keys are generated from user input. If True,
strings are converted to int/slice if necessary
@param traverse_string Whether to traverse inside strings. If True, any
non-compatible object will also be converted into a string
'''
keys = list(keys)[::-1]
while keys:
key = keys.pop()
if isinstance(dictn, dict):
if isinstance(obj, dict):
assert isinstance(key, compat_str)
if not casesense:
dictn = {k.lower(): v for k, v in dictn.items()}
obj = {k.lower(): v for k, v in obj.items()}
key = key.lower()
dictn = dictn.get(key)
elif isinstance(dictn, (list, tuple, compat_str)):
if ':' in key:
key = slice(*map(int_or_none, key.split(':')))
obj = obj.get(key)
else:
key = int_or_none(key)
dictn = try_get(dictn, lambda x: x[key])
if is_user_input:
key = (int_or_none(key) if ':' not in key
else slice(*map(int_or_none, key.split(':'))))
if not isinstance(obj, (list, tuple)):
if traverse_string:
obj = compat_str(obj)
else:
return None
return dictn
assert isinstance(key, (int, slice))
obj = try_get(obj, lambda x: x[key])
return obj
def traverse_dict(dictn, keys, casesense=True):
''' For backward compatibility. Do not use '''
return traverse_obj(dictn, keys, casesense=casesense,
is_user_input=True, traverse_string=True)