Skip to content

Commit fcb6199

Browse files
authored
v.db.join: handle existing columns properly (OSGeo#3765)
* handle existing columns properly * refactoring and removal of Python2 compatible code * add a basic testsuite
1 parent 64c8b4a commit fcb6199

File tree

2 files changed

+177
-85
lines changed

2 files changed

+177
-85
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
"""
2+
TEST: test_v_db_join.py
3+
4+
AUTHOR(S): Stefan Blumentrath
5+
6+
PURPOSE: Test for v.db.join
7+
8+
COPYRIGHT: (C) 2024 Stefan Blumentrath, and by the GRASS Development Team
9+
10+
This program is free software under the GNU General Public
11+
License (>=v2). Read the file COPYING that comes with GRASS
12+
for details.
13+
"""
14+
15+
from grass.gunittest.case import TestCase
16+
from grass.gunittest.main import test
17+
from grass.gunittest.gmodules import SimpleModule
18+
19+
20+
class TestVDbJoin(TestCase):
21+
"""Test v.db.join script"""
22+
23+
@classmethod
24+
def setUpClass(cls):
25+
"""Copy vector."""
26+
firestation_sql = """CREATE TABLE firestation_test_table (
27+
CITY text,
28+
some_number int,
29+
some_text text,
30+
some_double double precision,
31+
some_float real
32+
);
33+
INSERT INTO firestation_test VALUES
34+
('Cary', 1, 'short', 1.1233445366756784345,),
35+
('Apex', 2, 'longer', -111.1220390953406936354,),
36+
('Garner', 3, 'short', 4.20529509802443234245,),
37+
('Relaigh', 4, 'even longer than before', 32.913873948295837592,);
38+
"""
39+
firestation_existing_sql = """CREATE TABLE firestation_test_table_update (
40+
CITY text,
41+
others int
42+
);
43+
INSERT INTO firestation_test_table_update VALUES
44+
('Cary', 1),
45+
('Apex', 2),
46+
('Garner', 3),
47+
('Relaigh', 4);
48+
"""
49+
cls.runModule("g.copy", vector=["firestations", "test_firestations"])
50+
cls.runModule("db.execute", sql=firestation_sql)
51+
cls.runModule("db.execute", sql=firestation_existing_sql)
52+
53+
@classmethod
54+
def tearDownClass(cls):
55+
"""Remove copied vector data and created tables"""
56+
cls.runModule("g.remove", type="vector", name="test_firestations", flags="f")
57+
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table;")
58+
cls.runModule("db.execute", sql="DROP TABLE firestation_test_table_update;")
59+
60+
def test_join_firestations_table(self):
61+
"""Join firestations table with new different columns"""
62+
module = SimpleModule(
63+
"v.db.join",
64+
map="test_firestations",
65+
column="CITY",
66+
other_table="firestation_test_table",
67+
other_column="CITY",
68+
)
69+
self.assertModule(module)
70+
71+
def test_join_firestations_table_existing(self):
72+
"""Join firestations table with only existing columns"""
73+
module = SimpleModule(
74+
"v.db.join",
75+
map="test_firestations",
76+
column="CITY",
77+
other_table="firestation_test_table_update",
78+
other_column="CITY",
79+
)
80+
self.assertModule(module)
81+
82+
83+
if __name__ == "__main__":
84+
test()

scripts/v.db.join/v.db.join.py

Lines changed: 93 additions & 85 deletions
Original file line numberDiff line numberDiff line change
@@ -64,80 +64,87 @@
6464
# %end
6565

6666
import atexit
67-
import os
6867
import sys
69-
import grass.script as grass
68+
69+
from pathlib import Path
70+
71+
import grass.script as gs
7072
from grass.exceptions import CalledModuleError
7173

7274
rm_files = []
7375

7476

7577
def cleanup():
76-
for file in rm_files:
77-
if os.path.isfile(file):
78-
try:
79-
os.remove(file)
80-
except Exception as e:
81-
grass.warning(
82-
_("Unable to remove file {file}: {message}").format(
83-
file=file, message=e
84-
)
78+
for file_path in rm_files:
79+
try:
80+
file_path.unlink(missing_ok=True)
81+
except Exception as e:
82+
gs.warning(
83+
_("Unable to remove file {file}: {message}").format(
84+
file=file_path, message=e
8585
)
86+
)
8687

8788

8889
def main():
8990
global rm_files
90-
map = options["map"]
91+
# Include mapset into the name, so we avoid multiple messages about
92+
# found in more mapsets. The following generates an error message, while the code
93+
# above does not. However, the above checks that the map exists, so we don't
94+
# check it here.
95+
vector_map = gs.find_file(options["map"], element="vector")["fullname"]
9196
layer = options["layer"]
9297
column = options["column"]
9398
otable = options["other_table"]
9499
ocolumn = options["other_column"]
100+
scolumns = None
95101
if options["subset_columns"]:
96102
scolumns = options["subset_columns"].split(",")
97-
else:
98-
scolumns = None
103+
ecolumns = None
99104
if options["exclude_columns"]:
100105
ecolumns = options["exclude_columns"].split(",")
101-
else:
102-
ecolumns = None
103106

104107
try:
105-
f = grass.vector_layer_db(map, layer)
108+
f = gs.vector_layer_db(vector_map, layer)
106109
except CalledModuleError:
107110
sys.exit(1)
108111

109-
# Include mapset into the name, so we avoid multiple messages about
110-
# found in more mapsets. The following generates an error message, while the code
111-
# above does not. However, the above checks that the map exists, so we don't
112-
# check it here.
113-
map = grass.find_file(map, element="vector")["fullname"]
114-
115112
maptable = f["table"]
116113
database = f["database"]
117114
driver = f["driver"]
118115

119116
if driver == "dbf":
120-
grass.fatal(_("JOIN is not supported for tables stored in DBF format"))
117+
gs.fatal(_("JOIN is not supported for tables stored in DBF format"))
121118

122119
if not maptable:
123-
grass.fatal(
120+
gs.fatal(
124121
_("There is no table connected to this map. Unable to join any column.")
125122
)
126123

124+
all_cols_tt = gs.vector_columns(vector_map, int(layer)).keys()
125+
# This is used for testing presence (and potential name conflict) with
126+
# the newly added columns, but the test needs to case-insensitive since it
127+
# is SQL, so we lowercase the names here and in the test
128+
# An alternative is quoting identifiers (as in e.g. #3634)
129+
all_cols_tt = [name.lower() for name in all_cols_tt]
130+
127131
# check if column is in map table
128-
if column not in grass.vector_columns(map, layer):
129-
grass.fatal(
132+
if column.lower() not in all_cols_tt:
133+
gs.fatal(
130134
_("Column <{column}> not found in table <{table}>").format(
131135
column=column, table=maptable
132136
)
133137
)
134138

135139
# describe other table
136-
all_cols_ot = grass.db_describe(otable, driver=driver, database=database)["cols"]
140+
all_cols_ot = {
141+
col_desc[0].lower(): col_desc[1:]
142+
for col_desc in gs.db_describe(otable, driver=driver, database=database)["cols"]
143+
}
137144

138145
# check if ocolumn is on other table
139-
if ocolumn not in [ocol[0] for ocol in all_cols_ot]:
140-
grass.fatal(
146+
if ocolumn.lower() not in all_cols_ot:
147+
gs.fatal(
141148
_("Column <{column}> not found in table <{table}>").format(
142149
column=ocolumn, table=otable
143150
)
@@ -146,106 +153,107 @@ def main():
146153
# determine columns subset from other table
147154
if not scolumns:
148155
# select all columns from other table
149-
cols_to_add = all_cols_ot
156+
cols_to_update = all_cols_ot
150157
else:
151-
cols_to_add = []
158+
cols_to_update = {}
152159
# check if scolumns exists in the other table
153160
for scol in scolumns:
154-
found = False
155-
for col_ot in all_cols_ot:
156-
if scol == col_ot[0]:
157-
found = True
158-
cols_to_add.append(col_ot)
159-
break
160-
if not found:
161-
grass.warning(
161+
if scol not in all_cols_ot:
162+
gs.warning(
162163
_("Column <{column}> not found in table <{table}>").format(
163164
column=scol, table=otable
164165
)
165166
)
167+
else:
168+
cols_to_update[scol] = all_cols_ot[scol]
169+
170+
# skip the vector column which is used for join
171+
if column in cols_to_update:
172+
cols_to_update.pop(column)
166173

167174
# exclude columns from other table
168175
if ecolumns:
169-
cols_to_add = list(filter(lambda col: col[0] not in ecolumns, cols_to_add))
170-
171-
all_cols_tt = grass.vector_columns(map, int(layer)).keys()
172-
# This is used for testing presence (and potential name conflict) with
173-
# the newly added columns, but the test needs to case-insensitive since it
174-
# is SQL, so we lowercase the names here and in the test.
175-
all_cols_tt = [name.lower() for name in all_cols_tt]
176+
for ecol in ecolumns:
177+
if ecol not in all_cols_ot:
178+
gs.warning(
179+
_("Column <{column}> not found in table <{table}>").format(
180+
column=ecol, table=otable
181+
)
182+
)
183+
else:
184+
cols_to_update.pop(ecol)
176185

177-
cols_to_add_final = []
178-
for col in cols_to_add:
179-
# skip the vector column which is used for join
180-
colname = col[0]
181-
if colname == column:
182-
continue
186+
cols_to_add = []
187+
for col_name, col_desc in cols_to_update.items():
183188
use_len = False
184-
if len(col) > 2:
189+
col_type = f"{col_desc[0]}"
190+
# Sqlite 3 does not support the precision number any more
191+
if len(col_desc) > 2 and driver != "sqlite":
185192
use_len = True
186-
# Sqlite 3 does not support the precision number any more
187-
if driver == "sqlite":
188-
use_len = False
189193
# MySQL - expect format DOUBLE PRECISION(M,D), see #2792
190-
elif driver == "mysql" and col[1] == "DOUBLE PRECISION":
194+
if driver == "mysql" and col_desc[1] == "DOUBLE PRECISION":
191195
use_len = False
192196

193197
if use_len:
194-
coltype = "%s(%s)" % (col[1], col[2])
195-
else:
196-
coltype = "%s" % col[1]
198+
col_type = f"{col_desc[0]}({col_desc[1]})"
197199

198-
colspec = "%s %s" % (colname, coltype)
200+
col_spec = f"{col_name.lower()} {col_type}"
199201

200202
# add only the new column to the table
201-
if colname.lower() not in all_cols_tt:
202-
cols_to_add_final.append(colspec)
203-
204-
cols_added = [col.split(" ")[0] for col in cols_to_add_final]
205-
cols_added_str = ",".join(cols_added)
206-
try:
207-
grass.run_command(
208-
"v.db.addcolumn", map=map, columns=cols_to_add_final, layer=layer
209-
)
210-
except CalledModuleError:
211-
grass.fatal(_("Error creating columns <{}>").format(cols_added_str))
203+
if col_name.lower() not in all_cols_tt:
204+
cols_to_add.append(col_spec)
205+
206+
if cols_to_add:
207+
try:
208+
gs.run_command(
209+
"v.db.addcolumn",
210+
map=vector_map,
211+
columns=",".join(cols_to_add),
212+
layer=layer,
213+
)
214+
except CalledModuleError:
215+
gs.fatal(
216+
_("Error creating columns <{}>").format(
217+
", ".join([col.split(" ")[0] for col in cols_to_add])
218+
)
219+
)
212220

213221
update_str = "BEGIN TRANSACTION\n"
214-
for col in cols_added:
222+
for col in cols_to_update:
215223
cur_up_str = (
216224
f"UPDATE {maptable} SET {col} = (SELECT {col} FROM "
217225
f"{otable} WHERE "
218226
f"{otable}.{ocolumn}={maptable}.{column});\n"
219227
)
220228
update_str += cur_up_str
221229
update_str += "END TRANSACTION"
222-
grass.debug(update_str, 1)
223-
grass.verbose(
230+
gs.debug(update_str, 1)
231+
gs.verbose(
224232
_("Updating columns {columns} of vector map {map_name}...").format(
225-
columns=cols_added_str, map_name=map
233+
columns=", ".join(cols_to_update.keys()), map_name=vector_map
226234
)
227235
)
228-
sql_file = grass.tempfile()
236+
sql_file = Path(gs.tempfile())
229237
rm_files.append(sql_file)
230-
with open(sql_file, "w") as write_file:
231-
write_file.write(update_str)
238+
sql_file.write_text(update_str, encoding="UTF8")
239+
232240
try:
233-
grass.run_command(
241+
gs.run_command(
234242
"db.execute",
235-
input=sql_file,
243+
input=str(sql_file),
236244
database=database,
237245
driver=driver,
238246
)
239247
except CalledModuleError:
240-
grass.fatal(_("Error filling columns {}").format(cols_added_str))
248+
gs.fatal(_("Error filling columns {}").format(cols_to_update))
241249

242250
# write cmd history
243-
grass.vector_history(map)
251+
gs.vector_history(vector_map)
244252

245253
return 0
246254

247255

248256
if __name__ == "__main__":
249-
options, flags = grass.parser()
257+
options, flags = gs.parser()
250258
atexit.register(cleanup)
251259
sys.exit(main())

0 commit comments

Comments
 (0)