Source code for fumus.queries.itertools_mixin

import itertools as it
import operator

from fumus.utils import Optional


[docs] class ItertoolsMixin: NO_SIGNATURE_FUNCTIONS = {"chain", "islice", "product", "repeat", "zip_longest"} NO_KWARGS_FUNCTIONS = {"dropwhile", "filterfalse", "starmap", "takewhile", "tee"} iterable = None
[docs] def use(self, it_function, **kwargs): """Provides integration with itertools methods; pass corresponding parameters as kwargs""" import inspect if self._handle_no_signature_functions(it_function, **kwargs): return self signature = inspect.signature(it_function).parameters if self._handle_no_kwargs_functions(signature, it_function, **kwargs): return self return self._handle_default_signature_functions(signature, it_function, **kwargs)
[docs] def _handle_no_signature_functions(self, it_function, **kwargs): if it_function.__name__ not in self.NO_SIGNATURE_FUNCTIONS: return False if it_function.__name__ in ("product", "zip_longest"): if isinstance(self.iterable, range): self.iterable = it_function(self.iterable, **kwargs) else: self.iterable = it_function(*self.iterable, **kwargs) return True # functions like 'chain' don't expect key-word arguments self.iterable = it_function(self.iterable, *kwargs.values()) return True
[docs] def _handle_no_kwargs_functions(self, signature, it_function, **kwargs): # handle functions that take only iterable as arg if len(signature.keys()) == 1 and "iterable" in signature: self.iterable = it_function(self.iterable) return True # handle functions that take no kwargs if it_function.__name__ in self.NO_KWARGS_FUNCTIONS: if it_function.__name__ == "tee": self.iterable = it_function(self.iterable, *kwargs.values()) else: self.iterable = it_function(*kwargs.values(), self.iterable) return True return False
[docs] def _handle_default_signature_functions(self, signature, it_function, **kwargs): if "iterable" in signature: kwargs["iterable"] = self.iterable elif "data" in signature: kwargs["data"] = self.iterable self.iterable = it_function(**kwargs) return self
# ### 'recipes' ### # https://docs.python.org/3/library/itertools.html#itertools-recipes
[docs] def tabulate(self, mapper, start=0): """Returns function(0), function(1), ...""" self.iterable = map(mapper, it.count(start)) return self
[docs] def repeat_func(self, operation, times=None): """Repeats calls to func with specified arguments""" self.iterable = it.starmap(operation, it.repeat(self.iterable, times=times)) return self
[docs] def ncycles(self, count=0): """Returns the query elements n times""" self.iterable = it.chain.from_iterable(it.repeat(tuple(self.iterable), count)) return self
[docs] def consume(self, n=None): """Advances the iterator n-steps ahead. If n is None, consumes query entirely""" import collections if n is None: self.iterable = collections.deque(self.iterable, maxlen=0) return self if n < 0: raise ValueError("Consume boundary cannot be negative") self.iterable = it.islice(self.iterable, n, len(self.iterable)) return self
[docs] def take_nth(self, idx, default=None): """Returns Optional with the nth element of the query or a default value""" if idx < 0: idx = len(self.iterable) + idx return Optional.of_nullable(next(it.islice(self.iterable, idx, None), default))
[docs] def all_equal(self, key=None): """Returns True if all elements of the query are equal to each other""" return len(list(it.islice(it.groupby(self.iterable, key), 2))) <= 1
[docs] def view(self, start=0, stop=None, step=None): """Provides access to a selected part of the query""" if start < 0: start = len(self.iterable) + start if stop and stop < 0: stop = len(self.iterable) + stop if step and step < 0: raise ValueError("Step must be a positive integer or None") self.iterable = it.islice(self.iterable, start, stop, step) return self
# ### unique ###
[docs] def unique(self, key=None, reverse=False): """Yields unique elements in sorted order. Supports unhashable inputs""" self.iterable = self._unique(sorted(self.iterable, key=key, reverse=reverse), key=key) return self
[docs] @staticmethod def _unique(iterable, key=None): return map(next, map(operator.itemgetter(1), it.groupby(iterable, key)))
[docs] def unique_just_seen(self, key=None): """Yields unique elements, preserving order. Remembers only the element just seen""" self.iterable = map(next, map(operator.itemgetter(1), it.groupby(self.iterable, key))) return self
[docs] def unique_ever_seen(self, key=None): """Yields unique elements, preserving order. Remembers all elements ever seen""" self.iterable = self._unique_ever_seen(self.iterable, key) return self
[docs] @staticmethod def _unique_ever_seen(iterable, key=None): seen = set() for element in iterable: k = key(element) if key else element if k not in seen: seen.add(k) yield element
# ### ###
[docs] def sliding_window(self, n): """Collects data into overlapping fixed-length chunks or blocks""" if n < 0: raise ValueError("Window size cannot be negative") self.iterable = self._sliding_window(self.iterable, n) return self
[docs] @staticmethod def _sliding_window(iterable, n): import collections window = collections.deque(it.islice(iterable, n - 1), maxlen=n) for x in it.islice(iterable, n - 1, len(iterable)): window.append(x) yield tuple(window)
[docs] def grouper(self, n, *, incomplete="fill", fill_value=None): """Collects data into non-overlapping fixed-length chunks or blocks""" self.iterable = self._grouper(n, incomplete, fill_value) return self
[docs] def _grouper(self, n, incomplete="fill", fill_value=None): iterators = [iter(self.iterable)] * n match incomplete: case "fill": return it.zip_longest(*iterators, fillvalue=fill_value) case "strict": return zip(*iterators, strict=True) case "ignore": return zip(*iterators) case _: raise ValueError( f"Invalid incomplete flag '{incomplete}', expected: 'fill', 'strict', or 'ignore'" )
[docs] def round_robin(self): """Visits input iterables in a cycle until each is exhausted""" self.iterable = self._round_robin(self.iterable) return self
[docs] @staticmethod def _round_robin(iterable): # Algorithm credited to George Sakkis iterators = map(iter, iterable) for num_active in range(len(iterable), 0, -1): iterators = it.cycle(it.islice(iterators, num_active)) yield from map(next, iterators)
[docs] def partition(self, predicate): """ Partitions entries into true and false entries. Returns a query of two nested generators """ true_iter, false_iter = it.tee(self.iterable) self.iterable = filter(predicate, true_iter), it.filterfalse(predicate, false_iter) return self
[docs] def subslices(self): """Returns all contiguous non-empty sub-slices""" slices = it.starmap(slice, it.combinations(range(len(self.iterable) + 1), 2)) self.iterable = map(operator.getitem, it.repeat(self.iterable), slices) # noqa return self
[docs] def find_indices(self, value, start=0, stop=None): """Returns indices where a value occurs in a sequence or iterable""" self.iterable = self._find_indices(self.iterable, value, start, stop) return self
[docs] @staticmethod def _find_indices(iterable, value, start=0, stop=None): iterator = it.islice(iterable, start, stop) for i, element in enumerate(iterator, start): if element is value or element == value: yield i