Skip to content

Commit 0cbfd8c

Browse files
authored
Add fetch_available_statistical_variables (#229)
**Fetch available statistical variables**: - Adds a `fetch_available_statistical_variables` method to the ObservationEndpoint. This method fetches all statvars which have observations for one or more entities.
1 parent a2017fa commit 0cbfd8c

File tree

4 files changed

+132
-0
lines changed

4 files changed

+132
-0
lines changed

datacommons_client/endpoints/observation.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from datacommons_client.endpoints.payloads import ObservationRequestPayload
77
from datacommons_client.endpoints.payloads import ObservationSelect
88
from datacommons_client.endpoints.response import ObservationResponse
9+
from datacommons_client.utils.data_processing import group_variables_by_entity
910

1011

1112
class ObservationEndpoint(Endpoint):
@@ -233,3 +234,25 @@ def fetch_observations_by_entity(
233234
filter_facet_domains=filter_facet_domains,
234235
filter_facet_ids=filter_facet_ids,
235236
)
237+
238+
def fetch_available_statistical_variables(
239+
self,
240+
entity_dcids: str | list[str],
241+
) -> dict[str, list[str]]:
242+
"""
243+
Fetches available statistical variables (which have observations) for given entities.
244+
Args:
245+
entity_dcids (str | list[str]): One or more entity DCIDs(s) to fetch variables for.
246+
Returns:
247+
dict[str, list[str]]: A dictionary mapping entity DCIDs to their available statistical variables.
248+
"""
249+
250+
# Fetch observations for the given entity DCIDs. If variable is empty list
251+
# all available variables are retrieved.
252+
data = self.fetch(
253+
entity_dcids=entity_dcids,
254+
select=[ObservationSelect.VARIABLE, ObservationSelect.ENTITY],
255+
variable_dcids=[],
256+
).get_data_by_entity()
257+
258+
return group_variables_by_entity(data=data)

datacommons_client/tests/endpoints/test_observation_endpoint.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
from datacommons_client.endpoints.base import API
44
from datacommons_client.endpoints.observation import ObservationEndpoint
55
from datacommons_client.endpoints.payloads import ObservationDate
6+
from datacommons_client.endpoints.payloads import ObservationSelect
67
from datacommons_client.endpoints.response import ObservationResponse
78

89

@@ -164,3 +165,51 @@ def test_fetch_observations_facets_by_entity_type():
164165
endpoint="observation",
165166
all_pages=True,
166167
next_token=None)
168+
169+
170+
def test_fetch_available_statistical_variables_single_entity():
171+
"""Test fetching variables for a single entity."""
172+
mock_data = {
173+
"var1": ["ent1"],
174+
"var2": ["ent1"],
175+
}
176+
177+
# Mock the fetch method on the ObservationEndpoint instance
178+
endpoint = ObservationEndpoint(api=MagicMock())
179+
endpoint.fetch = MagicMock()
180+
endpoint.fetch.return_value.get_data_by_entity = MagicMock(
181+
return_value=mock_data)
182+
183+
result = endpoint.fetch_available_statistical_variables("ent1")
184+
185+
expected = {
186+
"ent1": ["var1", "var2"],
187+
}
188+
assert result == expected
189+
190+
endpoint.fetch.assert_called_once_with(
191+
entity_dcids="ent1",
192+
select=[ObservationSelect.VARIABLE, ObservationSelect.ENTITY],
193+
variable_dcids=[],
194+
)
195+
196+
197+
def test_fetch_available_statistical_variables_multiple_entities():
198+
"""Test fetching variables for multiple entities."""
199+
mock_data = {
200+
"var1": ["ent1", "ent2"],
201+
"var2": ["ent2"],
202+
}
203+
204+
endpoint = ObservationEndpoint(api=MagicMock())
205+
endpoint.fetch = MagicMock()
206+
endpoint.fetch.return_value.get_data_by_entity = MagicMock(
207+
return_value=mock_data)
208+
209+
result = endpoint.fetch_available_statistical_variables(["ent1", "ent2"])
210+
211+
expected = {
212+
"ent1": ["var1"],
213+
"ent2": ["var1", "var2"],
214+
}
215+
assert result == expected
Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
from datacommons_client.utils.data_processing import group_variables_by_entity
2+
3+
4+
def test_group_variables_by_entity_basic():
5+
"""Test grouping with simple variable-entity mapping."""
6+
input_data = {
7+
"var1": ["ent1", "ent2"],
8+
"var2": ["ent2", "ent3"],
9+
"var3": ["ent1"],
10+
}
11+
expected_output = {
12+
"ent1": ["var1", "var3"],
13+
"ent2": ["var1", "var2"],
14+
"ent3": ["var2"],
15+
}
16+
17+
result = group_variables_by_entity(input_data)
18+
assert result == expected_output
19+
20+
21+
def test_group_variables_by_entity_duplicate_entities():
22+
"""Test grouping when a variable has duplicate entities."""
23+
input_data = {
24+
"var1": ["ent1", "ent1", "ent2"],
25+
}
26+
result = group_variables_by_entity(input_data)
27+
assert result["ent1"].count("var1") == 2 # duplicates are preserved
28+
assert "ent2" in result
29+
assert result["ent2"] == ["var1"]
30+
31+
32+
def test_group_variables_by_entity_preserves_order():
33+
"""Test if the order of variables is preserved in the resulting entity lists."""
34+
input_data = {
35+
"var1": ["ent1"],
36+
"var2": ["ent1"],
37+
"var3": ["ent1"],
38+
}
39+
result = group_variables_by_entity(input_data)
40+
assert result["ent1"] == ["var1", "var2", "var3"]

datacommons_client/utils/data_processing.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -93,3 +93,23 @@ def observations_as_records(data: dict, facets: dict) -> list[dict]:
9393
facet_metadata=facets,
9494
)
9595
]
96+
97+
98+
def group_variables_by_entity(
99+
data: dict[str, list[str]]) -> dict[str, list[str]]:
100+
"""Groups variables by the entities they are associated with.
101+
Takes a dictionary mapping statistical variable DCIDs to a list of entity DCIDs,
102+
and returns a new dictionary mapping each entity DCID to a list of statistical
103+
variables available for that entity.
104+
Args:
105+
data: A dictionary where each key is a variable DCID and the value is a list
106+
of entity DCIDs that have observations for that variable.
107+
Returns:
108+
A dictionary where each key is an entity DCID and the value is a list of
109+
variable DCIDs available for that entity.
110+
"""
111+
result: dict[str, list[str]] = {}
112+
for variable, entities in data.items():
113+
for entity in entities:
114+
result.setdefault(entity, []).append(variable)
115+
return result

0 commit comments

Comments
 (0)