-
-
Notifications
You must be signed in to change notification settings - Fork 33.9k
Expand file tree
/
Copy pathcomposite.py
More file actions
177 lines (140 loc) · 5.6 KB
/
composite.py
File metadata and controls
177 lines (140 loc) · 5.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
import json
from django.core import checks
from django.db.models import NOT_PROVIDED, Field
from django.db.models.expressions import ColPairs
from django.db.models.fields.tuple_lookups import (
TupleExact,
TupleGreaterThan,
TupleGreaterThanOrEqual,
TupleIn,
TupleIsNull,
TupleLessThan,
TupleLessThanOrEqual,
)
from django.utils.functional import cached_property
class AttributeSetter:
def __init__(self, name, value):
setattr(self, name, value)
class CompositeAttribute:
def __init__(self, field):
self.field = field
@property
def attnames(self):
return [field.attname for field in self.field.fields]
def __get__(self, instance, cls=None):
return tuple(getattr(instance, attname) for attname in self.attnames)
def __set__(self, instance, values):
attnames = self.attnames
length = len(attnames)
if values is None:
values = (None,) * length
if not isinstance(values, (list, tuple)):
raise ValueError(f"{self.field.name!r} must be a list or a tuple.")
if length != len(values):
raise ValueError(f"{self.field.name!r} must have {length} elements.")
for attname, value in zip(attnames, values):
setattr(instance, attname, value)
class CompositePrimaryKey(Field):
descriptor_class = CompositeAttribute
def __init__(self, *args, **kwargs):
if (
not args
or not all(isinstance(field, str) for field in args)
or len(set(args)) != len(args)
):
raise ValueError("CompositePrimaryKey args must be unique strings.")
if len(args) == 1:
raise ValueError("CompositePrimaryKey must include at least two fields.")
if kwargs.get("default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a default.")
if kwargs.get("db_default", NOT_PROVIDED) is not NOT_PROVIDED:
raise ValueError("CompositePrimaryKey cannot have a database default.")
if kwargs.get("db_column", None) is not None:
raise ValueError("CompositePrimaryKey cannot have a db_column.")
if kwargs.setdefault("editable", False):
raise ValueError("CompositePrimaryKey cannot be editable.")
if not kwargs.setdefault("primary_key", True):
raise ValueError("CompositePrimaryKey must be a primary key.")
if not kwargs.setdefault("blank", True):
raise ValueError("CompositePrimaryKey must be blank.")
self.field_names = args
super().__init__(**kwargs)
def deconstruct(self):
# args is always [] so it can be ignored.
name, path, _, kwargs = super().deconstruct()
return name, path, self.field_names, kwargs
@cached_property
def fields(self):
meta = self.model._meta
return tuple(meta.get_field(field_name) for field_name in self.field_names)
@cached_property
def columns(self):
return tuple(field.column for field in self.fields)
def contribute_to_class(self, cls, name, private_only=False):
super().contribute_to_class(cls, name, private_only=private_only)
cls._meta.pk = self
setattr(cls, self.attname, self.descriptor_class(self))
def get_attname_column(self):
return self.get_attname(), None
def __iter__(self):
return iter(self.fields)
def __len__(self):
return len(self.field_names)
@cached_property
def cached_col(self):
return ColPairs(self.model._meta.db_table, self.fields, self.fields, self)
def get_col(self, alias, output_field=None):
if alias == self.model._meta.db_table and (
output_field is None or output_field == self
):
return self.cached_col
return ColPairs(alias, self.fields, self.fields, output_field)
def get_pk_value_on_save(self, instance):
values = []
for field in self.fields:
value = field.value_from_object(instance)
if value is None:
value = field.get_pk_value_on_save(instance)
values.append(value)
return tuple(values)
def _check_field_name(self):
if self.name == "pk":
return []
return [
checks.Error(
"'CompositePrimaryKey' must be named 'pk'.",
obj=self,
id="fields.E013",
)
]
def value_to_string(self, obj):
values = []
vals = self.value_from_object(obj)
for field, value in zip(self.fields, vals):
obj = AttributeSetter(field.attname, value)
values.append(field.value_to_string(obj))
return json.dumps(values, ensure_ascii=False)
def to_python(self, value):
if isinstance(value, str):
# Assume we're deserializing.
vals = json.loads(value)
value = [
field.to_python(val)
for field, val in zip(self.fields, vals, strict=True)
]
return value
CompositePrimaryKey.register_lookup(TupleExact)
CompositePrimaryKey.register_lookup(TupleGreaterThan)
CompositePrimaryKey.register_lookup(TupleGreaterThanOrEqual)
CompositePrimaryKey.register_lookup(TupleLessThan)
CompositePrimaryKey.register_lookup(TupleLessThanOrEqual)
CompositePrimaryKey.register_lookup(TupleIn)
CompositePrimaryKey.register_lookup(TupleIsNull)
def unnest(fields):
result = []
for field in fields:
if isinstance(field, CompositePrimaryKey):
result.extend(field.fields)
else:
result.append(field)
return result