diff --git a/hub/graphql/types/geojson.py b/hub/graphql/types/geojson.py index 3a03cd33b..a892afa8e 100644 --- a/hub/graphql/types/geojson.py +++ b/hub/graphql/types/geojson.py @@ -45,11 +45,15 @@ class PointGeometry: # lng, lat coordinates: List[float] + @classmethod + def from_geodjango(cls, point: Point) -> "PointGeometry": + return PointGeometry(coordinates=[point.x, point.y]) + @strawberry.type class PointFeature(Feature): geometry: PointGeometry - properties: JSON + properties: Optional[JSON] = strawberry.field(default_factory=dict) @classmethod def from_geodjango( @@ -57,7 +61,7 @@ def from_geodjango( ) -> "PointFeature": return PointFeature( id=str(id), - geometry=PointGeometry(coordinates=point), + geometry=PointGeometry.from_geodjango(point), properties=properties, ) @@ -70,11 +74,15 @@ class PolygonGeometry: type: GeoJSONTypes.Polygon = GeoJSONTypes.Polygon coordinates: List[List[List[float]]] + @classmethod + def from_geodjango(cls, polygon: Polygon) -> "PolygonGeometry": + return cls(coordinates=polygon.coords) + @strawberry.type class PolygonFeature(Feature): geometry: PolygonGeometry - properties: JSON + properties: Optional[JSON] = strawberry.field(default_factory=dict) @classmethod def from_geodjango( @@ -82,27 +90,25 @@ def from_geodjango( ) -> "PolygonFeature": return PolygonFeature( id=str(id), - geometry=PolygonGeometry(coordinates=polygon), + geometry=PolygonGeometry.from_geodjango(polygon), properties=properties, ) -# - - @strawberry.type class MultiPolygonGeometry: type: GeoJSONTypes.MultiPolygon = GeoJSONTypes.MultiPolygon coordinates: JSON - def __init__(self, coordinates: MultiPolygon): - self.coordinates = coordinates.json + @classmethod + def from_geodjango(cls, multipolygon: MultiPolygon) -> "MultiPolygonGeometry": + return cls(coordinates=multipolygon.coords) @strawberry.type class MultiPolygonFeature(Feature): geometry: MultiPolygonGeometry - properties: JSON + properties: Optional[JSON] = strawberry.field(default_factory=dict) @classmethod def from_geodjango( @@ -110,6 +116,6 @@ def from_geodjango( ) -> "MultiPolygonFeature": return MultiPolygonFeature( id=str(id), - geometry=MultiPolygonGeometry(coordinates=multipolygon.json), + geometry=MultiPolygonGeometry.from_geodjango(multipolygon), properties=properties, ) diff --git a/hub/graphql/types/model_types.py b/hub/graphql/types/model_types.py index ba44f2afe..1acac45fa 100644 --- a/hub/graphql/types/model_types.py +++ b/hub/graphql/types/model_types.py @@ -1,3 +1,4 @@ +import json import logging import urllib.parse from datetime import datetime @@ -31,6 +32,7 @@ ) from hub.graphql.context import HubDataLoaderContext from hub.graphql.dataloaders import ( + FieldDataLoaderFactory, FieldReturningListDataLoaderFactory, ReverseFKWithFiltersDataLoaderFactory, filterable_dataloader_resolver, @@ -40,7 +42,12 @@ from hub.graphql.types.electoral_commission import ElectoralCommissionPostcodeLookup from hub.graphql.types.geojson import MultiPolygonFeature, PointFeature from hub.graphql.types.postcodes import PostcodesIOResult -from hub.graphql.utils import attr_field, dict_key_field, fn_field +from hub.graphql.utils import ( + attr_field, + dict_key_field, + django_model_instance_to_strawberry_type, + fn_field, +) from hub.management.commands.import_mps import party_shades from utils.geo_reference import ( AnalyticalAreaType, @@ -514,7 +521,7 @@ class ConstituencyElectionResult: @strawberry.type class ConstituencyElectionStats: - json: strawberry.Private[dict] + json: strawberry.Private[dict] = None date: str result: str @@ -564,6 +571,31 @@ def second_party_result(self, info: Info) -> PartyResult: ) +@strawberry.type +class AreaFeatureAreaProperties: + name: str = dict_key_field() + gss: str = dict_key_field() + area_type_name: str = dict_key_field() + area_type_code: str = dict_key_field() + + +@strawberry.type +class AreaFeatureProperties: + area: AreaFeatureAreaProperties = dict_key_field() + data: Optional[JSON] = dict_key_field() + generic_data: Optional["GenericData"] = dict_key_field() + + +@strawberry.type +class AreaPolygonFeature(MultiPolygonFeature): + properties: AreaFeatureProperties = strawberry.field(default_factory=dict) + + +@strawberry.type +class AreaPointFeature(PointFeature): + properties: AreaFeatureProperties = strawberry.field(default_factory=dict) + + @strawberry_django.type(models.Area, filters=AreaFilter) class Area: id: auto @@ -572,10 +604,17 @@ class Area: gss: auto name: auto area_type: "AreaType" = strawberry_django_dataloaders.fields.auto_dataloader_field() - geometry: auto + geometry: JSON = strawberry_django.field( + resolver=lambda root: ( + json.loads(root.geometry) + if isinstance(root.geometry, str) + else root.geometry + ) + ) overlaps: auto # So that we can pass in properties to the geojson Feature objects - extra_geojson_properties: strawberry.Private[object] + geojson_feature_properties: strawberry.Private = None + geojson_feature_genericdata: strawberry.Private["GenericData"] = None people: List[Person] = filterable_dataloader_resolver( filter_type=Optional[PersonFilter], field_name="person", @@ -631,31 +670,52 @@ async def last_election(self, info: Info) -> Optional[ConstituencyElectionResult return cer @strawberry_django.field - def polygon( - self, info: Info, with_parent_data: bool = False - ) -> Optional[MultiPolygonFeature]: + def polygon(self, info: Info) -> Optional[AreaPolygonFeature]: props = { - "name": self.name, - "gss": self.gss, - "id": self.gss, - "area_type": self.area_type, + "area": { + "name": self.name, + "gss": self.gss, + "area_type_name": self.area_type.name, + "area_type_code": self.area_type.code, + }, + "data": ( + self.geojson_feature_properties + if hasattr(self, "geojson_feature_properties") + else {} + ), + "generic_data": ( + self.geojson_feature_genericdata + if hasattr(self, "geojson_feature_genericdata") + else None + ), } - if with_parent_data and hasattr(self, "extra_geojson_properties"): - props["extra_geojson_properties"] = self.extra_geojson_properties - return MultiPolygonFeature.from_geodjango( + return AreaPolygonFeature.from_geodjango( multipolygon=self.polygon, id=self.gss, properties=props ) @strawberry_django.field - def point( - self, info: Info, with_parent_data: bool = False - ) -> Optional[PointFeature]: - props = {"name": self.name, "gss": self.gss} - if with_parent_data and hasattr(self, "extra_geojson_properties"): - props["extra_geojson_properties"] = self.extra_geojson_properties - - return PointFeature.from_geodjango( + def point(self, info: Info) -> Optional[AreaPointFeature]: + props = { + "area": { + "name": self.name, + "gss": self.gss, + "area_type_name": self.area_type.name, + "area_type_code": self.area_type.code, + }, + "data": ( + self.geojson_feature_properties + if hasattr(self, "geojson_feature_properties") + else {} + ), + "generic_data": ( + self.geojson_feature_genericdata + if hasattr(self, "geojson_feature_genericdata") + else None + ), + } + + return AreaPointFeature.from_geodjango( point=self.point, id=self.gss, properties=props ) @@ -701,6 +761,20 @@ class GroupedDataCount: area_data: Optional[strawberry.Private[Area]] = None is_percentage: bool = False + @strawberry_django.field + async def area(self, info: Info) -> Optional[Area]: + if self.area_data: + area = await self.area_data + elif self.gss: + area_loader = FieldDataLoaderFactory.get_loader_class( + models.Area, field="gss", select_related=["area_type"] + ) + area = await area_loader(context=info.context).load(self.gss) + if area: + graphql_area: Area = django_model_instance_to_strawberry_type(area, Area) + graphql_area.geojson_feature_properties = self.row + return graphql_area + @strawberry_django.type(models.GenericData, filters=CommonDataFilter) class GroupedData: @@ -748,7 +822,6 @@ class GenericData(CommonData): public_url: auto description: auto image: auto - area: Optional[Area] postcode: auto remote_url: str = fn_field() @@ -758,23 +831,41 @@ def postcode_data(self) -> Optional[PostcodesIOResult]: return benedict(self.postcode_data) @strawberry_django.field - def areas(self, info: Info) -> Optional[Area]: - if self.point is None: - return None - - # TODO: data loader for this - # Convert to list to make deeper async resolvers work - return list(models.Area.objects.filter(polygon__contains=self.point)) + async def area( + self, info: Info, type: Optional[AnalyticalAreaType] = None + ) -> Optional[Area]: + area = None + if type is None: + if self.area_id is not None: + area_by_id_loader = FieldDataLoaderFactory.get_loader_class( + models.Area, field="id", select_related=["area_type"] + ) + area = await area_by_id_loader(context=info.context).load(self.area_id) + elif self.postcode_data is not None: + gss = self.postcode_data["codes"].get(type.value, None) + if gss is None: + return None + area_loader = FieldDataLoaderFactory.get_loader_class( + models.Area, field="gss", select_related=["area_type"] + ) + area = await area_loader(context=info.context).load(gss) + if area: + graphql_area = django_model_instance_to_strawberry_type(area, Area) + graphql_area.geojson_feature_genericdata = self + return graphql_area @strawberry_django.field - def area_from_point(self, area_type: str, info: Info) -> Optional[Area]: - if self.point is None: - return None + async def areas(self, info: Info) -> List[Area]: + if self.postcode_data is None: + return [] - # TODO: data loader for this - return models.Area.objects.filter( - polygon__contains=self.point, area_type__code=area_type - ).first() + area_loader = FieldDataLoaderFactory.get_loader_class( + models.Area, field="gss", select_related=["area_type"] + ) + areas = await area_loader(context=info.context).load_many( + self.postcode_data["codes"].values() + ) + return [a for a in areas if a is not None] @strawberry.type @@ -809,7 +900,10 @@ def imported_data_count_by_area( postcode_io_key=analytical_area_type.value, layer_ids=layer_ids, ) - return [GroupedDataCount(**datum) for datum in data] + return [ + GroupedDataCount(**datum, area_data=datum.get("area", None)) + for datum in data + ] @strawberry_django.field def imported_data_count_of_areas( @@ -856,7 +950,7 @@ def imported_data_count_for_area( ) if len(res) == 0: return None - return GroupedDataCount(**res[0]) + return GroupedDataCount(**res[0], area_data=res[0].get("area", None)) @strawberry.type @@ -1538,9 +1632,9 @@ def generic_data_by_external_data_source( "can_display_details" ): raise ValueError(f"User {user} does not have permission to view points") - return models.GenericData.objects.filter( - data_type__data_set__external_data_source=external_data_source - ) + qs = external_data_source.get_import_data() + + return list(qs) def generic_data_from_source_about_area( @@ -1564,7 +1658,7 @@ def generic_data_from_source_about_area( stats.filter_generic_data_using_gss_code(gss, mode) ) - return qs + return list(qs) def statistics( @@ -1594,6 +1688,10 @@ def statistics_for_choropleth( map_bounds: Optional[stats.MapBounds] = None, ) -> List[GroupedDataCount]: choropleth_config = choropleth_config or stats.ChoroplethConfig() + stats_config = stats_config or stats.StatisticsConfig() + + if not stats_config.group_by_area: + raise ValueError("An area type must be specified for a choropleth") user = get_current_user(info) for source in stats_config.source_ids: @@ -1603,12 +1701,7 @@ def statistics_for_choropleth( fields_requested_by_resolver = [f.name for f in info.selected_fields[0].selections] # Start with fields requested by resolver - choropleth_statistics_columns = ["label", "gss"] - return_columns = [ - field - for field in fields_requested_by_resolver - if field in choropleth_statistics_columns - ] + return_columns = [] if "count" in fields_requested_by_resolver: # (Count will default to the count of records automatically.) diff --git a/hub/graphql/utils.py b/hub/graphql/utils.py index 19facee9e..8b2748dc8 100644 --- a/hub/graphql/utils.py +++ b/hub/graphql/utils.py @@ -1,3 +1,5 @@ +from typing import cast + import strawberry import strawberry_django from strawberry.types.info import Info @@ -40,3 +42,7 @@ def graphql_type_to_dict(value, delete_null_keys=False): lambda x: x if (x is not strawberry.UNSET) else None, delete_null_keys=delete_null_keys, ) + + +def django_model_instance_to_strawberry_type(instance, graphql_type: strawberry.type): + return cast(graphql_type, instance) diff --git a/hub/management/commands/populate_external_data_source_types.py b/hub/management/commands/populate_external_data_source_types.py index ed15dcbf4..23f02ef85 100644 --- a/hub/management/commands/populate_external_data_source_types.py +++ b/hub/management/commands/populate_external_data_source_types.py @@ -39,9 +39,7 @@ def handle(self, id, *args, **options): f"Processing source {i + 1} of {source_count}: {source} ({source.id})" ) source: ExternalDataSource - qs: list[GenericData] = GenericData.objects.filter( - data_type__data_set__external_data_source_id=source.id - ).order_by("id") + qs: list[GenericData] = source.get_import_data().order_by("id") source_column_types = {} data_count = qs.count() diff --git a/hub/models.py b/hub/models.py index d155d2ed2..4b529acd1 100644 --- a/hub/models.py +++ b/hub/models.py @@ -861,6 +861,25 @@ def save(self, *args, **kwargs): super().save(*args, **kwargs) + def to_dict(self): + return { + "id": self.id, + "postcode": self.postcode, + "first_name": self.first_name, + "last_name": self.last_name, + "full_name": self.full_name, + "email": self.email, + "phone": self.phone, + "start_time": self.start_time, + "end_time": self.end_time, + "public_url": self.public_url, + "social_url": self.social_url, + "address": self.address, + "title": self.title, + "description": self.description, + "json": self.json, + } + class Area(models.Model): mapit_id = models.CharField(max_length=30) @@ -2748,10 +2767,10 @@ async def deferred_import_all( ) priority_enum = None try: - match member_count: - case ( - _ - ) if member_count < settings.SUPER_QUICK_IMPORT_ROW_COUNT_THRESHOLD: + match len(members): + case _ if len( + members + ) < settings.SUPER_QUICK_IMPORT_ROW_COUNT_THRESHOLD: priority_enum = ProcrastinateQueuePriority.SUPER_QUICK case ( _ diff --git a/nextjs/src/__generated__/graphql.ts b/nextjs/src/__generated__/graphql.ts index a137878fa..183555eda 100644 --- a/nextjs/src/__generated__/graphql.ts +++ b/nextjs/src/__generated__/graphql.ts @@ -1493,6 +1493,7 @@ export type GroupedData = { export type GroupedDataCount = { __typename?: 'GroupedDataCount'; + area?: Maybe; areaData?: Maybe; category?: Maybe; columns?: Maybe>; @@ -2931,6 +2932,7 @@ export type StatisticsConfig = { queryId?: InputMaybe; returnColumns?: InputMaybe>; sourceIds?: InputMaybe>; + summaryCalculations?: InputMaybe; }; export type StrFilterLookup = { diff --git a/nextjs/src/__generated__/zodSchema.ts b/nextjs/src/__generated__/zodSchema.ts index 193ae542e..f92126f1c 100644 --- a/nextjs/src/__generated__/zodSchema.ts +++ b/nextjs/src/__generated__/zodSchema.ts @@ -510,7 +510,8 @@ export function StatisticsConfigSchema(): z.ZodObject