Skip to content

Commit

Permalink
linting
Browse files Browse the repository at this point in the history
  • Loading branch information
cmutel committed Sep 19, 2024
1 parent 2ce8a7a commit 4b28db0
Show file tree
Hide file tree
Showing 35 changed files with 292 additions and 688 deletions.
113 changes: 35 additions & 78 deletions bw2data/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,18 +16,6 @@
from tqdm import tqdm

from bw2data import config, databases, geomapping
from bw2data.configuration import labels
from bw2data.data_store import ProcessedDataStore
from bw2data.errors import (
DuplicateNode,
InvalidExchange,
UnknownObject,
UntypedExchange,
WrongDatabase,
)
from bw2data.query import Query
from bw2data.search import IndexManager, Searcher
from bw2data.utils import as_uncertainty_dict, get_geocollection, get_node
from bw2data.backends import sqlite3_lci_db
from bw2data.backends.proxies import Activity
from bw2data.backends.schema import ActivityDataset, ExchangeDataset, get_id
Expand All @@ -44,6 +32,18 @@
get_csv_data_dict,
retupleize_geo_strings,
)
from bw2data.configuration import labels
from bw2data.data_store import ProcessedDataStore
from bw2data.errors import (
DuplicateNode,
InvalidExchange,
UnknownObject,
UntypedExchange,
WrongDatabase,
)
from bw2data.query import Query
from bw2data.search import IndexManager, Searcher
from bw2data.utils import as_uncertainty_dict, get_geocollection, get_node

_VALID_KEYS = {"location", "name", "product", "type"}

Expand Down Expand Up @@ -277,9 +277,7 @@ def find_graph_dependents(self):
"""

def extend(seeds):
return set.union(
seeds, set.union(*[set(databases[obj]["depends"]) for obj in seeds])
)
return set.union(seeds, set.union(*[set(databases[obj]["depends"]) for obj in seeds]))

seed, extended = {self.name}, extend({self.name})
while extended != seed:
Expand Down Expand Up @@ -368,12 +366,7 @@ def relabel_exchanges(obj, new_name):
e["input"] = (new_name, e["input"][1])
return obj

return dict(
[
((new_name, k[1]), relabel_exchanges(v, new_name))
for k, v in data.items()
]
)
return dict([((new_name, k[1]), relabel_exchanges(v, new_name)) for k, v in data.items()])

def rename(self, name):
"""Rename a database. Modifies exchanges to link to new name. Deregisters old database.
Expand Down Expand Up @@ -439,10 +432,7 @@ def _set_filters(self, filters):
if not filters:
self._filters = {}
else:
print(
"Filters will effect all database queries"
" until unset (`.filters = None`)"
)
print("Filters will effect all database queries" " until unset (`.filters = None`)")
assert isinstance(filters, dict), "Filter must be a dictionary"
for key in filters:
assert key in _VALID_KEYS, "Filter key {} is invalid".format(key)
Expand All @@ -469,14 +459,10 @@ def random(self, filters=True, true_random=False):
"""True random requires loading and sorting data in SQLite, and can be resource-intensive."""
try:
if true_random:
return self.node_class(
self._get_queryset(random=True, filters=filters).get()
)
return self.node_class(self._get_queryset(random=True, filters=filters).get())
else:
return self.node_class(
self._get_queryset(filters=filters)
.offset(random.randint(0, len(self)))
.get()
self._get_queryset(filters=filters).offset(random.randint(0, len(self))).get()
)
except DoesNotExist:
warnings.warn("This database is empty")
Expand Down Expand Up @@ -569,9 +555,7 @@ def _efficient_write_many_data(
self.delete(keep_params=True, warn=False, vacuum=False)
exchanges, activities = [], []

for key, ds in tqdm_wrapper(
data.items(), getattr(config, "is_test", False)
):
for key, ds in tqdm_wrapper(data.items(), getattr(config, "is_test", False)):
exchanges, activities = self._efficient_write_dataset(
key, ds, exchanges, activities, check_typos
)
Expand Down Expand Up @@ -677,12 +661,12 @@ def new_activity(self, code, **kwargs):

def new_node(self, code: str = None, **kwargs):
obj = self.node_class()
if 'database' in kwargs:
if kwargs['database'] != self.name:
if "database" in kwargs:
if kwargs["database"] != self.name:
raise ValueError(
f"Creating a new node in database `{self.name}`, but gave database label `{kwargs['database']}`"
)
kwargs.pop('database')
kwargs.pop("database")
obj["database"] = self.name

if code is None:
Expand All @@ -702,22 +686,17 @@ def new_node(self, code: str = None, **kwargs):

if (
ActivityDataset.select()
.where(
(ActivityDataset.database == self.name)
& (ActivityDataset.code == obj["code"])
)
.where((ActivityDataset.database == self.name) & (ActivityDataset.code == obj["code"]))
.count()
):
raise DuplicateNode("Node with this database / code combo already exists")
if (
"id" in kwargs
and ActivityDataset.select()
.where(ActivityDataset.id == int("id" in kwargs))
.count()
and ActivityDataset.select().where(ActivityDataset.id == int("id" in kwargs)).count()
):
raise DuplicateNode("Node with this id already exists")

if 'location' not in kwargs:
if "location" not in kwargs:
obj["location"] = config.global_location
obj.update(kwargs)
return obj
Expand Down Expand Up @@ -751,9 +730,7 @@ def delete(self, keep_params=False, warn=True, vacuum=True):
vacuum_needed = len(self) > 500 and vacuum

ActivityDataset.delete().where(ActivityDataset.database == self.name).execute()
ExchangeDataset.delete().where(
ExchangeDataset.output_database == self.name
).execute()
ExchangeDataset.delete().where(ExchangeDataset.output_database == self.name).execute()
IndexManager(self.filename).delete_database()

if not keep_params:
Expand All @@ -771,15 +748,9 @@ def delete(self, keep_params=False, warn=True, vacuum=True):
.tuples()
}
)
ParameterizedExchange.delete().where(
ParameterizedExchange.group << groups
).execute()
ActivityParameter.delete().where(
ActivityParameter.database == self.name
).execute()
DatabaseParameter.delete().where(
DatabaseParameter.database == self.name
).execute()
ParameterizedExchange.delete().where(ParameterizedExchange.group << groups).execute()
ActivityParameter.delete().where(ActivityParameter.database == self.name).execute()
DatabaseParameter.delete().where(DatabaseParameter.database == self.name).execute()

if vacuum_needed:
sqlite3_lci_db.vacuum()
Expand Down Expand Up @@ -812,9 +783,7 @@ def exchange_data_iterator(self, qs_func, dependents, flip=False):
"Exchange between {} and {} is invalid "
"- one of these objects is unknown (i.e. doesn't exist "
"as a process dataset)"
).format(
(input_database, input_code), (output_database, output_code)
)
).format((input_database, input_code), (output_database, output_code))
)
yield {
**as_uncertainty_dict(data),
Expand All @@ -828,9 +797,7 @@ def _add_inventory_geomapping_to_datapackage(self, dp: Datapackage) -> None:
Separated out to allow for easier use in subclasses."""
# Create geomapping array, from dataset interger ids to locations
inv_mapping_qs = ActivityDataset.select(
ActivityDataset.id, ActivityDataset.location
).where(
inv_mapping_qs = ActivityDataset.select(ActivityDataset.id, ActivityDataset.location).where(
ActivityDataset.database == self.name,
ActivityDataset.type << labels.process_node_types,
)
Expand All @@ -840,9 +807,7 @@ def _add_inventory_geomapping_to_datapackage(self, dp: Datapackage) -> None:
dict_iterator=(
{
"row": row[0],
"col": geomapping[
retupleize_geo_strings(row[1]) or config.global_location
],
"col": geomapping[retupleize_geo_strings(row[1]) or config.global_location],
"amount": 1,
}
for row in inv_mapping_qs.tuples()
Expand Down Expand Up @@ -915,9 +880,7 @@ def process(self, csv=False):
matrix="technosphere_matrix",
name=clean_datapackage_name(self.name + " technosphere matrix"),
dict_iterator=itertools.chain(
self.exchange_data_iterator(
get_technosphere_negative_qs, dependents, flip=True
),
self.exchange_data_iterator(get_technosphere_negative_qs, dependents, flip=True),
self.exchange_data_iterator(get_technosphere_positive_qs, dependents),
implicit_production,
),
Expand Down Expand Up @@ -1045,14 +1008,10 @@ def nodes_to_dataframe(
# Feels like magic
df = pandas.DataFrame(self)
else:
df = pandas.DataFrame(
[{field: obj.get(field) for field in columns} for obj in self]
)
df = pandas.DataFrame([{field: obj.get(field) for field in columns} for obj in self])
if return_sorted:
sort_columns = ["name", "reference product", "location", "unit"]
df = df.sort_values(
by=[column for column in sort_columns if column in df.columns]
)
df = df.sort_values(by=[column for column in sort_columns if column in df.columns])
return df

def edges_to_dataframe(
Expand Down Expand Up @@ -1119,9 +1078,7 @@ def edges_to_dataframe(
"source_location": edge.get("location"),
"source_unit": edge.get("unit"),
"source_categories": (
"::".join(edge["categories"])
if edge.get("categories")
else None
"::".join(edge["categories"]) if edge.get("categories") else None
),
"edge_amount": edge["amount"],
"edge_type": edge["type"],
Expand Down
34 changes: 9 additions & 25 deletions bw2data/backends/iotable/backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,9 @@
from fsspec.implementations.zip import ZipFileSystem

from bw2data import config, databases, geomapping
from bw2data.configuration import labels
from bw2data.backends import SQLiteBackend
from bw2data.backends.iotable.proxies import IOTableActivity, IOTableExchanges
from bw2data.configuration import labels


class IOTableBackend(SQLiteBackend):
Expand All @@ -21,9 +21,7 @@ class IOTableBackend(SQLiteBackend):
node_class = IOTableActivity

def write(self, data, process=False, searchable=True, check_typos=True):
super().write(
data, process=False, searchable=searchable, check_typos=check_typos
)
super().write(data, process=False, searchable=searchable, check_typos=check_typos)

def write_exchanges(self, technosphere, biosphere, dependents):
"""
Expand Down Expand Up @@ -51,9 +49,7 @@ def write_exchanges(self, technosphere, biosphere, dependents):
dict_iterator=(
{
"row": obj.id,
"col": geomapping[
obj.get("location", None) or config.global_location
],
"col": geomapping[obj.get("location", None) or config.global_location],
"amount": 1,
}
for obj in self
Expand All @@ -79,9 +75,7 @@ def write_exchanges(self, technosphere, biosphere, dependents):
dict_iterator=technosphere,
)
else:
raise Exception(
f"Error: Unsupported technosphere type: {type(technosphere)}"
)
raise Exception(f"Error: Unsupported technosphere type: {type(technosphere)}")

print("Adding biosphere matrix")
# if biosphere is a dictionary pass it's keys & values
Expand All @@ -105,9 +99,7 @@ def write_exchanges(self, technosphere, biosphere, dependents):
print("Finalizing serialization")
dp.finalize_serialization()

databases[self.name]["depends"] = sorted(
set(dependents).difference({self.name})
)
databases[self.name]["depends"] = sorted(set(dependents).difference({self.name}))
databases[self.name]["processed"] = datetime.datetime.now().isoformat()
databases.flush()

Expand Down Expand Up @@ -179,24 +171,18 @@ def dict_for_obj(obj, prefix):
dct["source_product"] = obj.get("product")
return dct

return pd.DataFrame(
[dict_for_obj(get(id_), prefix) for id_ in np.unique(ids)]
)
return pd.DataFrame([dict_for_obj(get(id_), prefix) for id_ in np.unique(ids)])

def get_edge_types(exchanges):
arrays = []
for resource in exchanges.resources:
if resource["data"]["matrix"] == "biosphere_matrix":
arrays.append(
np.array(
[labels.biosphere_edge_default]
* len(resource["data"]["array"])
)
np.array([labels.biosphere_edge_default] * len(resource["data"]["array"]))
)
else:
arr = np.array(
[labels.consumption_edge_default]
* len(resource["data"]["array"])
[labels.consumption_edge_default] * len(resource["data"]["array"])
)
arr[resource["flip"]["positive"]] = labels.production_edge_default
arrays.append(arr)
Expand All @@ -212,9 +198,7 @@ def get_edge_types(exchanges):
source_ids = np.hstack(
[resource["indices"]["array"]["row"] for resource in exchanges.resources]
)
edge_amounts = np.hstack(
[resource["data"]["array"] for resource in exchanges.resources]
)
edge_amounts = np.hstack([resource["data"]["array"] for resource in exchanges.resources])
edge_types = get_edge_types(exchanges)

print("Creating metadata dataframes")
Expand Down
Loading

0 comments on commit 4b28db0

Please sign in to comment.