Skip to content

Commit 1682d9b

Browse files
committed
Bulk generator script tweaks
1 parent 4cd3f22 commit 1682d9b

File tree

1 file changed

+84
-66
lines changed

1 file changed

+84
-66
lines changed

misc/scripts/models-as-data/bulk_generate_mad.py

Lines changed: 84 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,15 @@ def missing_module(module_name: str) -> None:
4444
build_dir = pathlib.Path(gitroot, "mad-generation-build")
4545

4646

47+
def database_dir_for_project(name: str) -> pathlib.Path:
48+
return build_dir / f"{name}-db"
49+
50+
51+
def database_for_project_exists(name: str) -> bool:
52+
path = database_dir_for_project(name)
53+
return path.exists()
54+
55+
4756
# A project to generate models for
4857
Project = TypedDict(
4958
"Project",
@@ -127,7 +136,7 @@ def run_in_parallel[T, U](
127136
if not items:
128137
return []
129138
max_workers = min(max_workers, len(items))
130-
results = [None for _ in range(len(items))]
139+
results: List[Optional[U]] = [None for _ in range(len(items))]
131140
with ThreadPoolExecutor(max_workers=max_workers) as executor:
132141
# Start cloning tasks and keep track of them
133142
futures = {
@@ -175,7 +184,7 @@ def clone_projects(projects: List[Project]) -> List[tuple[Project, str]]:
175184

176185
def build_database(
177186
language: str, extractor_options, project: Project, project_dir: str
178-
) -> str | None:
187+
) -> bool:
179188
"""
180189
Build a CodeQL database for a project.
181190
@@ -186,12 +195,12 @@ def build_database(
186195
project_dir: Path to the CodeQL database.
187196
188197
Returns:
189-
The path to the created database directory.
198+
True if the build was successful, False otherwise.
190199
"""
191200
name = project["name"]
192201

193202
# Create database directory path
194-
database_dir = build_dir / f"{name}-db"
203+
database_dir = database_dir_for_project(name)
195204

196205
# Only build the database if it doesn't already exist
197206
if not database_dir.exists():
@@ -214,16 +223,16 @@ def build_database(
214223
print(f"Successfully created database at {database_dir}")
215224
except subprocess.CalledProcessError as e:
216225
print(f"Failed to create database for {name}: {e}")
217-
return None
226+
return False
218227
else:
219228
print(
220229
f"Skipping database creation for {name} as it already exists at {database_dir}"
221230
)
222231

223-
return database_dir
232+
return True
224233

225234

226-
def generate_models(config, args, project: Project, database_dir: str) -> None:
235+
def generate_models(config, args, project: Project) -> None:
227236
"""
228237
Generate models for a project.
229238
@@ -235,6 +244,7 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
235244
name = project["name"]
236245
language = config["language"]
237246

247+
print("\n--- Generating models for project: " + name + " ---")
238248
generator = mad.Generator(language)
239249
generator.with_sinks = should_generate_sinks(project)
240250
generator.with_sources = should_generate_sources(project)
@@ -245,13 +255,13 @@ def generate_models(config, args, project: Project, database_dir: str) -> None:
245255
generator.single_file = name
246256
else:
247257
generator.folder = name
248-
generator.setenvironment(database=database_dir)
258+
generator.setenvironment(database=database_dir_for_project(name))
249259
generator.run()
250260

251261

252262
def build_databases_from_projects(
253263
language: str, extractor_options, projects: List[Project]
254-
) -> List[tuple[Project, str | None]]:
264+
) -> List[tuple[Project, bool]]:
255265
"""
256266
Build databases for all projects in parallel.
257267
@@ -261,7 +271,7 @@ def build_databases_from_projects(
261271
projects: List of projects to build databases for.
262272
263273
Returns:
264-
List of (project_name, database_dir) pairs, where database_dir is None if the build failed.
274+
List of (project_name, success) pairs, where success is False if the build failed.
265275
"""
266276
# Clone projects in parallel
267277
print("=== Cloning projects ===")
@@ -333,19 +343,20 @@ def download_dca_databases(
333343
experiment_names: list[str],
334344
pat: str,
335345
projects: List[Project],
336-
) -> List[tuple[Project, str | None]]:
346+
): # -> List[tuple[Project, bool]]:
337347
"""
338348
Download databases from a DCA experiment.
339349
Args:
340350
experiment_names: The names of the DCA experiments to download databases from.
341351
pat: Personal Access Token for GitHub API authentication.
342352
projects: List of projects to download databases for.
343353
Returns:
344-
List of (project_name, database_dir) pairs, where database_dir is None if the download failed.
354+
List of (project_name, success) pairs, where success is False if the download failed.
345355
"""
346356
print("\n=== Finding projects ===")
347357
project_map = {project["name"]: project for project in projects}
348-
analyzed_databases = {n: None for n in project_map}
358+
359+
analyzed_databases = {}
349360
for experiment_name in experiment_names:
350361
response = get_json_from_github(
351362
f"https://raw.githubusercontent.com/github/codeql-dca-main/data/{experiment_name}/reports/downloads.json",
@@ -358,26 +369,28 @@ def download_dca_databases(
358369
artifact_name = analyzed_database["artifact_name"]
359370
pretty_name = pretty_name_from_artifact_name(artifact_name)
360371

361-
if not pretty_name in analyzed_databases:
372+
if not pretty_name in project_map:
362373
print(f"Skipping {pretty_name} as it is not in the list of projects")
363374
continue
364375

365-
if analyzed_databases[pretty_name] is not None:
376+
if pretty_name in analyzed_databases:
366377
print(
367378
f"Skipping previous database {analyzed_databases[pretty_name]['artifact_name']} for {pretty_name}"
368379
)
369380

370381
analyzed_databases[pretty_name] = analyzed_database
371382

372-
not_found = [name for name, db in analyzed_databases.items() if db is None]
383+
not_found = [name for name in project_map if name not in analyzed_databases]
373384
if not_found:
374385
print(
375386
f"ERROR: The following projects were not found in the DCA experiments: {', '.join(not_found)}"
376387
)
377388
sys.exit(1)
378389

379-
def download_and_decompress(analyzed_database: dict) -> str:
390+
def download_and_decompress(analyzed_database: dict) -> bool:
380391
artifact_name = analyzed_database["artifact_name"]
392+
pretty_name = pretty_name_from_artifact_name(artifact_name)
393+
database_location = database_dir_for_project(pretty_name)
381394
repository = analyzed_database["repository"]
382395
run_id = analyzed_database["run_id"]
383396
print(f"=== Finding artifact: {artifact_name} ===")
@@ -398,33 +411,38 @@ def download_and_decompress(analyzed_database: dict) -> str:
398411
# First we open the zip file
399412
with zipfile.ZipFile(artifact_zip_location, "r") as zip_ref:
400413
artifact_unzipped_location = build_dir / artifact_name
414+
401415
# clean up any remnants of previous runs
402416
shutil.rmtree(artifact_unzipped_location, ignore_errors=True)
417+
shutil.rmtree(database_location, ignore_errors=True)
418+
403419
# And then we extract it to build_dir/artifact_name
404420
zip_ref.extractall(artifact_unzipped_location)
405421
# And then we extract the language tar.gz file inside it
406422
artifact_tar_location = artifact_unzipped_location / f"{language}.tar.gz"
407423
with tarfile.open(artifact_tar_location, "r:gz") as tar_ref:
408424
# And we just untar it to the same directory as the zip file
409425
tar_ref.extractall(artifact_unzipped_location)
410-
ret = artifact_unzipped_location / language
411-
print(f"Decompression complete: {ret}")
412-
return ret
426+
# Move the database to the canonical location
427+
shutil.move(artifact_unzipped_location / language, database_location)
428+
429+
print(f"Decompression complete: {database_location}")
430+
return True
413431

414-
results = run_in_parallel(
432+
run_in_parallel(
415433
download_and_decompress,
416434
list(analyzed_databases.values()),
417435
on_error=lambda db, exc: print(
418-
f"ERROR: Failed to download and decompress {db["artifact_name"]}: {exc}"
436+
f"ERROR: Failed to download and decompress {db['artifact_name']}: {exc}"
419437
),
420438
error_summary=lambda failures: print(
421439
f"ERROR: Failed to download {len(failures)} databases: {', '.join(item[0] for item in failures)}"
422440
),
423441
)
424442

425-
print(f"\n=== Fetched {len(results)} databases ===")
443+
print(f"\n=== Fetched {len(analyzed_databases.values())} databases ===")
426444

427-
return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
445+
# return [(project_map[n], r) for n, r in zip(analyzed_databases, results)]
428446

429447

430448
def clean_up_mad_destination_for_project(config, name: str):
@@ -460,55 +478,50 @@ def main(config, args) -> None:
460478
# Create build directory if it doesn't exist
461479
build_dir.mkdir(parents=True, exist_ok=True)
462480

463-
database_results = []
464-
match get_strategy(config):
465-
case "repo":
466-
extractor_options = config.get("extractor_options", [])
467-
database_results = build_databases_from_projects(
468-
language,
469-
extractor_options,
470-
projects,
471-
)
472-
case "dca":
473-
experiment_names = args.dca
474-
if experiment_names is None:
475-
print("ERROR: --dca argument is required for DCA strategy")
476-
sys.exit(1)
477-
478-
if args.pat is None:
479-
print("ERROR: --pat argument is required for DCA strategy")
480-
sys.exit(1)
481-
if not args.pat.exists():
482-
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
483-
sys.exit(1)
484-
with open(args.pat, "r") as f:
485-
pat = f.read().strip()
486-
database_results = download_dca_databases(
481+
# Check reuse databases flag is given and all databases exist
482+
skip_database_creation = args.reuse_databases and all(
483+
database_for_project_exists(project["name"]) for project in projects
484+
)
485+
486+
if not skip_database_creation:
487+
match get_strategy(config):
488+
case "repo":
489+
extractor_options = config.get("extractor_options", [])
490+
build_databases_from_projects(
487491
language,
488-
experiment_names,
489-
pat,
492+
extractor_options,
490493
projects,
491494
)
492-
493-
# Generate models for all projects
494-
print("\n=== Generating models ===")
495-
496-
failed_builds = [
497-
project["name"] for project, db_dir in database_results if db_dir is None
498-
]
499-
if failed_builds:
500-
print(
501-
f"ERROR: {len(failed_builds)} database builds failed: {', '.join(failed_builds)}"
502-
)
503-
sys.exit(1)
495+
case "dca":
496+
experiment_names = args.dca
497+
if experiment_names is None:
498+
print("ERROR: --dca argument is required for DCA strategy")
499+
sys.exit(1)
500+
501+
if args.pat is None:
502+
print("ERROR: --pat argument is required for DCA strategy")
503+
sys.exit(1)
504+
if not args.pat.exists():
505+
print(f"ERROR: Personal Access Token file '{pat}' does not exist.")
506+
sys.exit(1)
507+
with open(args.pat, "r") as f:
508+
pat = f.read().strip()
509+
download_dca_databases(
510+
language,
511+
experiment_names,
512+
pat,
513+
projects,
514+
)
504515

505516
# clean up existing MaD data for the projects
506-
for project, _ in database_results:
517+
for project in projects:
507518
clean_up_mad_destination_for_project(config, project["name"])
508519

509-
for project, database_dir in database_results:
510-
if database_dir is not None:
511-
generate_models(config, args, project, database_dir)
520+
# Generate models for all projects
521+
print("\n=== Generating models ===")
522+
523+
for project in projects:
524+
generate_models(config, args, project)
512525

513526

514527
if __name__ == "__main__":
@@ -543,6 +556,11 @@ def main(config, args) -> None:
543556
help="What `--threads` value to pass to `codeql` (default %(default)s)",
544557
default=0,
545558
)
559+
parser.add_argument(
560+
"--reuse-databases",
561+
action="store_true",
562+
help="Whether to reuse existing databases instead of rebuilding/redownloading them",
563+
)
546564
args = parser.parse_args()
547565

548566
# Load config file

0 commit comments

Comments
 (0)