Skip to content

Commit

Permalink
feat: Neo4j 4.x support (#1942)
Browse files Browse the repository at this point in the history
* added more configuration to support newer neo4j

Signed-off-by: Allison Suarez Miranda <[email protected]>

* added condition to all drivers and added db option

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fix driver

Signed-off-by: Allison Suarez Miranda <[email protected]>

* implemented feedback

Signed-off-by: Allison Suarez Miranda <[email protected]>

* implemented feedback

Signed-off-by: Allison Suarez Miranda <[email protected]>

* Implemented more feedback

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fix imports

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fixed unit tests

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fixing neo4j config

Signed-off-by: Allison Suarez Miranda <[email protected]>

* property patch

Signed-off-by: Allison Suarez Miranda <[email protected]>

* struggling with how to patch the driver creatioon method

* fixed patching

Signed-off-by: Allison Suarez Miranda <[email protected]>

* removed unused imports

Signed-off-by: Allison Suarez Miranda <[email protected]>

* typiong fix

Signed-off-by: Allison Suarez Miranda <[email protected]>

* missing an any in tuple

Signed-off-by: Allison Suarez Miranda <[email protected]>

* sort imports

Signed-off-by: Allison Suarez Miranda <[email protected]>

* check URI scheme security and set default trust and encrypted values accordingly

Signed-off-by: Allison Suarez Miranda <[email protected]>

* self.conf needed in neo4j extractor

Signed-off-by: Allison Suarez Miranda <[email protected]>

* updating unit tests

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fix driver

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fixed uri in neo4j search data extractor test

Signed-off-by: Allison Suarez Miranda <[email protected]>

* fix improts and lint

Signed-off-by: Allison Suarez Miranda <[email protected]>
  • Loading branch information
allisonsuarez authored Jul 27, 2022
1 parent c3dab36 commit e97b74d
Show file tree
Hide file tree
Showing 8 changed files with 170 additions and 83 deletions.
63 changes: 43 additions & 20 deletions databuilder/databuilder/extractor/neo4j_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@

import neo4j
from neo4j import GraphDatabase
from neo4j.api import (
SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri,
)
from pyhocon import ConfigFactory, ConfigTree

from databuilder.extractor.base_extractor import Extractor
Expand All @@ -25,15 +28,19 @@ class Neo4jExtractor(Extractor):
MODEL_CLASS_CONFIG_KEY = 'model_class'
NEO4J_AUTH_USER = 'neo4j_auth_user'
NEO4J_AUTH_PW = 'neo4j_auth_pw'
# in Neo4j (v4.0+), we can create and use more than one active database at the same time
NEO4J_DATABASE_NAME = 'neo4j_database'
NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec'
NEO4J_ENCRYPTED = 'neo4j_encrypted'
"""NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting."""
NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl'
"""NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs."""
NEO4J_DRIVER = 'neo4j_driver'

DEFAULT_CONFIG = ConfigFactory.from_dict({NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
NEO4J_ENCRYPTED: True,
NEO4J_VALIDATE_SSL: False})
DEFAULT_CONFIG = ConfigFactory.from_dict({
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE
})

def init(self, conf: ConfigTree) -> None:
"""
Expand All @@ -43,8 +50,36 @@ def init(self, conf: ConfigTree) -> None:
self.conf = conf.with_fallback(Neo4jExtractor.DEFAULT_CONFIG)
self.graph_url = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
self.cypher_query = conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY)
self.driver = self._get_driver()

self.db_name = self.conf.get_string(Neo4jExtractor.NEO4J_DATABASE_NAME)
driver = conf.get(Neo4jExtractor.NEO4J_DRIVER, None)
if driver:
self.driver = driver
else:
uri = conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': self.conf.get_int(Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER),
conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(Neo4jExtractor.NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(Neo4jExtractor.NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self.driver = GraphDatabase.driver(**driver_args)
self._extract_iter: Union[None, Iterator] = None

model_class = conf.get(Neo4jExtractor.MODEL_CLASS_CONFIG_KEY, None)
Expand All @@ -62,20 +97,6 @@ def close(self) -> None:
except Exception as e:
LOGGER.error("Exception encountered while closing the graph driver", e)

def _get_driver(self) -> Any:
"""
Create a Neo4j connection to Database
"""
trust = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if self.conf.get_bool(Neo4jExtractor.NEO4J_VALIDATE_SSL) \
else neo4j.TRUST_ALL_CERTIFICATES
return GraphDatabase.driver(uri=self.graph_url,
max_connection_lifetime=self.conf.get_int(
Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC),
auth=(self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_USER),
self.conf.get_string(Neo4jExtractor.NEO4J_AUTH_PW)),
encrypted=self.conf.get_bool(Neo4jExtractor.NEO4J_ENCRYPTED),
trust=trust)

def _execute_query(self, tx: Any) -> Any:
"""
Create an iterator to execute sql.
Expand All @@ -88,7 +109,9 @@ def _get_extract_iter(self) -> Iterator[Any]:
"""
Execute {cypher_query} and yield result one at a time
"""
with self.driver.session() as session:
with self.driver.session(
database=self.db_name
) as session:
if not hasattr(self, 'results'):
self.results = session.read_transaction(self._execute_query)

Expand Down
55 changes: 42 additions & 13 deletions databuilder/databuilder/publisher/neo4j_csv_publisher.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@
import pandas
from jinja2 import Template
from neo4j import GraphDatabase, Transaction
from neo4j.api import (
SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri,
)
from neo4j.exceptions import Neo4jError, TransientError
from pyhocon import ConfigFactory, ConfigTree

Expand Down Expand Up @@ -51,13 +54,17 @@

NEO4J_USER = 'neo4j_user'
NEO4J_PASSWORD = 'neo4j_password'
# in Neo4j (v4.0+), we can create and use more than one active database at the same time
NEO4J_DATABASE_NAME = 'neo4j_database'

NEO4J_DRIVER = 'neo4j_driver'

# NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting
NEO4J_ENCRYPTED = 'neo4j_encrypted'
# NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS
# cert against system CAs
NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl'


# This will be used to provide unique tag to the node and relationship
JOB_PUBLISH_TAG = 'job_publish_tag'

Expand Down Expand Up @@ -109,8 +116,7 @@
NEO4J_PROGRESS_REPORT_FREQUENCY: 500,
NEO4J_RELATIONSHIP_CREATION_CONFIRM: False,
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
NEO4J_ENCRYPTED: True,
NEO4J_VALIDATE_SSL: False,
NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE,
ADDITIONAL_FIELDS: {},
ADD_PUBLISHER_METADATA: True,
RELATION_PREPROCESSOR: NoopRelationPreprocessor()})
Expand Down Expand Up @@ -148,16 +154,39 @@ def init(self, conf: ConfigTree) -> None:
self._relation_files = self._list_files(conf, RELATION_FILES_DIR)
self._relation_files_iter = iter(self._relation_files)

trust = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if conf.get_bool(NEO4J_VALIDATE_SSL) \
else neo4j.TRUST_ALL_CERTIFICATES
self._driver = \
GraphDatabase.driver(uri=conf.get_string(NEO4J_END_POINT_KEY),
max_connection_lifetime=conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
auth=(conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
encrypted=conf.get_bool(NEO4J_ENCRYPTED),
trust=trust)
driver = conf.get(NEO4J_DRIVER, None)
if driver:
self._driver = driver
else:
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)

self._db_name = conf.get_string(NEO4J_DATABASE_NAME)
self._session = self._driver.session(database=self._db_name)

self._transaction_size = conf.get_int(NEO4J_TRANSACTION_SIZE)
self._session = self._driver.session()
self._confirm_rel_created = conf.get_bool(NEO4J_RELATIONSHIP_CREATION_CONFIRM)

# config is list of node label.
Expand Down Expand Up @@ -488,7 +517,7 @@ def _try_create_index(self, label: str) -> None:
""").render(LABEL=label)

LOGGER.info(f'Trying to create index for label {label} if not exist: {stmt}')
with self._driver.session() as session:
with self._driver.session(self._db_name) as session:
try:
session.run(stmt)
except Neo4jError as e:
Expand Down
50 changes: 38 additions & 12 deletions databuilder/databuilder/task/neo4j_staleness_removal_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,9 @@

import neo4j
from neo4j import GraphDatabase
from neo4j.api import (
SECURITY_TYPE_SECURE, SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, parse_neo4j_uri,
)
from pyhocon import ConfigFactory, ConfigTree

from databuilder import Scoped
Expand All @@ -21,11 +24,13 @@
NEO4J_MAX_CONN_LIFE_TIME_SEC = 'neo4j_max_conn_life_time_sec'
NEO4J_USER = 'neo4j_user'
NEO4J_PASSWORD = 'neo4j_password'
# in Neo4j (v4.0+), we can create and use more than one active database at the same time
NEO4J_DATABASE_NAME = 'neo4j_database'
NEO4J_DRIVER = 'neo4j_driver'
NEO4J_ENCRYPTED = 'neo4j_encrypted'
"""NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting."""
NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl'
"""NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs."""

TARGET_NODES = "target_nodes"
TARGET_RELATIONS = "target_relations"
BATCH_SIZE = "batch_size"
Expand All @@ -41,8 +46,7 @@

DEFAULT_CONFIG = ConfigFactory.from_dict({BATCH_SIZE: 100,
NEO4J_MAX_CONN_LIFE_TIME_SEC: 50,
NEO4J_ENCRYPTED: True,
NEO4J_VALIDATE_SSL: False,
NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE,
STALENESS_MAX_PCT: 5,
TARGET_NODES: [],
TARGET_RELATIONS: [],
Expand Down Expand Up @@ -127,14 +131,36 @@ def init(self, conf: ConfigTree) -> None:
else:
self.marker = conf.get_string(JOB_PUBLISH_TAG)

trust = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if conf.get_bool(NEO4J_VALIDATE_SSL) \
else neo4j.TRUST_ALL_CERTIFICATES
self._driver = \
GraphDatabase.driver(uri=conf.get_string(NEO4J_END_POINT_KEY),
max_connection_lifetime=conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
auth=(conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
encrypted=conf.get_bool(NEO4J_ENCRYPTED),
trust=trust)
driver = conf.get(NEO4J_DRIVER, None)
if driver:
self._driver = driver
else:
uri = conf.get_string(NEO4J_END_POINT_KEY)
driver_args = {
'uri': uri,
'max_connection_lifetime': conf.get_int(NEO4J_MAX_CONN_LIFE_TIME_SEC),
'auth': (conf.get_string(NEO4J_USER), conf.get_string(NEO4J_PASSWORD)),
}

# if URI scheme not secure set `trust`` and `encrypted` to default values
# https://neo4j.com/docs/api/python-driver/current/api.html#uri
_, security_type, _ = parse_neo4j_uri(uri=uri)
if security_type not in [SECURITY_TYPE_SELF_SIGNED_CERTIFICATE, SECURITY_TYPE_SECURE]:
default_security_conf = {'trust': neo4j.TRUST_ALL_CERTIFICATES, 'encrypted': True}
driver_args.update(default_security_conf)

# if NEO4J_VALIDATE_SSL or NEO4J_ENCRYPTED are set in config pass them to the driver
validate_ssl_conf = conf.get(NEO4J_VALIDATE_SSL, None)
encrypted_conf = conf.get(NEO4J_ENCRYPTED, None)
if validate_ssl_conf is not None:
driver_args['trust'] = neo4j.TRUST_SYSTEM_CA_SIGNED_CERTIFICATES if validate_ssl_conf \
else neo4j.TRUST_ALL_CERTIFICATES
if encrypted_conf is not None:
driver_args['encrypted'] = encrypted_conf

self._driver = GraphDatabase.driver(**driver_args)

self.db_name = conf.get(NEO4J_DATABASE_NAME)

def run(self) -> None:
"""
Expand Down Expand Up @@ -304,7 +330,7 @@ def _execute_cypher_query(self,

start = time.time()
try:
with self._driver.session() as session:
with self._driver.session(database=self.db_name) as session:
result = session.run(statement, **param_dict)
return [record for record in result]

Expand Down
2 changes: 1 addition & 1 deletion databuilder/setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from setuptools import find_packages, setup

__version__ = '7.0.0'
__version__ = '7.1.0'

requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)),
'requirements.txt')
Expand Down
17 changes: 10 additions & 7 deletions databuilder/tests/unit/extractor/test_neo4j_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import Any

from mock import patch
from neo4j import GraphDatabase
from pyhocon import ConfigFactory

from databuilder import Scoped
Expand All @@ -16,10 +17,11 @@ class TestNeo4jExtractor(unittest.TestCase):

def setUp(self) -> None:
config_dict = {
f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'TEST_GRAPH_URL',
f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'bolt://example.com:7687',
f'extractor.neo4j.{Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY}': 'TEST_QUERY',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'TEST_USER',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'TEST_PW'
f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'TEST_PW',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50,
}

self.conf = ConfigFactory.from_dict(config_dict)
Expand All @@ -28,7 +30,7 @@ def text_extraction_with_empty_query_result(self: Any) -> None:
"""
Test Extraction with empty results from query
"""
with patch.object(Neo4jExtractor, '_get_driver'):
with patch.object(GraphDatabase, 'driver'):
extractor = Neo4jExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))
Expand All @@ -41,7 +43,7 @@ def test_extraction_with_single_query_result(self: Any) -> None:
"""
Test Extraction with single result from query
"""
with patch.object(Neo4jExtractor, '_get_driver'):
with patch.object(GraphDatabase, 'driver'):
extractor = Neo4jExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))
Expand All @@ -58,7 +60,7 @@ def test_extraction_with_multiple_query_result(self: Any) -> None:
"""
Test Extraction with multiple result from query
"""
with patch.object(Neo4jExtractor, '_get_driver'):
with patch.object(GraphDatabase, 'driver'):
extractor = Neo4jExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))
Expand All @@ -83,17 +85,18 @@ def test_extraction_with_model_class(self: Any) -> None:
Test Extraction using model class
"""
config_dict = {
f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'TEST_GRAPH_URL',
f'extractor.neo4j.{Neo4jExtractor.GRAPH_URL_CONFIG_KEY}': 'bolt://example.com:7687',
f'extractor.neo4j.{Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY}': 'TEST_QUERY',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_USER}': 'TEST_USER',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_AUTH_PW}': 'TEST_PW',
f'extractor.neo4j.{Neo4jExtractor.NEO4J_MAX_CONN_LIFE_TIME_SEC}': 50,
f'extractor.neo4j.{Neo4jExtractor.MODEL_CLASS_CONFIG_KEY}':
'databuilder.models.table_elasticsearch_document.TableESDocument'
}

self.conf = ConfigFactory.from_dict(config_dict)

with patch.object(Neo4jExtractor, '_get_driver'):
with patch.object(GraphDatabase, 'driver'):
extractor = Neo4jExtractor()
extractor.init(Scoped.get_scoped_conf(conf=self.conf,
scope=extractor.get_scope()))
Expand Down
Loading

0 comments on commit e97b74d

Please sign in to comment.