from django.db.models import QuerySet from django.db.models.expressions import BaseExpression, Combinable from django.db.models.query import ValuesIterable from django.db.models.manager import Manager """ Base Idea: - to have a queryset function that easily allows you to build some sort of "record type class" in a queryset. - make it able to seamlessly use annotation functions - allow modification of initial values with callbacks, which is needed if your dataclass is frozen. By default the RecordManager and RecordQuerySet will use RecordDataclass as handler, expecting your record to be a dataclasses.dataclass type. An Example: @dataclass class MyDataClass id: int some_relation: str next_id: int SomeModel.objects.filter(...).records(MyDataClass, some_relation=F('model__relation'), next_id=Lambda(lambda x: x.get('id')+1)) This allows you to move an iterator into another layer, where it either can be consumed, or used as an input for another queryset call, but still guaranteeing, that regular usage of the iterator will not yield any smart object. It is good if you want to build a best-of-both-worlds approach for subquery-capable repository pattern, which is one of the biggest issues if you want to keep your business logic out of the repo layer, but still want to utilize djangos queryset mechanics properly. records() will take anything values() would take, but additionally it allows: - to pass the record type as first argument - to pass Adjunct classes as keyword argument value """ ## Useful for queryset function records() class BaseAdjunct: """ Any Adjunct data which does not translate into SQL, but rather adds data programmatically. """ skip = False # if skip is true, this adjunct will not be actually processed. resolves_field = True # if resolves_field is true, this adjunct will be called for a single field with resolve() post_processing = False # if post_processing is true, this adjunct will in the end be called with dbdata, and be able to manipulate the whole dictionary. def resolve(self, model, dbdata): raise NotImplementedError def post_process(self, model, dbdata): raise NotImplementedError class Adjunct(BaseAdjunct): """ value function that adds data, without SQL handling. """ def __init__(self, value=None): self.value = value def resolve(self, model, dbdata): return self.value class Lambda(Adjunct): """ adjunct value that returns a field value with a callback. """ def __init__(self, callback): self.callback = callback if callable(callback) else None def resolve(self, model, dbdata): # at this point i could check if callback needs 0-2 arguments and decide the call. if self.callback: return self.callback(dbdata) class Skip(BaseAdjunct): """ Skips this key from being retrieved from the database or used in the dataclass instantiation """ skip = True resolves_field = False class Callback(BaseAdjunct): """ calls a callback which can modify the whole initialization dictionary. """ resolves_field = False post_processing = True def __init__(self, callback): self.callback = callback def post_process(self, model, dbdata): if self.callback: return self.callback(dbdata) ## Wrapper to handle some sort of record baseclass class RecordHandler: """ handler for a record type defines how a record can be created, and how to retrieve all field names, and the required ones. """ @classmethod def wrap(cls, klass): return cls(klass) def __init__(self, klass): self.klass = klass def create(self, **kwargs): return self.klass(**kwargs) def get_field_names(self): return self.klass.__dict__.keys() @property def record(self): return self.klass @property def required_keys(self): return self.get_field_names() class RecordDict(RecordHandler): """ RecordHandler that outputs a dictionary """ def __init__(self, klass=None): # it is not required to define dict, but you could do OrderedDict e.g. self.klass = klass or dict def get_field_names(self): # dictionary has no required fields. return [] class RecordDataclass(RecordHandler): """ handles dataclasses.dataclass derivatives """ def create(self, **kwargs): # clean field names to be only valid if they are on the dataclass. record_fields = self.get_field_names() kwargs = {k: v for k, v in kwargs.items() if k in record_fields} return self.klass(**kwargs) def get_field_names(self): return list(self.klass.__dataclass_fields__.keys()) # @TODO: RecordPydantic # @TODO: RecordAttrs ###### QuerySet Plugin. class RecordIterable(ValuesIterable): """ Iterable returned by records() that yields a record class for each row. Replaces the standard iterable of the queryset. """ def __iter__(self): queryset = self.queryset model = self.queryset.model query = queryset.query compiler = query.get_compiler(queryset.db) record_data = getattr(queryset, '_record_extra', {}) record_handler = queryset._record # extra(select=...) cols are always at the start of the row. names = [ *query.extra_select, *query.values_select, *query.annotation_select, ] indexes = range(len(names)) for row in compiler.results_iter(chunked_fetch=self.chunked_fetch, chunk_size=self.chunk_size): dbdata = {names[i]: row[i] for i in indexes} # post-processors will be able to rewrite the whole dictionary. post_processors = [] # we overwrite db data bluntly for now. actually we would provide callbacks the current dict. for k, v in record_data.items(): if v.resolves_field: dbdata[k] = v.resolve(model, dbdata) if v.post_processing: post_processors.append(v) if post_processors: for processor in post_processors: processed = processor.post_process(model, dbdata) if processed is not None: dbdata = processed yield record_handler.create(**dbdata) class RecordQuerySetMixin: _record_handler = RecordDataclass def records(self, *args, **kwargs): """ generates record objects Acts like values(), however: - you can pass a record type or RecordHandler as first argument. - if record type is not defined in records(), you have to define it on the queryset, or the model, with _record, otherwise it will raise a RuntimeError. - keyword arguments of type "Adjunct" are used as deferred values, and resolved independently. - values() is called with every required_key on the dataclass not handled by an Adjunct """ if len(args) and not isinstance(args[0], str): # we assume this is our dataclass # @TODO better checks. handler = args[0] args = args[1:] else: # determine dataclass. handler = getattr(self, '_record', getattr(self.model, '_record', None)) if not handler: raise RuntimeError("Trying to records a class without destination class.") if not isinstance(handler, RecordHandler): handler = self._record_handler.wrap(handler) all_keys = [*args, *kwargs.keys()] unhandled_keys = list(set(handler.required_keys) - set(all_keys)) args = [*args, *unhandled_keys] new_kw = {} extra = {} for k, v in kwargs.items(): if isinstance(v, BaseAdjunct): if not v.skip: extra[k] = v elif isinstance(v, BaseExpression) or isinstance(v, Combinable): new_kw[k] = v else: new_kw[k] = v # copy ourself with values() and save the results on the cloned queryset. values = self.values(*args, **new_kw) values._iterable_class = RecordIterable values._record_extra = extra values._record = handler return values class RecordQuerySet(RecordQuerySetMixin, QuerySet): # overwrite cloning. def _clone(self): c = super()._clone() for key in ['_record', '_record_extra', '_record_handler']: if hasattr(self, key): setattr(c, key, getattr(self, key)) return c # i use a mixin instead for better clarity. records is completely safe, as it does not call _chain. # however you can also simply do: #class RecordManager(BaseManager.from_queryset(RecordQuerySet)): # pass class RecordManager(RecordQuerySetMixin, Manager): def get_queryset(self): return RecordQuerySet(self.model, using=self._db)