Django: How to Represent and Query Symmetrical Relationships for a Family Tree?
I am building a family tree application in Django where I need to represent and query marriages symmetrically. Each marriage should have only one record, and the relationship should include both partners without duplicating data. Here's the relevant model structure:
class Person(models.Model):
first_name = models.CharField(max_length=100)
last_name = models.CharField(max_length=100)
spouses = models.ManyToManyField(
'self', through="Marriage", symmetrical=True, related_name="partners"
)
class Marriage(models.Model):
person1 = models.ForeignKey(Person, on_delete=models.CASCADE, related_name="marriages_as_person1")
person2 = models.ForeignKey(Person, on_delete=models.CASCADE, related_name="marriages_as_person2")
start_date = models.DateField(null=True, blank=True)
end_date = models.DateField(null=True, blank=True)
I want to:
- Ensure both partners appear as spouses for each other symmetrically.
- Avoid duplicate entries for the same marriage.
- Efficiently query all spouses of a person.
Here’s the code I’m using to query spouses:
# Query spouses for a person
p1 = Person.objects.create()
p2 = Person.objects.create()
Marriage.objects.create(person1=p1, person2=p2)
p1.spouses.all() # Returns list containing p2
p2.spouses.all() # Returns empty list
However, I’m facing challenges:
- If
p1spouses are queried, it should containp2and ifp2spouses are queried, it should containp1 - Both queries are not symmetrical
Questions:
- Is my model structure correct for representing marriages symmetrically? If not, what improvements should I make?
- How can I efficiently query all spouses of a person in a database-optimized way while ensuring symmetry?
I would improve your structure. I think ManyToManyField should not be used here, probably my option would work better in your case. Also, this will definitely ensure that only one marriage record is created for the two persons.
class Person(models.Model):
first_name = models.CharField(max_length=100)
last_name = models.CharField(max_length=100)
class Marriage(models.Model):
person1 = models.ForeignKey(
Person,
on_delete=models.CASCADE,
related_name='+',
)
person2 = models.ForeignKey(
Person,
on_delete=models.CASCADE,
related_name='+',
)
unique_key = models.CharField(max_length=50)
start_date = models.DateField(null=True, blank=True)
end_date = models.DateField(null=True, blank=True)
class Meta:
constraints = [
models.UniqueConstraint(
fields=['unique_key'],
name='unique_marriage_pair',
)
]
def save(
self,
force_insert=None,
force_update=None,
using=None,
update_fields=None):
if not self.pk:
self.unique_key = '_'.join(
map(str, sorted([self.person1_id, self.person2_id]))
)
return super().save(
force_insert=force_insert,
force_update=force_update,
using=using,
update_fields=update_fields,
)
Which will allow you to find all spouse identifiers in a single query, for example, like this:
def get_pids(self, obj: Person) -> list[int]:
"""Returns a list of IDs of all the spouses of the person,
ensuring bidirectional relationships.
"""
person_id = obj.id
queryset = (
Marriage.objects
.filter(
models.Q(person1_id=person_id)
| models.Q(person2_id=person_id),
)
.values_list('person1_id', 'person2_id')
)
return list(set().union(*queryset) - {person_id})
Or it will be possible to make a selection for the list of passed identifiers in two queries, like this (provided Postgres is used):
from collections.abc import Iterable, Iterator
from typing import TypeAlias
from django.contrib.postgres.aggregates import ArrayAgg
Result: TypeAlias = Iterator[tuple[int, list[int]]]
def get_pids_for_few_persons(person_ids: Iterable[int]) -> Result:
def get_queryset(search_key: str, aggregate_key: str):
return (
Marriage.objects
.filter(models.Q(**{search_key + '__in': person_ids}))
.values(search_key)
.annotate(pids=ArrayAgg(aggregate_key))
.values_list(search_key, 'pids')
)
queryset1 = get_queryset(
search_key='person1_id',
aggregate_key='person2_id',
)
queryset2 = get_queryset(
search_key='person2_id',
aggregate_key='person1_id',
)
data1, data2 = dict(queryset1), dict(queryset2)
for person_id in data1.keys() | data2.keys():
yield person_id, data1.get(person_id, []) + data2.get(person_id, [])
p1 = Person.objects.create()
p2 = Person.objects.create()
p3 = Person.objects.create()
p4 = Person.objects.create()
Marriage.objects.create(person1=p1, person2=p2)
Marriage.objects.create(person1=p3, person2=p1)
result = list(get_pids((1, 2, 3)))
# [(1, [2, 3]), (2, [1]), (3, [1])]