from collections.abc import Mapping
from functools import singledispatchmethod
from fumus.queries.itertools_mixin import ItertoolsMixin
from fumus.queries.query_generator import QueryGenerator
from fumus.utils import Optional, DictItem
from fumus.decorators.handler import pre_call, handle_consumed
from fumus.exceptions.exception import NoneTypeError, UnsupportedTypeError, IllegalStateError
[docs]
@pre_call(handle_consumed)
class Query(ItertoolsMixin):
"""Abstraction over a sequence of elements supporting sequential aggregate operations"""
def __init__(self, iterable):
if iterable is None:
raise NoneTypeError("Cannot create Query from None")
self._iterable = iterable
self._is_consumed = False
self._on_close_handler = None
[docs]
def __iter__(self):
return iter(self.iterable)
[docs]
@classmethod
def of(cls, *iterable):
"""Creates Query from args"""
return cls(iterable)
[docs]
@classmethod
def of_nullable(cls, iterable):
"""Creates Query from args if iterable is not None; otherwise returns empty Query"""
if iterable is None:
return cls.empty()
return cls(iterable)
[docs]
@classmethod
def empty(cls):
"""Creates empty Query"""
return cls([])
[docs]
@classmethod
def iterate(cls, seed, operation, condition=None):
"""Creates infinite ordered Query"""
return cls(QueryGenerator.iterate(seed, operation, condition))
[docs]
@classmethod
def generate(cls, supplier):
"""Creates infinite unordered Query with values generated by given supplier function"""
return cls(QueryGenerator.generate(supplier))
[docs]
@classmethod
def constant(cls, element):
"""Creates infinite Query with given value"""
return cls.generate(lambda: element)
[docs]
@singledispatchmethod # noqa
@classmethod
def from_range(cls, *range_list: int):
"""Creates Query from start (inclusive) to stop (exclusive) by an incremental step"""
return cls(QueryGenerator.range(*range_list))
[docs]
@from_range.register(range) # noqa
@classmethod
def _(cls, range_obj: range):
"""Creates Query range object"""
return cls(QueryGenerator.range(range_obj.start, range_obj.stop, range_obj.step))
@property
def iterable(self):
if isinstance(self._iterable, Mapping):
return (DictItem(k, v) for k, v in self._iterable.items())
return self._iterable
@iterable.setter
def iterable(self, value):
self._iterable = value
[docs]
def concat(self, *queries):
"""Concatenates several queries together or adds new queries/collections to the current one"""
self.iterable = QueryGenerator.concat(self.iterable, *queries)
return self
[docs]
def prepend(self, iterable):
"""Prepends iterable to current query"""
self.iterable = QueryGenerator.concat(iterable, self.iterable)
return self
[docs]
def filter(self, predicate):
"""Filters values in query based on given predicate function"""
self.iterable = QueryGenerator.filter(self.iterable, predicate)
return self
[docs]
def map(self, mapper):
"""Returns a query consisting of the results of applying the given function to the elements of this query"""
self.iterable = QueryGenerator.map(self.iterable, mapper)
return self
[docs]
def filter_map(self, mapper, *, discard_falsy=False):
"""Filters out all None or falsy values and applies mapper function to the elements of the query"""
self.iterable = QueryGenerator.filter_map(self.iterable, mapper, discard_falsy)
return self
[docs]
def flat_map(self, mapper):
"""Maps each element of the query and yields the elements of the produced iterators"""
self.iterable = QueryGenerator.flat_map(self.iterable, mapper)
return self
[docs]
def flatten(self):
"""Converts a Query of multidimensional collection into a one-dimensional"""
self.iterable = QueryGenerator.flatten(self.iterable)
return self
[docs]
def peek(self, operation):
"""Performs the provided operation on each element of the query without consuming it"""
self.iterable = QueryGenerator.peek(self.iterable, operation)
return self
[docs]
def distinct(self):
"""Returns a query with the distinct elements of the current one"""
self.iterable = QueryGenerator.distinct(self.iterable)
return self
[docs]
def count(self):
"""Returns the count of elements in the query"""
return len(tuple(self.iterable))
[docs]
def sum(self):
"""Sums the elements of the query"""
if len(self.iterable) == 0:
return 0
if not any(isinstance(x, (int | float | None)) for x in self.iterable):
raise ValueError("Cannot apply sum on non-number elements")
return sum(self.iterable)
[docs]
def average(self):
"""Returns the average value of elements in the query"""
if (query_len := len(self.iterable)) == 0:
return 0
return self.sum() / query_len
[docs]
def skip(self, count):
"""Discards the first n elements of the query and returns a new query with the remaining ones"""
if count < 0:
raise ValueError("Skip count cannot be negative")
self.iterable = QueryGenerator.skip(self.iterable, count)
return self
[docs]
def limit(self, count):
"""Returns a query with the first n elements, or fewer if the underlying iterator ends sooner"""
if count < 0:
raise ValueError("Limit count cannot be negative")
self.iterable = QueryGenerator.limit(self.iterable, count)
return self
[docs]
def head(self, count):
"""Alias for 'limit'"""
if count < 0:
raise ValueError("Head count cannot be negative")
self.iterable = QueryGenerator.limit(self.iterable, count)
return self
[docs]
def tail(self, count):
"""Returns a query with the last n elements, or fewer if the underlying iterator ends sooner"""
if count < 0:
raise ValueError("Tail count cannot be negative")
self.iterable = QueryGenerator.tail(self.iterable, count)
return self
[docs]
def take_while(self, predicate):
"""Returns a query that yields elements based on a predicate"""
self.iterable = QueryGenerator.take_while(self.iterable, predicate)
return self
[docs]
def drop_while(self, predicate):
"""Returns a query that skips elements based on a predicate and yields the remaining ones"""
self.iterable = QueryGenerator.drop_while(self.iterable, predicate)
return self
[docs]
def take_first(self, default=None):
"""Returns Optional with the first element of the query or a default value"""
return Optional.of_nullable(next(iter(self.iterable), default))
[docs]
def take_last(self, default=None):
"""Returns Optional with the last element of the query or a default value"""
if self.iterable:
*_, last = self.iterable
return Optional.of_nullable(last)
return Optional.of_nullable(default)
[docs]
def sort(self, comparator=None, *, reverse=False):
"""
Sorts the elements of the current query according to natural order or based on the given comparator.
If 'reverse' flag is True, the elements are sorted in descending order
"""
self.iterable = QueryGenerator.sort(self.iterable, comparator, reverse)
return self
[docs]
def reverse(self, comparator=None):
"""
Sorts the elements of the current query in descending order.
Alias for 'sort(comparator, reverse=True)'
"""
self.iterable = QueryGenerator.sort(self.iterable, comparator, reverse=True)
return self
[docs]
def find_first(self, predicate=None):
"""
Searches for an element of the query that satisfies a predicate.
Returns an Optional with the first found value, if any, or None
"""
return Optional.of_nullable(next(filter(predicate, self.iterable), None))
[docs]
def find_any(self, predicate=None):
"""
Searches for an element of the query that satisfies a predicate.
Returns an Optional with some of the found values, if any, or None
"""
import random
if predicate:
self.filter(predicate)
try:
return Optional.of(random.choice(list(self.iterable)))
except IndexError:
return Optional.of_nullable(None)
[docs]
def any_match(self, predicate):
"""Returns whether any elements of the query match the given predicate"""
return any(predicate(i) for i in self.iterable)
[docs]
def all_match(self, predicate):
"""Returns whether all elements of the query match the given predicate"""
return all(predicate(i) for i in self.iterable)
[docs]
def none_match(self, predicate):
"""Returns whether no elements of the query match the given predicate"""
return any(not predicate(i) for i in self.iterable)
[docs]
def min(self, comparator=None, default=None):
"""Returns the minimum element of the query according to the given comparator"""
return Optional.of_nullable(min(self.iterable, key=comparator, default=default))
[docs]
def max(self, comparator=None, default=None):
"""Returns the maximum element of the query according to the given comparator"""
return Optional.of_nullable(max(self.iterable, key=comparator, default=default))
[docs]
def for_each(self, operation):
"""Performs an action for each element of this query"""
for i in self.iterable:
operation(i)
[docs]
def enumerate(self, start=0):
"""
Returns each element of the Query preceded by his corresponding index
(by default starting from 0 if not specified otherwise)
"""
self.iterable = QueryGenerator.enumerate(self.iterable, start)
return self
[docs]
def reduce(self, accumulator, identity=None):
"""
Reduces the elements to a single one, by repeatedly applying a reducing operation.
Returns Optional with the result, if any, or None
"""
if len(self.iterable) == 0:
return Optional.of_nullable(identity)
curr_iter = iter(self.iterable)
if identity is None:
identity = next(curr_iter)
for i in curr_iter:
identity = accumulator(identity, i)
return Optional.of_nullable(identity)
[docs]
def compare_with(self, other, comparator=None):
"""Compares current query with another one based on a given comparator"""
return not any(
(comparator and not comparator(i, j)) or i != j for i, j in zip(self.iterable, other)
)
# ### collectors ###
[docs]
def collect(self, collection_type, dict_collector=None, dict_merger=None, str_delimiter=", "):
"""
Returns a collection from the query.
In case of dict:
The 'dict_collector' function receives an element from the query and returns a (key, value) pair or a DictItem
specifying how the dict should be constructed.
The 'dict_merger' functions indicates in the case of a collision (duplicate keys), which entry should be kept.
E.g. lambda old, new: new
In case of str:
Concatenates the elements of the Query, separated by the specified 'str_delimiter'
"""
import builtins
match collection_type:
case builtins.tuple:
return self.to_tuple()
case builtins.list:
return self.to_list()
case builtins.set:
return self.to_set()
case builtins.dict:
return self.to_dict(dict_collector, dict_merger)
case builtins.str:
return self.to_string(str_delimiter)
case _:
raise ValueError("Invalid collection type")
[docs]
def to_list(self):
"""Returns a list of the elements of the current query"""
return list(self.iterable)
[docs]
def to_tuple(self):
"""Returns a tuple of the elements of the current query"""
return tuple(self.iterable)
[docs]
def to_set(self):
"""Returns a set of the elements of the current query"""
return set(self.iterable)
[docs]
def to_dict(self, collector=None, merger=None):
"""
Returns a dict of the elements of the current query.
The 'collector' function receives an element from the query and returns a (key, value) pair or a DictItem
specifying how the dict should be constructed.
The 'merger' functions indicates in the case of a collision (duplicate keys), which entry should be kept.
E.g. lambda old, new: new
"""
result = {}
source = (collector(i) for i in self.iterable) if collector else self.iterable
for item in source:
k, v = self._unpack_dict_item(item)
if k in result:
if merger is None:
raise IllegalStateError(f"Key '{k}' already exists")
v = merger(result[k], v)
result[k] = v
return result
[docs]
def _unpack_dict_item(self, item): # noqa
match item:
case tuple():
return item[0], item[1]
case DictItem():
# let's not make unnecessary calls to property getters
return item._key, item._value # noqa
case _:
raise UnsupportedTypeError(
f"Cannot create dict items from '{item.__class__.__name__}' type"
)
[docs]
def to_string(self, delimiter=", "):
"""Concatenates the elements of the Query, separated by the specified delimiter"""
return self._join(delimiter)
[docs]
def group_by(self, classifier=None, collector=None):
"""
Performs a "group by" operation on the elements of the query according to a classification function.
Returns the results in a dict built using collector function
(optionally provided by the user or via a default one)
"""
if collector is None:
return {key: list(group) for key, group in self._group_by(classifier)}
result = {}
for key, group in self._group_by(classifier):
key, group = collector(key, list(group))
if hasattr(group, "__iter__"):
if key not in result:
result[key] = []
result[key] += group
else:
result[key] = group
return result
[docs]
def _group_by(self, classifier=None):
# https://docs.python.org/3/library/itertools.html#itertools.groupby
classifier = (lambda x: x) if classifier is None else classifier
iterator = iter(self.iterable)
exhausted = False
def _grouper(target_key): # noqa
nonlocal curr_value, curr_key, exhausted
yield curr_value
for curr_value in iterator:
curr_key = classifier(curr_value)
if curr_key != target_key:
return
yield curr_value
exhausted = True
try:
curr_value = next(iterator)
except StopIteration:
return
curr_key = classifier(curr_value)
while not exhausted:
target_key = curr_key
curr_group = _grouper(target_key)
yield curr_key, curr_group
if curr_key == target_key:
for _ in curr_group:
pass
[docs]
def quantify(self, predicate=bool):
"""Count how many of the elements are Truthy or evaluate to True based on a given predicate"""
return sum(self.map(predicate))
[docs]
def close(self):
"""Closes the query, causing the provided close handler to be called"""
if self._on_close_handler:
self._on_close_handler()
self._is_consumed = True
[docs]
def on_close(self, handler):
"""Returns an equivalent query with an additional close handler"""
self._on_close_handler = handler
return self
# ### let's look nice ###
[docs]
def __repr__(self):
return f"{self.__class__.__name__}.of({self._join()})"
[docs]
def _join(self, delimiter=", "):
return delimiter.join(str(i) for i in self.iterable)
# NB: handle_consumed decorator needs access to toggle flag
[docs]
def take_nth(self, idx, default=None):
"""Returns Optional with the nth element of the query or a default value"""
return super().take_nth(idx, default)
[docs]
def all_equal(self, key=None):
"""Returns True if all elements of the query are equal to each other"""
return super().all_equal(key)