Skip to content

Commit

Permalink
v.db.join: handle existing columns properly (OSGeo#3765)
Browse files Browse the repository at this point in the history
* handle existing columns properly

* refactoring and removal of Python2 compatible code

* add a basic testsuite
  • Loading branch information
ninsbl authored Jun 5, 2024
1 parent 64c8b4a commit fcb6199
Show file tree
Hide file tree
Showing 2 changed files with 177 additions and 85 deletions.
84 changes: 84 additions & 0 deletions scripts/v.db.join/testsuite/test_v_db_join.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""
TEST: test_v_db_join.py
AUTHOR(S): Stefan Blumentrath
PURPOSE: Test for v.db.join
COPYRIGHT: (C) 2024 Stefan Blumentrath, and by the GRASS Development Team
This program is free software under the GNU General Public
License (>=v2). Read the file COPYING that comes with GRASS
for details.
"""

from grass.gunittest.case import TestCase
from grass.gunittest.main import test
from grass.gunittest.gmodules import SimpleModule


class TestVDbJoin(TestCase):
"""Test v.db.join script"""

@classmethod
def setUpClass(cls):
"""Copy vector."""
firestation_sql = """CREATE TABLE firestation_test_table (
CITY text,
some_number int,
some_text text,
some_double double precision,
some_float real
);
INSERT INTO firestation_test VALUES
('Cary', 1, 'short', 1.1233445366756784345,),
('Apex', 2, 'longer', -111.1220390953406936354,),
('Garner', 3, 'short', 4.20529509802443234245,),
('Relaigh', 4, 'even longer than before', 32.913873948295837592,);
"""
firestation_existing_sql = """CREATE TABLE firestation_test_table_update (
CITY text,
others int
);
INSERT INTO firestation_test_table_update VALUES
('Cary', 1),
('Apex', 2),
('Garner', 3),
('Relaigh', 4);
"""
cls.runModule("g.copy", vector=["firestations", "test_firestations"])
cls.runModule("db.execute", sql=firestation_sql)
cls.runModule("db.execute", sql=firestation_existing_sql)

@classmethod
def tearDownClass(cls):
"""Remove copied vector data and created tables"""
cls.runModule("g.remove", type="vector", name="test_firestations", flags="f")
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table;")
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table_update;")

def test_join_firestations_table(self):
"""Join firestations table with new different columns"""
module = SimpleModule(
"v.db.join",
map="test_firestations",
column="CITY",
other_table="firestation_test_table",
other_column="CITY",
)
self.assertModule(module)

def test_join_firestations_table_existing(self):
"""Join firestations table with only existing columns"""
module = SimpleModule(
"v.db.join",
map="test_firestations",
column="CITY",
other_table="firestation_test_table_update",
other_column="CITY",
)
self.assertModule(module)


if __name__ == "__main__":
test()
178 changes: 93 additions & 85 deletions scripts/v.db.join/v.db.join.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,80 +64,87 @@
# %end

import atexit
import os
import sys
import grass.script as grass

from pathlib import Path

import grass.script as gs
from grass.exceptions import CalledModuleError

rm_files = []


def cleanup():
for file in rm_files:
if os.path.isfile(file):
try:
os.remove(file)
except Exception as e:
grass.warning(
_("Unable to remove file {file}: {message}").format(
file=file, message=e
)
for file_path in rm_files:
try:
file_path.unlink(missing_ok=True)
except Exception as e:
gs.warning(
_("Unable to remove file {file}: {message}").format(
file=file_path, message=e
)
)


def main():
global rm_files
map = options["map"]
# Include mapset into the name, so we avoid multiple messages about
# found in more mapsets. The following generates an error message, while the code
# above does not. However, the above checks that the map exists, so we don't
# check it here.
vector_map = gs.find_file(options["map"], element="vector")["fullname"]
layer = options["layer"]
column = options["column"]
otable = options["other_table"]
ocolumn = options["other_column"]
scolumns = None
if options["subset_columns"]:
scolumns = options["subset_columns"].split(",")
else:
scolumns = None
ecolumns = None
if options["exclude_columns"]:
ecolumns = options["exclude_columns"].split(",")
else:
ecolumns = None

try:
f = grass.vector_layer_db(map, layer)
f = gs.vector_layer_db(vector_map, layer)
except CalledModuleError:
sys.exit(1)

# Include mapset into the name, so we avoid multiple messages about
# found in more mapsets. The following generates an error message, while the code
# above does not. However, the above checks that the map exists, so we don't
# check it here.
map = grass.find_file(map, element="vector")["fullname"]

maptable = f["table"]
database = f["database"]
driver = f["driver"]

if driver == "dbf":
grass.fatal(_("JOIN is not supported for tables stored in DBF format"))
gs.fatal(_("JOIN is not supported for tables stored in DBF format"))

if not maptable:
grass.fatal(
gs.fatal(
_("There is no table connected to this map. Unable to join any column.")
)

all_cols_tt = gs.vector_columns(vector_map, int(layer)).keys()
# This is used for testing presence (and potential name conflict) with
# the newly added columns, but the test needs to case-insensitive since it
# is SQL, so we lowercase the names here and in the test
# An alternative is quoting identifiers (as in e.g. #3634)
all_cols_tt = [name.lower() for name in all_cols_tt]

# check if column is in map table
if column not in grass.vector_columns(map, layer):
grass.fatal(
if column.lower() not in all_cols_tt:
gs.fatal(
_("Column <{column}> not found in table <{table}>").format(
column=column, table=maptable
)
)

# describe other table
all_cols_ot = grass.db_describe(otable, driver=driver, database=database)["cols"]
all_cols_ot = {
col_desc[0].lower(): col_desc[1:]
for col_desc in gs.db_describe(otable, driver=driver, database=database)["cols"]
}

# check if ocolumn is on other table
if ocolumn not in [ocol[0] for ocol in all_cols_ot]:
grass.fatal(
if ocolumn.lower() not in all_cols_ot:
gs.fatal(
_("Column <{column}> not found in table <{table}>").format(
column=ocolumn, table=otable
)
Expand All @@ -146,106 +153,107 @@ def main():
# determine columns subset from other table
if not scolumns:
# select all columns from other table
cols_to_add = all_cols_ot
cols_to_update = all_cols_ot
else:
cols_to_add = []
cols_to_update = {}
# check if scolumns exists in the other table
for scol in scolumns:
found = False
for col_ot in all_cols_ot:
if scol == col_ot[0]:
found = True
cols_to_add.append(col_ot)
break
if not found:
grass.warning(
if scol not in all_cols_ot:
gs.warning(
_("Column <{column}> not found in table <{table}>").format(
column=scol, table=otable
)
)
else:
cols_to_update[scol] = all_cols_ot[scol]

# skip the vector column which is used for join
if column in cols_to_update:
cols_to_update.pop(column)

# exclude columns from other table
if ecolumns:
cols_to_add = list(filter(lambda col: col[0] not in ecolumns, cols_to_add))

all_cols_tt = grass.vector_columns(map, int(layer)).keys()
# This is used for testing presence (and potential name conflict) with
# the newly added columns, but the test needs to case-insensitive since it
# is SQL, so we lowercase the names here and in the test.
all_cols_tt = [name.lower() for name in all_cols_tt]
for ecol in ecolumns:
if ecol not in all_cols_ot:
gs.warning(
_("Column <{column}> not found in table <{table}>").format(
column=ecol, table=otable
)
)
else:
cols_to_update.pop(ecol)

cols_to_add_final = []
for col in cols_to_add:
# skip the vector column which is used for join
colname = col[0]
if colname == column:
continue
cols_to_add = []
for col_name, col_desc in cols_to_update.items():
use_len = False
if len(col) > 2:
col_type = f"{col_desc[0]}"
# Sqlite 3 does not support the precision number any more
if len(col_desc) > 2 and driver != "sqlite":
use_len = True
# Sqlite 3 does not support the precision number any more
if driver == "sqlite":
use_len = False
# MySQL - expect format DOUBLE PRECISION(M,D), see #2792
elif driver == "mysql" and col[1] == "DOUBLE PRECISION":
if driver == "mysql" and col_desc[1] == "DOUBLE PRECISION":
use_len = False

if use_len:
coltype = "%s(%s)" % (col[1], col[2])
else:
coltype = "%s" % col[1]
col_type = f"{col_desc[0]}({col_desc[1]})"

colspec = "%s %s" % (colname, coltype)
col_spec = f"{col_name.lower()} {col_type}"

# add only the new column to the table
if colname.lower() not in all_cols_tt:
cols_to_add_final.append(colspec)

cols_added = [col.split(" ")[0] for col in cols_to_add_final]
cols_added_str = ",".join(cols_added)
try:
grass.run_command(
"v.db.addcolumn", map=map, columns=cols_to_add_final, layer=layer
)
except CalledModuleError:
grass.fatal(_("Error creating columns <{}>").format(cols_added_str))
if col_name.lower() not in all_cols_tt:
cols_to_add.append(col_spec)

if cols_to_add:
try:
gs.run_command(
"v.db.addcolumn",
map=vector_map,
columns=",".join(cols_to_add),
layer=layer,
)
except CalledModuleError:
gs.fatal(
_("Error creating columns <{}>").format(
", ".join([col.split(" ")[0] for col in cols_to_add])
)
)

update_str = "BEGIN TRANSACTION\n"
for col in cols_added:
for col in cols_to_update:
cur_up_str = (
f"UPDATE {maptable} SET {col} = (SELECT {col} FROM "
f"{otable} WHERE "
f"{otable}.{ocolumn}={maptable}.{column});\n"
)
update_str += cur_up_str
update_str += "END TRANSACTION"
grass.debug(update_str, 1)
grass.verbose(
gs.debug(update_str, 1)
gs.verbose(
_("Updating columns {columns} of vector map {map_name}...").format(
columns=cols_added_str, map_name=map
columns=", ".join(cols_to_update.keys()), map_name=vector_map
)
)
sql_file = grass.tempfile()
sql_file = Path(gs.tempfile())
rm_files.append(sql_file)
with open(sql_file, "w") as write_file:
write_file.write(update_str)
sql_file.write_text(update_str, encoding="UTF8")

try:
grass.run_command(
gs.run_command(
"db.execute",
input=sql_file,
input=str(sql_file),
database=database,
driver=driver,
)
except CalledModuleError:
grass.fatal(_("Error filling columns {}").format(cols_added_str))
gs.fatal(_("Error filling columns {}").format(cols_to_update))

# write cmd history
grass.vector_history(map)
gs.vector_history(vector_map)

return 0


if __name__ == "__main__":
options, flags = grass.parser()
options, flags = gs.parser()
atexit.register(cleanup)
sys.exit(main())

0 comments on commit fcb6199

Please sign in to comment.