Skip to content

Commit 9c71846

Browse files
save
1 parent ad97e14 commit 9c71846

File tree

4 files changed

+296
-285
lines changed

4 files changed

+296
-285
lines changed

tests/data/baseline/ml/fastRP.json.gz

0 Bytes
Binary file not shown.

tests/run.sh

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,3 +2,4 @@ clear
22
python3 test/setup.py &&
33
python3 test/baseline/create_baselines.py &&
44
pytest test/test_centrality.py #test/test_ml.py
5+
echo

tests/test/setup.py

Lines changed: 33 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import re
44
import time
5+
from glob import glob
56

67
import pyTigerGraph as tg
78
from dotenv import load_dotenv
@@ -15,21 +16,26 @@
1516
pattern = re.compile(r'"name":\s*"tg_.*"')
1617

1718

18-
def add_reverse_edge(ds: Datasets):
19-
with open(f"{dataset.tmp_dir}/{ds.name}/create_schema.gsql") as f:
20-
schema: str = f.read()
21-
with open(f"{dataset.tmp_dir}/{ds.name}/create_schema.gsql", "w") as f:
22-
schema = schema.replace(
23-
"ADD DIRECTED EDGE Cite (from Paper, to Paper, time Int, is_train Bool, is_val Bool);",
24-
'ADD DIRECTED EDGE Cite (from Paper, to Paper, time Int, is_train Bool, is_val Bool) WITH REVERSE_EDGE="reverse_Cite";',
25-
)
26-
f.write(schema)
19+
# def add_reverse_edge(ds: Datasets):
20+
# with open(f"{dataset.tmp_dir}/{ds.name}/create_schema.gsql") as f:
21+
# schema: str = f.read()
22+
# with open(f"{dataset.tmp_dir}/{ds.name}/create_schema.gsql", "w") as f:
23+
# schema = schema.replace(
24+
# "ADD DIRECTED EDGE Cite (from Paper, to Paper, time Int, is_train Bool, is_val Bool);",
25+
# 'ADD DIRECTED EDGE Cite (from Paper, to Paper, time Int, is_train Bool, is_val Bool) WITH REVERSE_EDGE="reverse_Cite";',
26+
# )
27+
# f.write(schema)
28+
#
29+
#
30+
def get_query_path(q_name):
31+
pth = glob(f"../algorithms/**/{q_name}.gsql", recursive=True)
32+
return pth[0]
2733

2834

2935
if __name__ == "__main__":
30-
host_name = os.getenv("HOST_NAME")
31-
user_name = os.getenv("USER_NAME")
32-
password = os.getenv("PASS")
36+
host_name = os.environ["HOST_NAME"]
37+
user_name = os.environ["USER_NAME"]
38+
password = os.environ["PASS"]
3339
conn = tg.TigerGraphConnection(
3440
host=host_name,
3541
username=user_name,
@@ -42,25 +48,29 @@ def add_reverse_edge(ds: Datasets):
4248
if res["error"]:
4349
exit(1)
4450
# load the data
45-
dataset = Datasets("Cora")
46-
add_reverse_edge(dataset)
47-
conn.ingestDataset(dataset, getToken=True)
51+
# dataset = Datasets("Cora")
52+
# add_reverse_edge(dataset)
53+
# conn.ingestDataset(dataset, getToken=True)
4854

49-
dataset = Datasets("graph_algorithms_testing")
50-
conn.ingestDataset(dataset, getToken=True)
55+
# dataset = Datasets("graph_algorithms_testing")
56+
# conn.ingestDataset(dataset, getToken=True)
57+
conn.getToken()
5158

5259
conn.graphname = graph_name
5360
# install the queries
54-
feat = conn.gds.featurizer()
61+
feat = conn.gds.featurizer() # type: ignore
5562
installed_queries = util.get_installed_queries(conn)
5663
algos = json.dumps(feat.algo_dict, indent=1)
5764
queries = [
5865
m.split(": ")[1].replace('"', "").strip() for m in pattern.findall(algos)
5966
]
60-
for q in tqdm(queries, desc="installing GDS queries"):
67+
68+
t = tqdm(queries, desc="installing GDS queries")
69+
for q in t:
70+
t.set_postfix({"query": q})
71+
pth = get_query_path(q)
6172
if q not in installed_queries:
62-
print(q)
63-
feat.installAlgorithm(q)
73+
feat.installAlgorithm(q, pth)
6474

65-
for _ in trange(30, desc="Sleeping while data loads"):
66-
time.sleep(1)
75+
# for _ in trange(30, desc="Sleeping while data loads"):
76+
# time.sleep(1)

0 commit comments

Comments
 (0)