diff --git a/src/connection-int.c b/src/connection-int.c index b3df0db..57d3dfa 100644 --- a/src/connection-int.c +++ b/src/connection-int.c @@ -36,7 +36,7 @@ void connection_handle_error(ConnectionObject *conn) { } int connection_run_without_results(ConnectionObject *conn, const char *query) { - int status = mg_session_run(conn->session, query, NULL, NULL, NULL, NULL); + int status = mg_session_run(conn->session, query, NULL, conn->extras, NULL, NULL); if (status != 0) { connection_handle_error(conn); return -1; @@ -87,7 +87,7 @@ int connection_run(ConnectionObject *conn, const char *query, PyObject *params, const mg_list *mg_columns; int status = - mg_session_run(conn->session, query, mg_params, NULL, &mg_columns, NULL); + mg_session_run(conn->session, query, mg_params, conn->extras, &mg_columns, NULL); mg_map_destroy(mg_params); if (status != 0) { diff --git a/src/connection.c b/src/connection.c index 576f54f..f0546e3 100644 --- a/src/connection.c +++ b/src/connection.c @@ -21,6 +21,7 @@ static void connection_dealloc(ConnectionObject *conn) { mg_session_destroy(conn->session); + mg_map_destroy(conn->extras); Py_TYPE(conn)->tp_free(conn); } @@ -37,11 +38,36 @@ static int execute_trust_callback(const char *hostname, const char *ip_address, return !status; } +static mg_map *database_to_extras(const char *database) { + assert(databases); + + mg_map *map = NULL; + + map = mg_map_make_empty(1U); + if (!map) { + PyErr_SetString(PyExc_RuntimeError, "failed to create a mg_map"); + goto cleanup; + } + + mg_string* key = mg_string_make("db"); + mg_value* value = mg_value_make_string(database); + + if (mg_map_insert_unsafe2(map, key, value) != 0) { + mg_string_destroy(key); + abort(); + } + return map; + +cleanup: + mg_map_destroy(map); + return NULL; +} + static int connection_init(ConnectionObject *conn, PyObject *args, PyObject *kwargs) { static char *kwlist[] = {"host", "address", "port", "username", "password", "client_name", "sslmode", "sslcert", - "sslkey", "trust_callback", "lazy", NULL}; + "sslkey", "trust_callback", "lazy", "database", NULL}; const char *host = NULL; const char *address = NULL; @@ -54,11 +80,12 @@ static int connection_init(ConnectionObject *conn, PyObject *args, const char *sslkey = NULL; PyObject *trust_callback = NULL; int lazy = 0; - - if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$ssisssissOp", kwlist, &host, + const char *database = NULL; + + if (!PyArg_ParseTupleAndKeywords(args, kwargs, "|$ssisssissOps", kwlist, &host, &address, &port, &username, &password, &client_name, &sslmode_int, &sslcert, - &sslkey, &trust_callback, &lazy)) { + &sslkey, &trust_callback, &lazy, &database)) { return -1; } @@ -124,12 +151,17 @@ static int connection_init(ConnectionObject *conn, PyObject *args, conn->status = CONN_STATUS_READY; conn->lazy = 0; conn->autocommit = 0; + conn->extras = NULL; if (lazy) { conn->lazy = 1; conn->autocommit = 1; } + if (database) { + conn->extras = database_to_extras(database); + } + return 0; } @@ -180,6 +212,8 @@ static PyObject *connection_close(ConnectionObject *conn, PyObject *args) { // rollback any open transactions. mg_session_destroy(conn->session); conn->session = NULL; + mg_map_destroy(conn->extras); + conn->extras = NULL; conn->status = CONN_STATUS_CLOSED; Py_RETURN_NONE; diff --git a/src/connection.h b/src/connection.h index c543a8d..b633260 100644 --- a/src/connection.h +++ b/src/connection.h @@ -35,6 +35,7 @@ typedef struct ConnectionObject { int status; int autocommit; int lazy; + mg_map *extras; } ConnectionObject; // clang-format on diff --git a/src/mgclientmodule.c b/src/mgclientmodule.c index a750d4e..77e5153 100644 --- a/src/mgclientmodule.c +++ b/src/mgclientmodule.c @@ -197,7 +197,7 @@ static PyObject *mgclient_connect(PyObject *self, PyObject *args, PyDoc_STRVAR(mgclient_connect_doc, "connect(host=None, address=None, port=None, username=None, password=None,\n\ client_name=None, sslmode=mgclient.MG_SSLMODE_DISABLE,\n\ - sslcert=None, sslkey=None, trust_callback=None, lazy=False)\n\ + sslcert=None, sslkey=None, trust_callback=None, lazy=False, database=None)\n\ --\n\ \n\ Makes a new connection to the database server and returns a\n\ @@ -271,7 +271,11 @@ Currently recognized parameters are:\n\ \n\ * :obj:`lazy`\n\ \n\ - If this is set to ``True``, a lazy connection is made. Default is ``False``."); + If this is set to ``True``, a lazy connection is made. Default is ``False``.\n\ +\n\ + * :obj:`database`\n\ +\n\ + If set, all queries executed will target the defined database. Default is ``None``."); // clang-format on static PyMethodDef mgclient_methods[] = { diff --git a/test/test_multi_tenancy.py b/test/test_multi_tenancy.py new file mode 100644 index 0000000..617255a --- /dev/null +++ b/test/test_multi_tenancy.py @@ -0,0 +1,148 @@ +# Copyright (c) 2016-2020 Memgraph Ltd. [https://memgraph.com] +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import mgclient +import pytest +import tempfile + +from common import start_memgraph, MEMGRAPH_PORT + + +def assert_db(cursor, db_name): + cursor.execute("SHOW DATABASE") + assert cursor.fetchall() == [(db_name, )] + +def assert_data(cursor, db): + cursor.execute('MATCH (n:Node) RETURN n.db') + cursor.fetchall() == [(db,)] + + +@pytest.fixture(scope="function") +def memgraph_server(): + # memgraph = start_memgraph() + yield "127.0.0.1", MEMGRAPH_PORT + + # memgraph.kill() + # memgraph.wait() + +# def test_connect_database_fail(memgraph_server): +# host, port = memgraph_server +# # Connected to a non existent database +# conn = mgclient.connect( +# host=host, +# port=port, +# lazy=True, +# database="does not exist") +# cursor = conn.cursor() +# with pytest.raises(mgclient.DatabaseError): +# cursor.execute("MATCH(n) RETURN n;") + +# def test_connect_database(memgraph_server): +# host, port = memgraph_server +# conn = mgclient.connect(host=host, port=port, lazy=True) +# cursor = conn.cursor() + +# #setup +# assert_db(cursor, "memgraph") +# cursor.execute('CREATE (:Node{db:"memgraph"})') +# cursor.fetchall() + +# cursor.execute("CREATE DATABASE db1") +# cursor.fetchall() +# cursor.execute("USE DATABASE db1") +# cursor.fetchall() +# assert_db(cursor, "db1") +# cursor.execute('CREATE (:Node{db:"db1"})') +# cursor.fetchall() + +# cursor.execute("CREATE DATABASE db2") +# cursor.fetchall() +# cursor.execute("USE DATABASE db2") +# cursor.fetchall() +# assert_db(cursor, "db2") +# cursor.execute('CREATE (:Node{db:"db2"})') +# cursor.fetchall() + +# #connection tests +# #default +# conn = mgclient.connect(host=host, port=port, lazy=True) +# cursor = conn.cursor() +# assert_db(cursor, "memgraph") +# assert_data(cursor, "memgraph") + +# #memgraph +# conn = mgclient.connect(host=host, port=port, lazy=True, database="memgraph") +# cursor = conn.cursor() +# assert_db(cursor, "memgraph") +# assert_data(cursor, "memgraph") + +# #db1 +# conn = mgclient.connect(host=host, port=port, lazy=True, database="db1") +# cursor = conn.cursor() +# assert_db(cursor, "db1") +# assert_data(cursor, "db1") + +# #db2 +# conn = mgclient.connect(host=host, port=port, lazy=True, database="db2") +# cursor = conn.cursor() +# assert_db(cursor, "db2") +# assert_data(cursor, "db2") + +def test_connect_database_and_block(memgraph_server): + host, port = memgraph_server + conn = mgclient.connect(host=host, port=port, lazy=True) + cursor = conn.cursor() + + #setup + assert_db(cursor, "memgraph") + cursor.execute("CREATE DATABASE db1") + cursor.fetchall() + cursor.execute("CREATE DATABASE db2") + cursor.fetchall() + + #connection tests + #default <- should allow db switching + conn = mgclient.connect(host=host, port=port, lazy=True) + cursor = conn.cursor() + assert_db(cursor, "memgraph") + cursor.execute("USE DATABASE db1;") + cursor.fetchall() + assert_db(cursor, "db1") + cursor.execute("USE DATABASE db2;") + cursor.fetchall() + assert_db(cursor, "db2") + + #memgraph + conn = mgclient.connect(host=host, port=port, lazy=True, database="memgraph") + cursor = conn.cursor() + assert_db(cursor, "memgraph") + with pytest.raises(mgclient.DatabaseError): + cursor.execute("USE DATABASE db2;") + print(cursor.fetchall()) + + #db1 + conn = mgclient.connect(host=host, port=port, lazy=True, database="db1") + cursor = conn.cursor() + assert_db(cursor, "db1") + with pytest.raises(mgclient.DatabaseError): + cursor.execute("USE DATABASE db2;") + cursor.fetchall() + + #db2 + conn = mgclient.connect(host=host, port=port, lazy=True, database="db2") + cursor = conn.cursor() + assert_db(cursor, "db2") + with pytest.raises(mgclient.DatabaseError): + cursor.execute("USE DATABASE memgraph;") + cursor.fetchall()