Add extractor helpers (#10653)

Authored by: Grub4K
This commit is contained in:
Simon Sawicki 2024-10-13 05:14:32 +02:00 committed by GitHub
parent 85b87c991a
commit d710a6ca7c
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 261 additions and 11 deletions

View File

@ -4,8 +4,18 @@ import xml.etree.ElementTree
import pytest import pytest
from yt_dlp.utils import dict_get, int_or_none, str_or_none from yt_dlp.utils import (
from yt_dlp.utils.traversal import traverse_obj ExtractorError,
determine_ext,
dict_get,
int_or_none,
str_or_none,
)
from yt_dlp.utils.traversal import (
traverse_obj,
require,
subs_list_to_dict,
)
_TEST_DATA = { _TEST_DATA = {
100: 100, 100: 100,
@ -420,6 +430,71 @@ class TestTraversal:
assert traverse_obj(morsel, [(None,), any]) == morsel, \ assert traverse_obj(morsel, [(None,), any]) == morsel, \
'Morsel should not be implicitly changed to dict on usage' 'Morsel should not be implicitly changed to dict on usage'
def test_traversal_filter(self):
data = [None, False, True, 0, 1, 0.0, 1.1, '', 'str', {}, {0: 0}, [], [1]]
assert traverse_obj(data, [..., filter]) == [True, 1, 1.1, 'str', {0: 0}, [1]], \
'`filter` should filter falsy values'
class TestTraversalHelpers:
def test_traversal_require(self):
with pytest.raises(ExtractorError):
traverse_obj(_TEST_DATA, ['None', {require('value')}])
assert traverse_obj(_TEST_DATA, ['str', {require('value')}]) == 'str', \
'`require` should pass through non `None` values'
def test_subs_list_to_dict(self):
assert traverse_obj([
{'name': 'de', 'url': 'https://example.com/subs/de.vtt'},
{'name': 'en', 'url': 'https://example.com/subs/en1.ass'},
{'name': 'en', 'url': 'https://example.com/subs/en2.ass'},
], [..., {
'id': 'name',
'url': 'url',
}, all, {subs_list_to_dict}]) == {
'de': [{'url': 'https://example.com/subs/de.vtt'}],
'en': [
{'url': 'https://example.com/subs/en1.ass'},
{'url': 'https://example.com/subs/en2.ass'},
],
}, 'function should build subtitle dict from list of subtitles'
assert traverse_obj([
{'name': 'de', 'url': 'https://example.com/subs/de.ass'},
{'name': 'de'},
{'name': 'en', 'content': 'content'},
{'url': 'https://example.com/subs/en'},
], [..., {
'id': 'name',
'data': 'content',
'url': 'url',
}, all, {subs_list_to_dict}]) == {
'de': [{'url': 'https://example.com/subs/de.ass'}],
'en': [{'data': 'content'}],
}, 'subs with mandatory items missing should be filtered'
assert traverse_obj([
{'url': 'https://example.com/subs/de.ass', 'name': 'de'},
{'url': 'https://example.com/subs/en', 'name': 'en'},
], [..., {
'id': 'name',
'ext': ['url', {lambda x: determine_ext(x, default_ext=None)}],
'url': 'url',
}, all, {subs_list_to_dict(ext='ext')}]) == {
'de': [{'url': 'https://example.com/subs/de.ass', 'ext': 'ass'}],
'en': [{'url': 'https://example.com/subs/en', 'ext': 'ext'}],
}, '`ext` should set default ext but leave existing value untouched'
assert traverse_obj([
{'name': 'en', 'url': 'https://example.com/subs/en2', 'prio': True},
{'name': 'en', 'url': 'https://example.com/subs/en1', 'prio': False},
], [..., {
'id': 'name',
'quality': ['prio', {int}],
'url': 'url',
}, all, {subs_list_to_dict(ext='ext')}]) == {'en': [
{'url': 'https://example.com/subs/en1', 'ext': 'ext'},
{'url': 'https://example.com/subs/en2', 'ext': 'ext'},
]}, '`quality` key should sort subtitle list accordingly'
class TestDictGet: class TestDictGet:
def test_dict_get(self): def test_dict_get(self):

View File

@ -573,13 +573,13 @@ class InfoExtractor:
def _login_hint(self, method=NO_DEFAULT, netrc=None): def _login_hint(self, method=NO_DEFAULT, netrc=None):
password_hint = f'--username and --password, --netrc-cmd, or --netrc ({netrc or self._NETRC_MACHINE}) to provide account credentials' password_hint = f'--username and --password, --netrc-cmd, or --netrc ({netrc or self._NETRC_MACHINE}) to provide account credentials'
cookies_hint = 'See https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp for how to manually pass cookies'
return { return {
None: '', None: '',
'any': f'Use --cookies, --cookies-from-browser, {password_hint}', 'any': f'Use --cookies, --cookies-from-browser, {password_hint}. {cookies_hint}',
'password': f'Use {password_hint}', 'password': f'Use {password_hint}',
'cookies': ( 'cookies': f'Use --cookies-from-browser or --cookies for the authentication. {cookies_hint}',
'Use --cookies-from-browser or --cookies for the authentication. ' 'session_cookies': f'Use --cookies for the authentication (--cookies-from-browser might not work). {cookies_hint}',
'See https://github.com/yt-dlp/yt-dlp/wiki/FAQ#how-do-i-pass-cookies-to-yt-dlp for how to manually pass cookies'),
}[method if method is not NO_DEFAULT else 'any' if self.supports_login() else 'cookies'] }[method if method is not NO_DEFAULT else 'any' if self.supports_login() else 'cookies']
def __init__(self, downloader=None): def __init__(self, downloader=None):

View File

@ -1984,11 +1984,30 @@ def urljoin(base, path):
return urllib.parse.urljoin(base, path) return urllib.parse.urljoin(base, path)
def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1): def partial_application(func):
sig = inspect.signature(func)
@functools.wraps(func)
def wrapped(*args, **kwargs):
try:
sig.bind(*args, **kwargs)
except TypeError:
return functools.partial(func, *args, **kwargs)
else:
return func(*args, **kwargs)
return wrapped
@partial_application
def int_or_none(v, scale=1, default=None, get_attr=None, invscale=1, base=None):
if get_attr and v is not None: if get_attr and v is not None:
v = getattr(v, get_attr, None) v = getattr(v, get_attr, None)
if invscale == 1 and scale < 1:
invscale = int(1 / scale)
scale = 1
try: try:
return int(v) * invscale // scale return (int(v) if base is None else int(v, base=base)) * invscale // scale
except (ValueError, TypeError, OverflowError): except (ValueError, TypeError, OverflowError):
return default return default
@ -2006,9 +2025,13 @@ def str_to_int(int_str):
return int_or_none(int_str) return int_or_none(int_str)
@partial_application
def float_or_none(v, scale=1, invscale=1, default=None): def float_or_none(v, scale=1, invscale=1, default=None):
if v is None: if v is None:
return default return default
if invscale == 1 and scale < 1:
invscale = int(1 / scale)
scale = 1
try: try:
return float(v) * invscale / scale return float(v) * invscale / scale
except (ValueError, TypeError): except (ValueError, TypeError):

View File

@ -1,18 +1,35 @@
from __future__ import annotations
import collections
import collections.abc import collections.abc
import contextlib import contextlib
import functools
import http.cookies import http.cookies
import inspect import inspect
import itertools import itertools
import re import re
import typing
import xml.etree.ElementTree import xml.etree.ElementTree
from ._utils import ( from ._utils import (
IDENTITY, IDENTITY,
NO_DEFAULT, NO_DEFAULT,
ExtractorError,
LazyList, LazyList,
deprecation_warning, deprecation_warning,
get_elements_html_by_class,
get_elements_html_by_attribute,
get_elements_by_attribute,
get_element_html_by_attribute,
get_element_by_attribute,
get_element_html_by_id,
get_element_by_id,
get_element_html_by_class,
get_elements_by_class,
get_element_text_and_html_by_tag,
is_iterable_like, is_iterable_like,
try_call, try_call,
url_or_none,
variadic, variadic,
) )
@ -54,6 +71,7 @@ def traverse_obj(
Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`. Read as: `{key: traverse_obj(obj, path) for key, path in dct.items()}`.
- `any`-builtin: Take the first matching object and return it, resetting branching. - `any`-builtin: Take the first matching object and return it, resetting branching.
- `all`-builtin: Take all matching objects and return them as a list, resetting branching. - `all`-builtin: Take all matching objects and return them as a list, resetting branching.
- `filter`-builtin: Return the value if it is truthy, `None` otherwise.
`tuple`, `list`, and `dict` all support nested paths and branches. `tuple`, `list`, and `dict` all support nested paths and branches.
@ -247,6 +265,10 @@ def traverse_obj(
objs = (list(filtered_objs),) objs = (list(filtered_objs),)
continue continue
if key is filter:
objs = filter(None, objs)
continue
if __debug__ and callable(key): if __debug__ and callable(key):
# Verify function signature # Verify function signature
inspect.signature(key).bind(None, None) inspect.signature(key).bind(None, None)
@ -277,13 +299,143 @@ def traverse_obj(
return results[0] if results else {} if allow_empty and is_dict else None return results[0] if results else {} if allow_empty and is_dict else None
for index, path in enumerate(paths, 1): for index, path in enumerate(paths, 1):
result = _traverse_obj(obj, path, index == len(paths), True) is_last = index == len(paths)
if result is not None: try:
return result result = _traverse_obj(obj, path, is_last, True)
if result is not None:
return result
except _RequiredError as e:
if is_last:
# Reraise to get cleaner stack trace
raise ExtractorError(e.orig_msg, expected=e.expected) from None
return None if default is NO_DEFAULT else default return None if default is NO_DEFAULT else default
def value(value, /):
return lambda _: value
def require(name, /, *, expected=False):
def func(value):
if value is None:
raise _RequiredError(f'Unable to extract {name}', expected=expected)
return value
return func
class _RequiredError(ExtractorError):
pass
@typing.overload
def subs_list_to_dict(*, ext: str | None = None) -> collections.abc.Callable[[list[dict]], dict[str, list[dict]]]: ...
@typing.overload
def subs_list_to_dict(subs: list[dict] | None, /, *, ext: str | None = None) -> dict[str, list[dict]]: ...
def subs_list_to_dict(subs: list[dict] | None = None, /, *, ext=None):
"""
Convert subtitles from a traversal into a subtitle dict.
The path should have an `all` immediately before this function.
Arguments:
`ext` The default value for `ext` in the subtitle dict
In the dict you can set the following additional items:
`id` The subtitle id to sort the dict into
`quality` The sort order for each subtitle
"""
if subs is None:
return functools.partial(subs_list_to_dict, ext=ext)
result = collections.defaultdict(list)
for sub in subs:
if not url_or_none(sub.get('url')) and not sub.get('data'):
continue
sub_id = sub.pop('id', None)
if sub_id is None:
continue
if ext is not None and not sub.get('ext'):
sub['ext'] = ext
result[sub_id].append(sub)
result = dict(result)
for subs in result.values():
subs.sort(key=lambda x: x.pop('quality', 0) or 0)
return result
@typing.overload
def find_element(*, attr: str, value: str, tag: str | None = None, html=False): ...
@typing.overload
def find_element(*, cls: str, html=False): ...
@typing.overload
def find_element(*, id: str, tag: str | None = None, html=False): ...
@typing.overload
def find_element(*, tag: str, html=False): ...
def find_element(*, tag=None, id=None, cls=None, attr=None, value=None, html=False):
# deliberately using `id=` and `cls=` for ease of readability
assert tag or id or cls or (attr and value), 'One of tag, id, cls or (attr AND value) is required'
if not tag:
tag = r'[\w:.-]+'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
assert not id, 'Cannot match both attr and id'
func = get_element_html_by_attribute if html else get_element_by_attribute
return functools.partial(func, attr, value, tag=tag)
elif cls:
assert not id, 'Cannot match both cls and id'
assert tag is None, 'Cannot match both cls and tag'
func = get_element_html_by_class if html else get_elements_by_class
return functools.partial(func, cls)
elif id:
func = get_element_html_by_id if html else get_element_by_id
return functools.partial(func, id, tag=tag)
index = int(bool(html))
return lambda html: get_element_text_and_html_by_tag(tag, html)[index]
@typing.overload
def find_elements(*, cls: str, html=False): ...
@typing.overload
def find_elements(*, attr: str, value: str, tag: str | None = None, html=False): ...
def find_elements(*, tag=None, cls=None, attr=None, value=None, html=False):
# deliberately using `cls=` for ease of readability
assert cls or (attr and value), 'One of cls or (attr AND value) is required'
if attr and value:
assert not cls, 'Cannot match both attr and cls'
func = get_elements_html_by_attribute if html else get_elements_by_attribute
return functools.partial(func, attr, value, tag=tag or r'[\w:.-]+')
assert not tag, 'Cannot match both cls and tag'
func = get_elements_html_by_class if html else get_elements_by_class
return functools.partial(func, cls)
def get_first(obj, *paths, **kwargs): def get_first(obj, *paths, **kwargs):
return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False) return traverse_obj(obj, *((..., *variadic(keys)) for keys in paths), **kwargs, get_all=False)