Skip to content

Commit 1112f51

Browse files
authored
feature: support property filters in observations_dataframe (#225)
Feature: specify property filters so that only the relevant observations are fetched. This could substantially reduce the amount of data transferred if users specify a particular `unit`,`measurementMethod`, `observationPeriod`, etc.
1 parent 1018af1 commit 1112f51

File tree

10 files changed

+517
-77
lines changed

10 files changed

+517
-77
lines changed

datacommons_client/client.py

Lines changed: 86 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datacommons_client.endpoints.payloads import ObservationDate
77
from datacommons_client.endpoints.resolve import ResolveEndpoint
88
from datacommons_client.utils.decorators import requires_pandas
9+
from datacommons_client.utils.error_handling import NoDataForPropertyError
910

1011
try:
1112
import pandas as pd
@@ -58,6 +59,58 @@ def __init__(
5859
self.observation = ObservationEndpoint(api=self.api)
5960
self.resolve = ResolveEndpoint(api=self.api)
6061

62+
def _find_filter_facet_ids(
63+
self,
64+
fetch_by: Literal["entity", "entity_type"],
65+
date: ObservationDate | str,
66+
variable_dcids: str | list[str],
67+
entity_dcids: Literal["all"] | list[str] = "all",
68+
entity_type: Optional[str] = None,
69+
parent_entity: Optional[str] = None,
70+
property_filters: Optional[dict[str, str | list[str]]] = None,
71+
) -> list[str] | None:
72+
"""Finds matching facet IDs for property filters.
73+
74+
Args:
75+
fetch_by (Literal["entity", "entity_type"]): Determines whether to fetch by entity or entity type.
76+
variable_dcids (str | list[str]): The variable DCIDs for which to retrieve facet IDs.
77+
entity_dcids (Literal["all"] | list[str], optional): The entity DCIDs, or "all" if filtering by entity type.
78+
entity_type (Optional[str]): The entity type, required if fetching by entity type.
79+
parent_entity (Optional[str]): The parent entity, used when fetching by entity type.
80+
property_filters (Optional[dict[str, str | list[str]]): A dictionary of properties to match facets against.
81+
82+
Returns:
83+
list[str] | None: A list of matching facet IDs, or None if no filters are applied.
84+
"""
85+
86+
if not property_filters:
87+
return None
88+
89+
if fetch_by == "entity":
90+
observations = self.observation.fetch_observations_by_entity(
91+
date=date,
92+
entity_dcids=entity_dcids,
93+
variable_dcids=variable_dcids,
94+
select=["variable", "entity", "facet"],
95+
)
96+
else:
97+
observations = self.observation.fetch_observations_by_entity_type(
98+
date=date,
99+
entity_type=entity_type,
100+
parent_entity=parent_entity,
101+
variable_dcids=variable_dcids,
102+
select=["variable", "entity", "facet"],
103+
)
104+
105+
facet_sets = [
106+
observations.find_matching_facet_id(property_name=p, value=v)
107+
for p, v in property_filters.items()
108+
]
109+
110+
facet_ids = list({facet for facets in facet_sets for facet in facets})
111+
112+
return facet_ids
113+
61114
@requires_pandas
62115
def observations_dataframe(
63116
self,
@@ -66,6 +119,7 @@ def observations_dataframe(
66119
entity_dcids: Literal["all"] | list[str] = "all",
67120
entity_type: Optional[str] = None,
68121
parent_entity: Optional[str] = None,
122+
property_filters: Optional[dict[str, str | list[str]]] = None,
69123
):
70124
"""
71125
Fetches statistical observations and returns them as a Pandas DataFrame.
@@ -74,15 +128,17 @@ def observations_dataframe(
74128
at a particular date (e.g., "population of USA in 2020", "GDP of California in 2010").
75129
76130
Args:
77-
variable_dcids (str | list[str]): One or more variable DCIDs for the observation.
78-
date (ObservationDate | str): The date for which observations are requested. It can be
131+
variable_dcids (str | list[str]): One or more variable DCIDs for the observation.
132+
date (ObservationDate | str): The date for which observations are requested. It can be
79133
a specific date, "all" to retrieve all observations, or "latest" to get the most recent observations.
80-
entity_dcids (Literal["all"] | list[str], optional): The entity DCIDs to retrieve data for.
81-
Defaults to "all". DCIDs must include their type (e.g "country/GTM" for Guatemala).
82-
entity_type (Optional[str], optional): The type of entities to filter by when `entity_dcids="all"`.
83-
Required if `entity_dcids="all"`. Defaults to None.
84-
parent_entity (Optional[str], optional): The parent entity under which the target entities fall.
85-
Used only when `entity_dcids="all"`. Defaults to None.
134+
entity_dcids (Literal["all"] | list[str], optional): The entity DCIDs to retrieve data for.
135+
Defaults to "all". DCIDs must include their type (e.g., "country/GTM" for Guatemala).
136+
entity_type (Optional[str]): The type of entities to filter by when `entity_dcids="all"`.
137+
Required if `entity_dcids="all"`. Defaults to None.
138+
parent_entity (Optional[str]): The parent entity under which the target entities fall.
139+
Used only when `entity_dcids="all"`. Defaults to None.
140+
property_filters (Optional[dict[str, str | list[str]]): An optional dictionary used to filter
141+
the data by using observation properties like `measurementMethod`, `unit`, or `observationPeriod`.
86142
87143
Returns:
88144
pd.DataFrame: A DataFrame containing the requested observations.
@@ -97,14 +153,34 @@ def observations_dataframe(
97153
"Specify 'entity_type' and 'parent_entity' only when 'entity_dcids' is 'all'."
98154
)
99155

156+
# If property filters are provided, fetch the required facet IDs. Otherwise, set to None.
157+
facets = self._find_filter_facet_ids(
158+
fetch_by="entity" if entity_dcids != "all" else "entity_type",
159+
date=date,
160+
variable_dcids=variable_dcids,
161+
entity_dcids=entity_dcids,
162+
entity_type=entity_type,
163+
parent_entity=parent_entity,
164+
property_filters=property_filters,
165+
)
166+
167+
if not facets and property_filters:
168+
raise NoDataForPropertyError
169+
100170
if entity_dcids == "all":
101171
observations = self.observation.fetch_observations_by_entity_type(
102172
date=date,
103173
parent_entity=parent_entity,
104174
entity_type=entity_type,
105-
variable_dcids=variable_dcids)
175+
variable_dcids=variable_dcids,
176+
filter_facet_ids=facets,
177+
)
106178
else:
107179
observations = self.observation.fetch_observations_by_entity(
108-
date=date, entity_dcids=entity_dcids, variable_dcids=variable_dcids)
180+
date=date,
181+
entity_dcids=entity_dcids,
182+
variable_dcids=variable_dcids,
183+
filter_facet_ids=facets,
184+
)
109185

110186
return pd.DataFrame(observations.get_observations_as_records())

datacommons_client/endpoints/observation.py

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def fetch_latest_observations(
6868
entity_dcids: Optional[str | list[str]] = None,
6969
entity_expression: Optional[str] = None,
7070
*,
71+
select: Optional[list[ObservationSelect | str]] = None,
7172
filter_facet_domains: Optional[str | list[str]] = None,
7273
filter_facet_ids: Optional[str | list[str]] = None,
7374
) -> ObservationResponse:
@@ -78,6 +79,8 @@ def fetch_latest_observations(
7879
variable_dcids (str | list[str]): One or more variable IDs for the data.
7980
entity_dcids (Optional[str | list[str]]): One or more entity IDs to filter the data.
8081
entity_expression (Optional[str]): A string expression to filter entities.
82+
select (Optional[list[ObservationSelect | str]]): Fields to include in the response.
83+
If not provided, defaults to ["date", "variable", "entity", "value"].
8184
filter_facet_domains: Optional[str | list[str]: One or more domain names to filter the data.
8285
filter_facet_ids: Optional[str | list[str]: One or more facet IDs to filter the data.
8386
@@ -91,13 +94,15 @@ def fetch_latest_observations(
9194
entity_expression=entity_expression,
9295
filter_facet_domains=filter_facet_domains,
9396
filter_facet_ids=filter_facet_ids,
97+
select=[s for s in ObservationSelect] if not select else select,
9498
)
9599

96100
def fetch_latest_observations_by_entity(
97101
self,
98102
variable_dcids: str | list[str],
99103
entity_dcids: str | list[str],
100104
*,
105+
select: Optional[list[ObservationSelect | str]] = None,
101106
filter_facet_domains: Optional[str | list[str]] = None,
102107
filter_facet_ids: Optional[str | list[str]] = None,
103108
) -> ObservationResponse:
@@ -106,6 +111,8 @@ def fetch_latest_observations_by_entity(
106111
Args:
107112
variable_dcids (str | list[str]): One or more variable IDs for the data.
108113
entity_dcids (str | list[str]): One or more entity IDs to filter the data.
114+
select (Optional[list[ObservationSelect | str]]): Fields to include in the response.
115+
If not provided, defaults to ["date", "variable", "entity", "value"].
109116
filter_facet_domains: Optional[str | list[str]: One or more domain names to filter the data.
110117
filter_facet_ids: Optional[str | list[str]: One or more facet IDs to filter the data.
111118
@@ -116,8 +123,10 @@ def fetch_latest_observations_by_entity(
116123
return self.fetch_latest_observations(
117124
variable_dcids=variable_dcids,
118125
entity_dcids=entity_dcids,
126+
select=[s for s in ObservationSelect] if not select else select,
119127
filter_facet_domains=filter_facet_domains,
120-
filter_facet_ids=filter_facet_ids)
128+
filter_facet_ids=filter_facet_ids,
129+
)
121130

122131
def fetch_observations_by_entity_type(
123132
self,
@@ -126,6 +135,7 @@ def fetch_observations_by_entity_type(
126135
entity_type: str,
127136
variable_dcids: str | list[str],
128137
*,
138+
select: Optional[list[ObservationSelect | str]] = None,
129139
filter_facet_domains: Optional[str | list[str]] = None,
130140
filter_facet_ids: Optional[str | list[str]] = None,
131141
) -> ObservationResponse:
@@ -142,6 +152,8 @@ def fetch_observations_by_entity_type(
142152
For example, "Country" or "Region".
143153
variable_dcids (str | list[str]): The variable(s) to fetch observations for.
144154
This can be a single variable ID or a list of IDs.
155+
select (Optional[list[ObservationSelect | str]]): Fields to include in the response.
156+
If not provided, defaults to ["date", "variable", "entity", "value"].
145157
filter_facet_domains: Optional[str | list[str]: One or more domain names to filter the data.
146158
filter_facet_ids: Optional[str | list[str]: One or more facet IDs to filter the data.
147159
@@ -165,7 +177,7 @@ def fetch_observations_by_entity_type(
165177
return self.fetch(
166178
variable_dcids=variable_dcids,
167179
date=date,
168-
select=[s for s in ObservationSelect],
180+
select=[s for s in ObservationSelect] if not select else select,
169181
entity_expression=
170182
f"{parent_entity}<-containedInPlace+{{typeOf:{entity_type}}}",
171183
filter_facet_domains=filter_facet_domains,
@@ -178,6 +190,7 @@ def fetch_observations_by_entity(
178190
entity_dcids: str | list[str],
179191
variable_dcids: str | list[str],
180192
*,
193+
select: Optional[list[ObservationSelect | str]] = None,
181194
filter_facet_domains: Optional[str | list[str]] = None,
182195
filter_facet_ids: Optional[str | list[str]] = None,
183196
) -> ObservationResponse:
@@ -191,6 +204,8 @@ def fetch_observations_by_entity(
191204
entity_dcids (str | list[str]): One or more entity IDs to filter the data.
192205
variable_dcids (str | list[str]): The variable(s) to fetch observations for.
193206
This can be a single variable ID or a list of IDs.
207+
select (Optional[list[ObservationSelect | str]]): Fields to include in the response.
208+
If not provided, defaults to ["date", "variable", "entity", "value"].
194209
filter_facet_domains: Optional[str | list[str]: One or more domain names to filter the data.
195210
filter_facet_ids: Optional[str | list[str]: One or more facet IDs to filter the data.
196211
@@ -213,7 +228,7 @@ def fetch_observations_by_entity(
213228
return self.fetch(
214229
variable_dcids=variable_dcids,
215230
date=date,
216-
select=[s for s in ObservationSelect],
231+
select=[s for s in ObservationSelect] if not select else select,
217232
entity_dcids=entity_dcids,
218233
filter_facet_domains=filter_facet_domains,
219234
filter_facet_ids=filter_facet_ids,

datacommons_client/endpoints/payloads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,7 @@ class ObservationSelect(str, Enum):
7878
VARIABLE = "variable"
7979
ENTITY = "entity"
8080
VALUE = "value"
81+
FACET = "facet"
8182

8283
@classmethod
8384
def _missing_(cls, value):
@@ -138,7 +139,6 @@ def __post_init__(self):
138139
def normalize(self):
139140
"""
140141
Normalizes the payload for consistent internal representation.
141-
142142
- Converts `variable_dcids`, `entity_dcids`, `filter_facet_domains` and `filter_facet_ids`
143143
to lists if they are passed as strings.
144144
- Normalizes the `date` field to ensure it is in the correct format.

datacommons_client/endpoints/response.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -151,6 +151,79 @@ def get_observations_as_records(self) -> List[Dict[str, Any]]:
151151
return observations_as_records(data=self.get_data_by_entity(),
152152
facets=self.facets)
153153

154+
def get_facets_metadata(self) -> Dict[str, Any]:
155+
"""Extract metadata about StatVars from the response. This data is
156+
structured as a dictionary of StatVars, each containing a dictionary of
157+
facets with their corresponding metadata.
158+
159+
Returns:
160+
Dict[str, Any]: A dictionary of StatVars with their associated metadata,
161+
including earliest and latest observation dates, observation counts,
162+
measurementMethod, observationPeriod, and unit, etc.
163+
"""
164+
# Dictionary to store metadata
165+
metadata = {}
166+
167+
# Extract information from byVariable
168+
data_by_entity = self.get_data_by_entity()
169+
170+
# Extract facet information
171+
facets_info = self.to_dict().get("facets", {})
172+
173+
for dcid, variables in data_by_entity.items():
174+
metadata[dcid] = {}
175+
176+
for entity_id, entity in variables.items():
177+
for facet in entity.get("orderedFacets", []):
178+
facet_metadata = metadata[dcid].setdefault(
179+
facet.facetId,
180+
{
181+
"earliestDate": {},
182+
"latestDate": {},
183+
"obsCount": {},
184+
},
185+
)
186+
187+
facet_metadata["earliestDate"][entity_id] = facet.earliestDate
188+
facet_metadata["latestDate"][entity_id] = facet.latestDate
189+
facet_metadata["obsCount"][entity_id] = facet.obsCount
190+
191+
# Merge additional facet details
192+
facet_metadata.update(facets_info.get(facet.facetId, {}))
193+
194+
return metadata
195+
196+
def find_matching_facet_id(self, property_name: str,
197+
value: str | list[str]) -> list[str]:
198+
"""Finds facet IDs that match a given property and value.
199+
200+
Args:
201+
property_name (str): The property to match.
202+
value (str | list): The value to match. Can be a string, number, or a list of values.
203+
Returns:
204+
list[str]: A list of facet IDs that match the property and value.
205+
"""
206+
# Initialize an empty list to store matching facet IDs
207+
matching_facet_ids = []
208+
209+
# Iterate over the facets metadata to find matching facet IDs
210+
for facet_data in self.get_facets_metadata().values():
211+
212+
# Iterate over each facet and its associated metadata
213+
for facet_id, metadata in facet_data.items():
214+
215+
# Get the value of the specified property from the data
216+
prop_value = metadata.get(property_name)
217+
218+
# Check if the property value matches the specified value
219+
if isinstance(value, list):
220+
if prop_value in value:
221+
matching_facet_ids.append(facet_id)
222+
elif prop_value == value:
223+
matching_facet_ids.append(facet_id)
224+
225+
return matching_facet_ids
226+
154227

155228
@dataclass
156229
class ResolveResponse(SerializableMixin):

datacommons_client/tests/endpoints/test_error_handling.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from datacommons_client.utils.error_handling import DCConnectionError
88
from datacommons_client.utils.error_handling import DCStatusError
99
from datacommons_client.utils.error_handling import InvalidDCInstanceError
10+
from datacommons_client.utils.error_handling import NoDataForPropertyError
1011

1112

1213
def test_data_commons_error_default_message():
@@ -59,6 +60,9 @@ def test_subclass_default_messages():
5960
instance_error = InvalidDCInstanceError()
6061
assert InvalidDCInstanceError.default_message in str(instance_error)
6162

63+
filter_error = NoDataForPropertyError()
64+
assert NoDataForPropertyError.default_message in str(filter_error)
65+
6266

6367
def test_subclass_custom_message():
6468
"""Tests that subclasses use custom messages when provided."""

0 commit comments

Comments
 (0)