How to use "Prefetch()" with "filter()" to reduce `SELECT` queries to iterate 3 or more models?

I have Country, State and City models which are chained by foreign keys as shown below:

class Country(models.Model):
    name = models.CharField(max_length=20)

class State(models.Model):
    country = models.ForeignKey(Country, on_delete=models.CASCADE)
    name = models.CharField(max_length=20)
    
class City(models.Model):
    state = models.ForeignKey(State, on_delete=models.CASCADE)
    name = models.CharField(max_length=20)

Then, when I iterate Country, State and City models with prefetch_related() and all() as shown below:

                                  # Here                                  # Here
for country_obj in Country.objects.prefetch_related("state_set__city_set").all():
    for state_obj in country_obj.state_set.all(): # Here
        for city_obj in state_obj.city_set.all(): # Here
            print(country_obj, state_obj, city_obj)

3 SELECT queries are run as shown below. *I use PostgreSQL and these below are the query logs of PostgreSQL and you can see this answer explaining how to enable and disable the query logs on PostgreSQL:

enter image description here

But, when I iterate with filter() instead of all() as shown below:

                                                                          # Here
for country_obj in Country.objects.prefetch_related("state_set__city_set").filter():
    for state_obj in country_obj.state_set.filter(): # Here
        for city_obj in state_obj.city_set.filter(): # Here
            print(country_obj, state_obj, city_obj)

8 SELECT queries are run as shown below instead of 3 SELECT queries:

enter image description here

So, I use Prefetch() with filter() to reduce 8 SELECT queries to 3 SELECT queries as shown below:

for country_obj in Country.objects.prefetch_related(
    Prefetch('state_set', # Here
        queryset=State.objects.filter(),
        to_attr='state_obj'
    ),
    Prefetch('city_set', # Here
        queryset=City.objects.filter(),
        to_attr='city_obj'
    ),
).filter():
    print(country_obj, country_obj.state_obj, country_obj.city_obj)

But, the error below occurs:

AttributeError: Cannot find 'city_set' on Country object, 'city_set' is an invalid parameter to prefetch_related()

So, how can I use Prefetch() with filter() to reduce 8 SELECT queries to 3 SELECT queries?

Try nesting prefetch objects

countries = Country.objects.filter().prefetch_related(
    Prefetch('state_set', # Here
        queryset=State.objects.filter().prefetch_related(
            Prefetch("city_set",
            queryset=City.objects.filter(),
        ),
    ),
)

Then run everything with all() since it's already filtered

for country_obj in countries:
    for state_obj in country_obj.state_set.all():
        for city_obj in state_obj.city_set.all():
            print(country_obj, state_obj, city_obj)
Back to Top