diff --git a/run_tapenade_blas.py b/run_tapenade_blas.py index 619ad21..c61829e 100644 --- a/run_tapenade_blas.py +++ b/run_tapenade_blas.py @@ -6,8 +6,11 @@ import subprocess import sys from pathlib import Path +from shutil import rmtree as shrm FORTRAN_EXTS = {".f", ".for", ".f90", ".F", ".F90"} +TAPENADE_USELESS_ZONES = 3 +GENLIBTMP = "TMPGENLIB" def is_fortran(p: Path) -> bool: return p.suffix in FORTRAN_EXTS @@ -643,10 +646,11 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): complex_vars = set() integer_vars = set() char_vars = set() + array_vars = set() # Find the argument declaration section lines = content.split('\n') - in_args_section = False + in_args_section = False # This variable is used nowhere for i, line in enumerate(lines): line_stripped = line.strip() @@ -661,6 +665,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): # Also look for the actual declaration lines (not in comments) if line_stripped and not line_stripped.startswith('*') and not line_stripped.startswith('C '): + # Look for ARRAY variables + is_array = ('(' in line_stripped) and (')' in line_stripped) # Parse variable declarations if line_stripped.startswith('REAL') or line_stripped.startswith('DOUBLE PRECISION') or line_stripped.startswith('FLOAT'): @@ -677,6 +683,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): var = re.sub(r'\*.*$', '', var) if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var): real_vars.add(var) + if is_array: + array_vars.add(var) elif line_stripped.startswith('INTEGER'): int_decl = re.search(r'INTEGER\s+(.+)', line_stripped, re.IGNORECASE) @@ -689,6 +697,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): var = re.sub(r'\*.*$', '', var) if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var): integer_vars.add(var) + if is_array: + array_vars.add(var) elif line_stripped.startswith('CHARACTER'): char_decl = re.search(r'CHARACTER\s+(.+)', line_stripped, re.IGNORECASE) @@ -701,6 +711,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): var = re.sub(r'\*.*$', '', var) if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var): char_vars.add(var) + if is_array: + array_vars.add(var) elif line_stripped.startswith('COMPLEX'): # Extract variable names from COMPLEX declaration @@ -716,6 +728,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): var = re.sub(r'\*.*$', '', var) if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var): complex_vars.add(var) # Add complex variables to complex_vars + if is_array: + array_vars.add(var) # For FUNCTIONs with explicit return types, add function name to appropriate variable set if func_type == 'FUNCTION': @@ -847,7 +861,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False): 'real_vars': real_vars, 'complex_vars': complex_vars, 'integer_vars': integer_vars, - 'char_vars': char_vars + 'char_vars': char_vars, + 'array_vars': array_vars } return func_name, valid_inputs, valid_outputs, inout_vars, func_type, params, warnings, param_types, has_sufficient_docs @@ -8018,6 +8033,7 @@ def main(): help="AD modes to generate: d (forward scalar), dv (forward vector), b (reverse scalar), bv (reverse vector), all (all modes). Default: all") ap.add_argument("--nbdirsmax", type=int, default=4, help="Maximum number of derivative directions for vector mode (default: 4)") ap.add_argument("--flat", action="store_true", help="Use flat directory structure (all files in function directory, single DIFFSIZES.inc)") + ap.add_argument("--genlib", default=None, required=False, help="Generate Tapenade external library") ap.add_argument("--extra", nargs=argparse.REMAINDER, help="Extra args passed to Tapenade after -d/-r", default=[]) args = ap.parse_args() @@ -8072,10 +8088,10 @@ def main(): modes = {"d", "dv", "b", "bv"} # Determine which specific modes to run - run_d = "d" in modes - run_dv = "dv" in modes - run_b = "b" in modes - run_bv = "bv" in modes + run_d = "d" in modes or args.genlib + run_dv = not args.genlib and "dv" in modes + run_b = not args.genlib and "b" in modes + run_bv = not args.genlib and "bv" in modes # List of non-differentiable functions to skip entirely # See SKIPPED_FUNCTIONS.md for detailed documentation on why each is skipped @@ -8144,8 +8160,76 @@ def run_task(task): # Create output directory structure flat_mode = args.flat mode_dirs = {} - - if flat_mode: + + if (args.genlib): + # When generating the general lib useful to Tapenade, we will save everything in a tmp file + # and only the lib in a local folder used to concatenate everything afterwards. + tmp_dir = Path(GENLIBTMP).resolve() + tmp_dir.mkdir(parents=True, exist_ok=True) + func_out_dir = tmp_dir + genlib_dir = out_dir + genlib_dir.mkdir(parents=True, exist_ok=True) + mode_dirs['d'] = tmp_dir + + def convert_tap_result2genlib_format(l: str) : + out = [] + infos = l.split("[")[1] + use_infos = True + for c in infos[TAPENADE_USELESS_ZONES:]: # Don't bother with the first + if(c == "]"): + break + if use_infos: + if(c == "("): + use_infos = False + else: + out = out + [("0" if c == "." else "1" )] + else: + if(c == ")"): + use_infos = True + + return out + + def parse_tap_trace4inout(fname): + with open(fname, "r") as f: + sought_after = " ===================== IN-OUT ANALYSIS OF UNIT " + l = f.readline() + while(not l.startswith(sought_after)): + l = f.readline() + + # Now we read the next one, and start looking at the arguments + var2idx_mapping = dict() + l = f.readline().strip() + for v in l.split(" ")[TAPENADE_USELESS_ZONES:]: # The first variables are useless + not_quite_id, var_name = v.split("]") + idx = int(not_quite_id[1:])-TAPENADE_USELESS_ZONES + var2idx_mapping[var_name] = idx + + # Now that the mapping has been parsed, we move towards the end of the analysis phase, and extract the summary + sought_after = "terminateFGForUnit Unit" + while(not l.startswith(sought_after)): + l = f.readline() + # We have found our signal to read the results + # It is always four lines looking like this + # N [111111..11(1)111111] ---> corresponds to NotReadNotWritten, probably useless + # R [...1111111(1)11111.] ---> corresponds to ReadNotWritten + # W [..1.......(1).....1] ---> corresponds to NotReadThenWritten ==> Need to check what the 1 in third position means + # RW [..........(1).....1] ---> corresponds to ReadThenWritten + l = f.readline() + # Discard the not read not written elements + l = f.readline() + # We deal with the ReadNotWritten information + read_not_written = convert_tap_result2genlib_format(l) + l = f.readline() + # Deal with NotReadThenWritten + not_read_then_written = convert_tap_result2genlib_format(l) + l = f.readline() + # Deal with ReadThenWritten + read_then_written = convert_tap_result2genlib_format(l) + + return read_not_written, not_read_then_written, read_then_written, var2idx_mapping + + + elif flat_mode: # Flat mode with organized subdirectories: src/, test/, include/ src_dir = out_dir / 'src' test_dir = out_dir / 'test' @@ -8188,7 +8272,7 @@ def run_task(task): mode_dirs['bv'].mkdir(parents=True, exist_ok=True) # Update log path to be in the function subdirectory - func_log_path = func_out_dir / (src.stem + ".tapenade.log") + # func_log_path = func_out_dir / (src.stem + ".tapenade.log") # ISNT THIS COMPLETELY USELESS?? # Find dependency files called_functions = parse_function_calls(src) @@ -8240,11 +8324,14 @@ def run_task(task): for dep_file in main_file_removed: cmd.append(str(dep_file)) cmd.extend(list(args.extra)) + if (args.genlib): + cmd = cmd + ["-traceinout", src.stem] try: with open(mode_log_path, "w") as logf: logf.write(f"Mode: FORWARD (scalar)\n") # Format command for logging (properly quoted for shell copy-paste) + print("CMD:", cmd) cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd) logf.write(f"Command: {cmd_str}\n") logf.write(f"Function: {func_name}\n") @@ -8285,6 +8372,44 @@ def run_task(task): pass print(f" ERROR: Exception during forward mode execution: {e}", file=sys.stderr) return_codes["forward"] = 999 + + if (args.genlib) : # Everything went well, and we are trying to generate the external lib + read_not_written, not_read_then_written, read_then_written, var2idx = parse_tap_trace4inout(mode_log_path) + if func_type == 'FUNCTION': + param_for_genlib = all_params + [src.stem] + else: + param_for_genlib = all_params + param_2_tap_reordering = [var2idx[p.lower()] for p in param_for_genlib] + with open("DiffBlasGenLib", "a") as f: + f.write(("function " if func_type == 'FUNCTION' else "subroutine ") + src.stem + ":\n") + indent = " " + f.write(indent + "external:\n") + shape = "(" + ", ".join(["param " + str(i) for i in range(1,len(all_params)+1)] + ["result" if func_type == 'FUNCTION' else ""]) + ")" ## TODO: Need to add ', return' in case of a function,. dpeending on whether it is within the all params or not + f.write(indent + "shape: " + shape + "\n") + types = [] + for p in param_for_genlib: + current_type = "" + if p.upper() in param_types['real_vars'] or p.lower() in param_types['real_vars']: + current_type = "metavar float" # We should probably be more precise in order to handle mixed precision things + # Namely, adapt to + # modifiedType(modifiers(ident double), float() for double / REAL*8 + # float() for single precision + elif p.upper() in param_types['complex_vars'] or p.lower() in param_types['complex_vars']: + current_type = "metavar complex" + # Similar to the real variables, we should be able to be more precise in terms of precision of the complex variable + elif p.upper() in param_types['integer_vars'] or p.lower() in param_types['integer_vars']: + current_type = "metavar integer" + elif p.upper() in param_types['char_vars'] or p.lower() in param_types['char_vars']: + current_type = "character()" + if p.upper() in param_types['array_vars'] or p.lower() in param_types['array_vars']: + current_type = "arrayType(" + current_type + ", dimColons())" + types.append(current_type) + types = "(" + ", ".join(types) + ")" + f.write(indent + "type: " + types + "\n") + f.write(indent + "ReadNotWritten: (" + ", ".join([read_not_written[i] for i in param_2_tap_reordering]) + ")\n") + f.write(indent + "NotReadThenWritten: (" + ", ".join([not_read_then_written[i] for i in param_2_tap_reordering]) + ")\n") + f.write(indent + "ReadThenWritten: (" + ", ".join([read_then_written[i] for i in param_2_tap_reordering]) + ")\n") + f.write("\n") # Run scalar reverse mode (b) if run_b: @@ -8383,6 +8508,7 @@ def run_task(task): try: with open(mode_log_path, "w") as logf: logf.write(f"Mode: FORWARD VECTOR\n") + print("CMD:", cmd) # Format command for logging (properly quoted for shell copy-paste) cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd) logf.write(f"Command: {cmd_str}\n") @@ -8449,6 +8575,7 @@ def run_task(task): try: with open(mode_log_path, "w") as logf: logf.write(f"Mode: REVERSE VECTOR\n") + print("CMD:", cmd) # Format command for logging (properly quoted for shell copy-paste) cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd) logf.write(f"Command: {cmd_str}\n") @@ -8677,8 +8804,8 @@ def run_task(task): # Return the worst return code (non-zero if any mode failed) final_rc = max(return_codes.values()) if return_codes else 999 - return (src, final_rc) - + return (src, final_rc) + # Serial or parallel execution results = [] if args.jobs <= 1: @@ -8774,6 +8901,9 @@ def run_task(task): if "reverse" in args.mode or args.mode == "both": print(" make vector-reverse # Build vector reverse mode only") print(" ./test__vector_forward # Run vector forward mode test") + + if args.genlib: + shrm(Path(GENLIBTMP)) def generate_top_level_makefile(out_dir, flat_mode=False): """Generate the top-level Makefile for building all subdirectories or flat makefiles"""