diff --git a/databuilder/databuilder/extractor/neo4j_extractor.py b/databuilder/databuilder/extractor/neo4j_extractor.py index 8d0cfb186e..629155df3b 100644 --- a/databuilder/databuilder/extractor/neo4j_extractor.py +++ b/databuilder/databuilder/extractor/neo4j_extractor.py @@ -35,10 +35,14 @@ class Neo4jExtractor(Extractor): """NEO4J_ENCRYPTED is a boolean indicating whether to use SSL/TLS when connecting.""" NEO4J_VALIDATE_SSL = 'neo4j_validate_ssl' """NEO4J_VALIDATE_SSL is a boolean indicating whether to validate the server's SSL/TLS cert against system CAs.""" + NEO4J_USE_IMPLICIT_TRANSACTIONS = 'neo4j_use_implicit_transactions' + """NEO4J_USE_IMPLICIT_TRANSACTIONS is a boolean indicating whether to use implicit or explicit transactions. This + is only needed when implicit transactions are required, such as for CALL {} IN TRANSACTIONS queries.""" DEFAULT_CONFIG = ConfigFactory.from_dict({ NEO4J_MAX_CONN_LIFE_TIME_SEC: 50, NEO4J_DATABASE_NAME: neo4j.DEFAULT_DATABASE, + NEO4J_USE_IMPLICIT_TRANSACTIONS: False, }) def init(self, conf: ConfigTree) -> None: @@ -50,6 +54,7 @@ def init(self, conf: ConfigTree) -> None: self.graph_url = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) self.cypher_query = self.conf.get_string(Neo4jExtractor.CYPHER_QUERY_CONFIG_KEY) self.db_name = self.conf.get_string(Neo4jExtractor.NEO4J_DATABASE_NAME) + self.use_implicit_transactions = self.conf.get(Neo4jExtractor.NEO4J_USE_IMPLICIT_TRANSACTIONS) uri = self.conf.get_string(Neo4jExtractor.GRAPH_URL_CONFIG_KEY) driver_args = { @@ -107,10 +112,15 @@ def _get_extract_iter(self) -> Iterator[Any]: Execute {cypher_query} and yield result one at a time """ with self.driver.session( - database=self.db_name + database=self.db_name, + default_access_mode=neo4j.READ_ACCESS ) as session: if not hasattr(self, 'results'): - self.results = session.read_transaction(self._execute_query) + if not self.use_implicit_transactions: + self.results = session.read_transaction(self._execute_query) + else: + LOGGER.info('Executing query in implicit transaction %s', self.cypher_query) + self.results = session.run(self.cypher_query).data() for result in self.results: if hasattr(self, 'model_class'): diff --git a/databuilder/setup.py b/databuilder/setup.py index d9ba34d820..9af728465f 100644 --- a/databuilder/setup.py +++ b/databuilder/setup.py @@ -5,7 +5,7 @@ from setuptools import find_packages, setup -__version__ = '7.5.0' +__version__ = '7.5.1' requirements_path = os.path.join(os.path.dirname(os.path.realpath(__file__)), 'requirements.txt')