1
0
mirror of https://github.com/ihabunek/toot.git synced 2025-04-18 00:48:47 -04:00

Add batched helper

This commit is contained in:
Ivan Habunek 2025-01-12 12:00:59 +01:00
parent c98af14a84
commit 38dfd747f6
No known key found for this signature in database
GPG Key ID: 01DB3DD0D824504C
2 changed files with 31 additions and 2 deletions

View File

@ -7,7 +7,7 @@ from toot.wcstring import wc_wrap, trunc, pad, fit_text
from toot.tui.utils import LRUCache from toot.tui.utils import LRUCache
from PIL import Image from PIL import Image
from collections import namedtuple from collections import namedtuple
from toot.utils import urlencode_url from toot.utils import batched, urlencode_url
def test_pad(): def test_pad():
@ -319,3 +319,15 @@ def test_urlencode_url():
assert urlencode_url("https://www.example.com") == "https://www.example.com" assert urlencode_url("https://www.example.com") == "https://www.example.com"
assert urlencode_url("https://www.example.com/url%20with%20spaces") == "https://www.example.com/url%20with%20spaces" assert urlencode_url("https://www.example.com/url%20with%20spaces") == "https://www.example.com/url%20with%20spaces"
def test_batched():
assert list(batched("", 2)) == []
assert list(batched("a", 2)) == [["a"]]
assert list(batched("ab", 2)) == [["a", "b"]]
assert list(batched("abc", 2)) == [["a", "b"], ["c"]]
assert list(batched("abcd", 2)) == [["a", "b"], ["c", "d"]]
assert list(batched("abcde", 2)) == [["a", "b"], ["c", "d"], ["e"]]
assert list(batched("abcdef", 2)) == [["a", "b"], ["c", "d"], ["e", "f"]]
with pytest.raises(ValueError):
list(batched("foo", 0))

View File

@ -9,7 +9,8 @@ import warnings
from bs4 import BeautifulSoup from bs4 import BeautifulSoup
from importlib.metadata import version from importlib.metadata import version
from typing import Any, Dict, Generator, List, Optional from itertools import islice
from typing import Any, Dict, Generator, Iterable, List, Optional, TypeVar
from urllib.parse import urlparse, urlencode, quote, unquote from urllib.parse import urlparse, urlencode, quote, unquote
@ -164,3 +165,19 @@ def get_version(name):
return version(name) return version(name)
except Exception: except Exception:
return None return None
T = TypeVar("T")
def batched(iterable: Iterable[T], n: int) -> Generator[List[T], None, None]:
"""Batch data from the iterable into lists of length n. The last batch may
be shorter than n."""
if n < 1:
raise ValueError("n must be positive")
iterator = iter(iterable)
while True:
batch = list(islice(iterator, n))
if batch:
yield batch
else:
break