Skip to content

Commit 25b337b

Browse files
authored
Merge pull request #19819 from github/redsun82/rust-regenerate-models
Rust: adapt model generation to new format
2 parents a9169dc + e8c3a2b commit 25b337b

File tree

120 files changed

+13526
-13562
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

120 files changed

+13526
-13562
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@ repos:
2020
rev: 25.1.0
2121
hooks:
2222
- id: black
23-
files: ^(misc/codegen/.*|misc/scripts/models-as-data/bulk_generate_mad)\.py$
23+
files: ^(misc/codegen/.*|misc/scripts/models-as-data/.*)\.py$
2424

2525
- repo: local
2626
hooks:

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
Note: This file must be formatted using the Black Python formatter.
66
"""
77

8-
import os.path
8+
import pathlib
99
import subprocess
1010
import sys
1111
from typing import Required, TypedDict, List, Callable, Optional
@@ -41,7 +41,7 @@ def missing_module(module_name: str) -> None:
4141
.decode("utf-8")
4242
.strip()
4343
)
44-
build_dir = os.path.join(gitroot, "mad-generation-build")
44+
build_dir = pathlib.Path(gitroot, "mad-generation-build")
4545

4646

4747
# A project to generate models for
@@ -86,10 +86,10 @@ def clone_project(project: Project) -> str:
8686
git_tag = project.get("git-tag")
8787

8888
# Determine target directory
89-
target_dir = os.path.join(build_dir, name)
89+
target_dir = build_dir / name
9090

9191
# Clone only if directory doesn't already exist
92-
if not os.path.exists(target_dir):
92+
if not target_dir.exists():
9393
if git_tag:
9494
print(f"Cloning {name} from {repo_url} at tag {git_tag}")
9595
else:
@@ -191,10 +191,10 @@ def build_database(
191191
name = project["name"]
192192

193193
# Create database directory path
194-
database_dir = os.path.join(build_dir, f"{name}-db")
194+
database_dir = build_dir / f"{name}-db"
195195

196196
# Only build the database if it doesn't already exist
197-
if not os.path.exists(database_dir):
197+
if not database_dir.exists():
198198
print(f"Building CodeQL database for {name}...")
199199
extractor_options = [option for x in extractor_options for option in ("-O", x)]
200200
try:
@@ -236,13 +236,16 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
236236
language = config["language"]
237237

238238
generator = mad.Generator(language)
239-
# Note: The argument parser converts with-sinks to with_sinks, etc.
240-
generator.generateSinks = should_generate_sinks(project)
241-
generator.generateSources = should_generate_sources(project)
242-
generator.generateSummaries = should_generate_summaries(project)
243-
generator.setenvironment(database=database_dir, folder=name)
239+
generator.with_sinks = should_generate_sinks(project)
240+
generator.with_sources = should_generate_sources(project)
241+
generator.with_summaries = should_generate_summaries(project)
244242
generator.threads = args.codeql_threads
245243
generator.ram = args.codeql_ram
244+
if config.get("single-file", False):
245+
generator.single_file = name
246+
else:
247+
generator.folder = name
248+
generator.setenvironment(database=database_dir)
246249
generator.run()
247250

248251

@@ -313,20 +316,14 @@ def download_artifact(url: str, artifact_name: str, pat: str) -> str:
313316
if response.status_code != 200:
314317
print(f"Failed to download file. Status code: {response.status_code}")
315318
sys.exit(1)
316-
target_zip = os.path.join(build_dir, zipName)
319+
target_zip = build_dir / zipName
317320
with open(target_zip, "wb") as file:
318321
for chunk in response.iter_content(chunk_size=8192):
319322
file.write(chunk)
320323
print(f"Download complete: {target_zip}")
321324
return target_zip
322325

323326

324-
def remove_extension(filename: str) -> str:
325-
while "." in filename:
326-
filename, _ = os.path.splitext(filename)
327-
return filename
328-
329-
330327
def pretty_name_from_artifact_name(artifact_name: str) -> str:
331328
return artifact_name.split("___")[1]
332329

@@ -348,7 +345,7 @@ def download_dca_databases(
348345
"""
349346
print("\n=== Finding projects ===")
350347
project_map = {project["name"]: project for project in projects}
351-
analyzed_databases = {}
348+
analyzed_databases = {n: None for n in project_map}
352349
for experiment_name in experiment_names:
353350
response = get_json_from_github(
354351
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
@@ -361,17 +358,24 @@ def download_dca_databases(
361358
artifact_name = analyzed_database["artifact_name"]
362359
pretty_name = pretty_name_from_artifact_name(artifact_name)
363360

364-
if not pretty_name in project_map:
361+
if not pretty_name in analyzed_databases:
365362
print(f"Skipping {pretty_name} as it is not in the list of projects")
366363
continue
367364

368-
if pretty_name in analyzed_databases:
365+
if analyzed_databases[pretty_name] is not None:
369366
print(
370367
f"Skipping previous database {analyzed_databases[pretty_name]['artifact_name']} for {pretty_name}"
371368
)
372369

373370
analyzed_databases[pretty_name] = analyzed_database
374371

372+
not_found = [name for name, db in analyzed_databases.items() if db is None]
373+
if not_found:
374+
print(
375+
f"ERROR: The following projects were not found in the DCA experiments: {', '.join(not_found)}"
376+
)
377+
sys.exit(1)
378+
375379
def download_and_decompress(analyzed_database: dict) -> str:
376380
artifact_name = analyzed_database["artifact_name"]
377381
repository = analyzed_database["repository"]
@@ -393,19 +397,17 @@ def download_and_decompress(analyzed_database: dict) -> str:
393397
# The database is in a zip file, which contains a tar.gz file with the DB
394398
# First we open the zip file
395399
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
396-
artifact_unzipped_location = os.path.join(build_dir, artifact_name)
400+
artifact_unzipped_location = build_dir / artifact_name
397401
# clean up any remnants of previous runs
398402
shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
399403
# And then we extract it to build_dir/artifact_name
400404
zip_ref.extractall(artifact_unzipped_location)
401405
# And then we extract the language tar.gz file inside it
402-
artifact_tar_location = os.path.join(
403-
artifact_unzipped_location, f"{language}.tar.gz"
404-
)
406+
artifact_tar_location = artifact_unzipped_location / f"{language}.tar.gz"
405407
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
406408
# And we just untar it to the same directory as the zip file
407409
tar_ref.extractall(artifact_unzipped_location)
408-
ret = os.path.join(artifact_unzipped_location, language)
410+
ret = artifact_unzipped_location / language
409411
print(f"Decompression complete: {ret}")
410412
return ret
411413

@@ -425,8 +427,16 @@ def download_and_decompress(analyzed_database: dict) -> str:
425427
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
426428

427429

428-
def get_mad_destination_for_project(config, name: str) -> str:
429-
return os.path.join(config["destination"], name)
430+
def clean_up_mad_destination_for_project(config, name: str):
431+
target = pathlib.Path(config["destination"], name)
432+
if config.get("single-file", False):
433+
target = target.with_suffix(".model.yml")
434+
if target.exists():
435+
print(f"Deleting existing MaD file at {target}")
436+
target.unlink()
437+
elif target.exists():
438+
print(f"Deleting existing MaD directory at {target}")
439+
shutil.rmtree(target, ignore_errors=True)
430440

431441

432442
def get_strategy(config) -> str:
@@ -448,8 +458,7 @@ def main(config, args) -> None:
448458
language = config["language"]
449459

450460
# Create build directory if it doesn't exist
451-
if not os.path.exists(build_dir):
452-
os.makedirs(build_dir)
461+
build_dir.mkdir(parents=True, exist_ok=True)
453462

454463
database_results = []
455464
match get_strategy(config):
@@ -469,7 +478,7 @@ def main(config, args) -> None:
469478
if args.pat is None:
470479
print("ERROR: --pat argument is required for DCA strategy")
471480
sys.exit(1)
472-
if not os.path.exists(args.pat):
481+
if not args.pat.exists():
473482
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
474483
sys.exit(1)
475484
with open(args.pat, "r") as f:
@@ -493,12 +502,9 @@ def main(config, args) -> None:
493502
)
494503
sys.exit(1)
495504

496-
# Delete the MaD directory for each project
497-
for project, database_dir in database_results:
498-
mad_dir = get_mad_destination_for_project(config, project["name"])
499-
if os.path.exists(mad_dir):
500-
print(f"Deleting existing MaD directory at {mad_dir}")
501-
subprocess.check_call(["rm", "-rf", mad_dir])
505+
# clean up existing MaD data for the projects
506+
for project, _ in database_results:
507+
clean_up_mad_destination_for_project(config, project["name"])
502508

503509
for project, database_dir in database_results:
504510
if database_dir is not None:
@@ -508,7 +514,10 @@ def main(config, args) -> None:
508514
if __name__ == "__main__":
509515
parser = argparse.ArgumentParser()
510516
parser.add_argument(
511-
"--config", type=str, help="Path to the configuration file.", required=True
517+
"--config",
518+
type=pathlib.Path,
519+
help="Path to the configuration file.",
520+
required=True,
512521
)
513522
parser.add_argument(
514523
"--dca",
@@ -519,13 +528,13 @@ def main(config, args) -> None:
519528
)
520529
parser.add_argument(
521530
"--pat",
522-
type=str,
531+
type=pathlib.Path,
523532
help="Path to a file containing the PAT token required to grab DCA databases (the same as the one you use for DCA)",
524533
)
525534
parser.add_argument(
526535
"--codeql-ram",
527536
type=int,
528-
help="What `--ram` value to pass to `codeql` while generating models (by default the flag is not passed)",
537+
help="What `--ram` value to pass to `codeql` while generating models (by default 2048 MB per thread)",
529538
default=None,
530539
)
531540
parser.add_argument(
@@ -538,7 +547,7 @@ def main(config, args) -> None:
538547

539548
# Load config file
540549
config = {}
541-
if not os.path.exists(args.config):
550+
if not args.config.exists():
542551
print(f"ERROR: Config file '{args.config}' does not exist.")
543552
sys.exit(1)
544553
try:

misc/scripts/models-as-data/convert_extensions.py

Lines changed: 41 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -7,65 +7,86 @@
77
import sys
88
import tempfile
99

10+
1011
def quote_if_needed(v):
1112
# string columns
1213
if type(v) is str:
13-
return "\"" + v + "\""
14+
return '"' + v + '"'
1415
# bool column
1516
return str(v)
1617

18+
1719
def parseData(data):
18-
rows = [{ }, { }]
20+
rows = [{}, {}]
1921
for row in data:
2022
d = map(quote_if_needed, row)
2123
provenance = row[-1]
2224
targetRows = rows[1] if provenance.endswith("generated") else rows[0]
23-
helpers.insert_update(targetRows, row[0], " - [" + ', '.join(d) + ']\n')
25+
helpers.insert_update(targetRows, row[0], " - [" + ", ".join(d) + "]\n")
2426

2527
return rows
2628

29+
2730
class Converter:
2831
def __init__(self, language, dbDir):
2932
self.language = language
3033
self.dbDir = dbDir
31-
self.codeQlRoot = subprocess.check_output(["git", "rev-parse", "--show-toplevel"]).decode("utf-8").strip()
34+
self.codeQlRoot = (
35+
subprocess.check_output(["git", "rev-parse", "--show-toplevel"])
36+
.decode("utf-8")
37+
.strip()
38+
)
3239
self.extDir = os.path.join(self.codeQlRoot, f"{self.language}/ql/lib/ext/")
3340
self.dirname = "modelconverter"
3441
self.modelFileExtension = ".model.yml"
3542
self.workDir = tempfile.mkdtemp()
3643

37-
3844
def runQuery(self, query):
39-
print('########## Querying: ', query)
40-
queryFile = os.path.join(self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query)
45+
print("########## Querying: ", query)
46+
queryFile = os.path.join(
47+
self.codeQlRoot, f"{self.language}/ql/src/utils/{self.dirname}", query
48+
)
4149
resultBqrs = os.path.join(self.workDir, "out.bqrs")
4250

43-
helpers.run_cmd(['codeql', 'query', 'run', queryFile, '--database', self.dbDir, '--output', resultBqrs], "Failed to generate " + query)
51+
helpers.run_cmd(
52+
[
53+
"codeql",
54+
"query",
55+
"run",
56+
queryFile,
57+
"--database",
58+
self.dbDir,
59+
"--output",
60+
resultBqrs,
61+
],
62+
"Failed to generate " + query,
63+
)
4464
return helpers.readData(self.workDir, resultBqrs)
4565

46-
4766
def asAddsTo(self, rows, predicate):
48-
extensions = [{ }, { }]
67+
extensions = [{}, {}]
4968
for i in range(2):
5069
for key in rows[i]:
51-
extensions[i][key] = helpers.addsToTemplate.format(f"codeql/{self.language}-all", predicate, rows[i][key])
52-
53-
return extensions
70+
extensions[i][key] = helpers.addsToTemplate.format(
71+
f"codeql/{self.language}-all", predicate, rows[i][key]
72+
)
5473

74+
return extensions
5575

5676
def getAddsTo(self, query, predicate):
5777
data = self.runQuery(query)
5878
rows = parseData(data)
5979
return self.asAddsTo(rows, predicate)
6080

61-
6281
def makeContent(self):
6382
summaries = self.getAddsTo("ExtractSummaries.ql", helpers.summaryModelPredicate)
6483
sources = self.getAddsTo("ExtractSources.ql", helpers.sourceModelPredicate)
6584
sinks = self.getAddsTo("ExtractSinks.ql", helpers.sinkModelPredicate)
6685
neutrals = self.getAddsTo("ExtractNeutrals.ql", helpers.neutralModelPredicate)
67-
return [helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]), helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1])]
68-
86+
return [
87+
helpers.merge(sources[0], sinks[0], summaries[0], neutrals[0]),
88+
helpers.merge(sources[1], sinks[1], summaries[1], neutrals[1]),
89+
]
6990

7091
def save(self, extensions):
7192
# Create directory if it doesn't exist
@@ -77,9 +98,11 @@ def save(self, extensions):
7798
for entry in extensions[0]:
7899
with open(self.extDir + "/" + entry + self.modelFileExtension, "w") as f:
79100
f.write(extensionTemplate.format(extensions[0][entry]))
80-
101+
81102
for entry in extensions[1]:
82-
with open(self.extDir + "/generated/" + entry + self.modelFileExtension, "w") as f:
103+
with open(
104+
self.extDir + "/generated/" + entry + self.modelFileExtension, "w"
105+
) as f:
83106
f.write(extensionTemplate.format(extensions[1][entry]))
84107

85108
def run(self):

0 commit comments

Comments
 (0)