1 purify
Gabor Körber edited this page 2021-05-22 22:23:30 +02:00
from django.db.models import QuerySet
from django.db.models.expressions import BaseExpression, Combinable
from django.db.models.query import ValuesIterable
from django.db.models.manager import Manager

"""

    Base Idea:
        - to have a queryset function that easily allows you to build some sort of "pure class" in a queryset.
        - make it able to seamlessly use annotation functions
        - allow callback modication of data, which is needed if your dataclass is frozen.
    
    By default the PureManager and PureQuerySet will use PureDataclass as handler, expecting your pureclass to be a dataclasses.dataclass type.
    
    An Example:
    @dataclass
    class MyClass
        id: int
        some_relation: str
        next_id: int
    
    SomeModel.objects.filter(...).exclude(...).purify(MyClass, some_relation=F('some__model__relation'), next_id=Lambda(lambda x: x.get('id')+1))
    
    This allows you to move an iterator into another layer, where it either can be consumed, or used as an input for another queryset call,
    but still guaranteeing, that regular usage of the iterator will not yield any smart object.
    It is good if you want to build a best-of-both-worlds approach for subquery-capable repository pattern, which is one of the biggest issues if you want
    to keep your business logic out of the repo layer, but still want to utilize djangos queryset mechanics properly.
    
"""

## Useful for queryset function purify()
class BaseLeap:
    skip = False # if skip is true, this leap will not be actually processed.
    resolves_field = True # if resolves_field is true, this leap will be called for a single field with resolve()
    post_processing = False # if post_processing is true, this leap will in the end be called with dbdata, and be able to manipulate the whole dictionary.

    def resolve(self, model, dbdata):
        raise NotImplementedError

    def post_process(self, model, dbdata):
        raise NotImplementedError

class Leap(BaseLeap):
    """ value function that leaps SQL handling. """
    def __init__(self, value=None):
        self.value = value

    def resolve(self, model, dbdata):
        return self.value

class Lambda(Leap):
    """ leap value that calls a lambda expression. """
    def __init__(self, callback):
        self.callback = callback if callable(callback) else None

    def resolve(self, model, dbdata):
        # at this point i could check if callback needs 0-2 arguments and decide the call. 
        if self.callback:
            return self.callback(dbdata)
        
class Skip(BaseLeap):
    """ Skips this key from being retrieved from the database or used in the dataclass instantiation """
    skip = True
    resolves_field = False

class Callback(BaseLeap):
    resolves_field = False
    post_processing = True
    def __init__(self, callback):
        self.callback = callback
    def post_process(self, model, dbdata):
        if self.callback:
            return self.callback(dbdata)

## Wrapper to handle some sort of pure baseclass
class PureHandler:
    """ handler for a pure baseclass 
    
        defines how a pureclass can be created, and how to retrieve all field names, and the required ones.
    """

    @classmethod
    def wrap(cls, klass):
        return cls(klass)

    def __init__(self, klass):
        self.klass = klass

    def create(self, **kwargs):
        return self.klass(**kwargs)

    def get_field_names(self):
        return self.klass.__dict__.keys()

    @property
    def pureclass(self):
        return self.klass

    @property
    def required_keys(self):
        return self.get_field_names()


class PureDict(PureHandler):
    """ PureHandler that outputs a dictionary """

    def __init__(self, klass=None):
        # it is not required to define dict, but you could do OrderedDict e.g.
        self.klass = klass or dict

    def get_field_names(self):
        # dictionary has no required fields.
        return []


class PureDataclass(PureHandler):
    """ handles dataclasses.dataclass derivatives """

    def create(self, **kwargs):
        # clean field names to be only valid if they are on the dataclass.
        pure_fields = self.get_field_names()
        kwargs = {k: v for k, v in kwargs.items() if k in pure_fields}
        return self.klass(**kwargs)

    def get_field_names(self):
        return list(self.klass.__dataclass_fields__.keys())

# @TODO: PurePydantic

###### QuerySet Plugin.

class PureIterable(ValuesIterable):
    """
    Iterable returned by purify() that yields a pure class for each row.
    Replaces the standard iterable of the queryset.
    """

    def __iter__(self):
        queryset = self.queryset
        model = self.queryset.model
        query = queryset.query
        compiler = query.get_compiler(queryset.db)
        pure_data = getattr(queryset, '_pureclass_extra', {})
        pure_handler = queryset._pureclass

        # extra(select=...) cols are always at the start of the row.
        names = [
            *query.extra_select,
            *query.values_select,
            *query.annotation_select,
        ]
        indexes = range(len(names))

        for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size):
            dbdata = {names[i]: row[i] for i in indexes}
            # post-processors will be able to rewrite the whole dictionary.
            post_processors = []
            # we overwrite db data bluntly for now. actually we would provide callbacks the current dict.
            for k, v in pure_data.items():
                if v.resolves_field:
                    dbdata[k] = v.resolve(model, dbdata)
                if v.post_processing:
                    post_processors.append(v)
            if post_processors:
                for processor in post_processors:
                    processed = processor.post_process(model, dbdata)
                    if processed is not None:
                        dbdata = processed
            yield pure_handler.create(**dbdata)


class PureQuerySetMixin:
    _pureclass_handler = PureDataclass

    def purify(self, *args, **kwargs):
        """
            generates pure objects
        
            Acts like values(), however:
                - first argument is a pureclass or purehandler, if not a string.
                - if pureclass is not defined in purify, you have to define it on the queryset, or the model, with _pureclass,
                  otherwise it will raise a RuntimeError.
                - keyword arguments of type "Leap" are used as deferred values, and resolved independently.
                - values() is called with every required_key on the dataclass not handled by a Leap
        """
        if len(args) and not isinstance(args[0], str):
            # we assume this is our dataclass
            # @TODO better checks.
            handler = args[0]
            args = args[1:]
        else:
            # determine dataclass.
            handler = getattr(self, '_pureclass', getattr(self.model, '_pureclass', None))
            if not handler:
                raise RuntimeError("Trying to purify a class without destination class.")

        if not isinstance(handler, PureHandler):
            handler = self._pureclass_handler.wrap(handler)

        all_keys = [*args, *kwargs.keys()]
        unhandled_keys = list(set(handler.required_keys) - set(all_keys))
        args = [*args, *unhandled_keys]

        new_kw = {}
        extra = {}
        for k, v in kwargs.items():
            if isinstance(v, Leap):
                if not v.skip:
                    extra[k] = v
            elif isinstance(v, BaseExpression) or isinstance(v, Combinable):
                new_kw[k] = v
            else:
                new_kw[k] = v

        # copy ourself with values() and save the results on the cloned queryset.
        values = self.values(*args, **new_kw)
        values._iterable_class = PureIterable
        values._pureclass_extra = extra
        values._pureclass = handler
        return values


class PureQuerySet(PureQuerySetMixin, QuerySet):
    # overwrite cloning.
    def _clone(self):
        c = super()._clone()
        for key in ['_pureclass', '_pureclass_extra', '_pureclass_handler']:
            if hasattr(self, key):
                setattr(c, key, getattr(self, key))
        return c

# i use a mixin instead for better clarity. purify is completely safe, as it does not call _chain.
# however you can also simply do:
#class PureManager(BaseManager.from_queryset(PureQuerySet)):
#    pass


class PureManager(PureQuerySetMixin, Manager):
    def get_queryset(self):
        return PureQuerySet(self.model, using=self._db)