Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
152 changes: 141 additions & 11 deletions run_tapenade_blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,11 @@
import subprocess
import sys
from pathlib import Path
from shutil import rmtree as shrm

FORTRAN_EXTS = {".f", ".for", ".f90", ".F", ".F90"}
TAPENADE_USELESS_ZONES = 3
GENLIBTMP = "TMPGENLIB"

def is_fortran(p: Path) -> bool:
return p.suffix in FORTRAN_EXTS
Expand Down Expand Up @@ -643,10 +646,11 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
complex_vars = set()
integer_vars = set()
char_vars = set()
array_vars = set()

# Find the argument declaration section
lines = content.split('\n')
in_args_section = False
in_args_section = False # This variable is used nowhere

Comment on lines +653 to 654
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@sriharikrishna Should we clean the code a bit?

for i, line in enumerate(lines):
line_stripped = line.strip()
Expand All @@ -661,6 +665,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):

# Also look for the actual declaration lines (not in comments)
if line_stripped and not line_stripped.startswith('*') and not line_stripped.startswith('C '):
# Look for ARRAY variables
is_array = ('(' in line_stripped) and (')' in line_stripped)

# Parse variable declarations
if line_stripped.startswith('REAL') or line_stripped.startswith('DOUBLE PRECISION') or line_stripped.startswith('FLOAT'):
Expand All @@ -677,6 +683,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
var = re.sub(r'\*.*$', '', var)
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
real_vars.add(var)
if is_array:
array_vars.add(var)

elif line_stripped.startswith('INTEGER'):
int_decl = re.search(r'INTEGER\s+(.+)', line_stripped, re.IGNORECASE)
Expand All @@ -689,6 +697,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
var = re.sub(r'\*.*$', '', var)
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
integer_vars.add(var)
if is_array:
array_vars.add(var)

elif line_stripped.startswith('CHARACTER'):
char_decl = re.search(r'CHARACTER\s+(.+)', line_stripped, re.IGNORECASE)
Expand All @@ -701,6 +711,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
var = re.sub(r'\*.*$', '', var)
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
char_vars.add(var)
if is_array:
array_vars.add(var)

elif line_stripped.startswith('COMPLEX'):
# Extract variable names from COMPLEX declaration
Expand All @@ -716,6 +728,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
var = re.sub(r'\*.*$', '', var)
if var and re.match(r'^[A-Za-z][A-Za-z0-9_]*$', var):
complex_vars.add(var) # Add complex variables to complex_vars
if is_array:
array_vars.add(var)

# For FUNCTIONs with explicit return types, add function name to appropriate variable set
if func_type == 'FUNCTION':
Expand Down Expand Up @@ -847,7 +861,8 @@ def parse_fortran_function(file_path: Path, suppress_warnings=False):
'real_vars': real_vars,
'complex_vars': complex_vars,
'integer_vars': integer_vars,
'char_vars': char_vars
'char_vars': char_vars,
'array_vars': array_vars
}

return func_name, valid_inputs, valid_outputs, inout_vars, func_type, params, warnings, param_types, has_sufficient_docs
Expand Down Expand Up @@ -8018,6 +8033,7 @@ def main():
help="AD modes to generate: d (forward scalar), dv (forward vector), b (reverse scalar), bv (reverse vector), all (all modes). Default: all")
ap.add_argument("--nbdirsmax", type=int, default=4, help="Maximum number of derivative directions for vector mode (default: 4)")
ap.add_argument("--flat", action="store_true", help="Use flat directory structure (all files in function directory, single DIFFSIZES.inc)")
ap.add_argument("--genlib", default=None, required=False, help="Generate Tapenade external library")
ap.add_argument("--extra", nargs=argparse.REMAINDER, help="Extra args passed to Tapenade after -d/-r", default=[])
args = ap.parse_args()

Expand Down Expand Up @@ -8072,10 +8088,10 @@ def main():
modes = {"d", "dv", "b", "bv"}

# Determine which specific modes to run
run_d = "d" in modes
run_dv = "dv" in modes
run_b = "b" in modes
run_bv = "bv" in modes
run_d = "d" in modes or args.genlib
run_dv = not args.genlib and "dv" in modes
run_b = not args.genlib and "b" in modes
run_bv = not args.genlib and "bv" in modes

# List of non-differentiable functions to skip entirely
# See SKIPPED_FUNCTIONS.md for detailed documentation on why each is skipped
Expand Down Expand Up @@ -8144,8 +8160,76 @@ def run_task(task):
# Create output directory structure
flat_mode = args.flat
mode_dirs = {}

if flat_mode:

if (args.genlib):
# When generating the general lib useful to Tapenade, we will save everything in a tmp file
# and only the lib in a local folder used to concatenate everything afterwards.
tmp_dir = Path(GENLIBTMP).resolve()
tmp_dir.mkdir(parents=True, exist_ok=True)
func_out_dir = tmp_dir
genlib_dir = out_dir
genlib_dir.mkdir(parents=True, exist_ok=True)
mode_dirs['d'] = tmp_dir

def convert_tap_result2genlib_format(l: str) :
out = []
infos = l.split("[")[1]
use_infos = True
for c in infos[TAPENADE_USELESS_ZONES:]: # Don't bother with the first
if(c == "]"):
break
if use_infos:
if(c == "("):
use_infos = False
else:
out = out + [("0" if c == "." else "1" )]
else:
if(c == ")"):
use_infos = True

return out

def parse_tap_trace4inout(fname):
with open(fname, "r") as f:
sought_after = " ===================== IN-OUT ANALYSIS OF UNIT "
l = f.readline()
while(not l.startswith(sought_after)):
l = f.readline()

# Now we read the next one, and start looking at the arguments
var2idx_mapping = dict()
l = f.readline().strip()
for v in l.split(" ")[TAPENADE_USELESS_ZONES:]: # The first variables are useless
not_quite_id, var_name = v.split("]")
idx = int(not_quite_id[1:])-TAPENADE_USELESS_ZONES
var2idx_mapping[var_name] = idx

# Now that the mapping has been parsed, we move towards the end of the analysis phase, and extract the summary
sought_after = "terminateFGForUnit Unit"
while(not l.startswith(sought_after)):
l = f.readline()
# We have found our signal to read the results
# It is always four lines looking like this
# N [111111..11(1)111111] ---> corresponds to NotReadNotWritten, probably useless
# R [...1111111(1)11111.] ---> corresponds to ReadNotWritten
# W [..1.......(1).....1] ---> corresponds to NotReadThenWritten ==> Need to check what the 1 in third position means
# RW [..........(1).....1] ---> corresponds to ReadThenWritten
l = f.readline()
# Discard the not read not written elements
l = f.readline()
# We deal with the ReadNotWritten information
read_not_written = convert_tap_result2genlib_format(l)
l = f.readline()
# Deal with NotReadThenWritten
not_read_then_written = convert_tap_result2genlib_format(l)
l = f.readline()
# Deal with ReadThenWritten
read_then_written = convert_tap_result2genlib_format(l)

return read_not_written, not_read_then_written, read_then_written, var2idx_mapping


elif flat_mode:
# Flat mode with organized subdirectories: src/, test/, include/
src_dir = out_dir / 'src'
test_dir = out_dir / 'test'
Expand Down Expand Up @@ -8188,7 +8272,7 @@ def run_task(task):
mode_dirs['bv'].mkdir(parents=True, exist_ok=True)

# Update log path to be in the function subdirectory
func_log_path = func_out_dir / (src.stem + ".tapenade.log")
# func_log_path = func_out_dir / (src.stem + ".tapenade.log") # ISNT THIS COMPLETELY USELESS??
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Verify need for this variable


# Find dependency files
called_functions = parse_function_calls(src)
Expand Down Expand Up @@ -8240,11 +8324,14 @@ def run_task(task):
for dep_file in main_file_removed:
cmd.append(str(dep_file))
cmd.extend(list(args.extra))
if (args.genlib):
cmd = cmd + ["-traceinout", src.stem]

try:
with open(mode_log_path, "w") as logf:
logf.write(f"Mode: FORWARD (scalar)\n")
# Format command for logging (properly quoted for shell copy-paste)
print("CMD:", cmd)
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
logf.write(f"Command: {cmd_str}\n")
logf.write(f"Function: {func_name}\n")
Expand Down Expand Up @@ -8285,6 +8372,44 @@ def run_task(task):
pass
print(f" ERROR: Exception during forward mode execution: {e}", file=sys.stderr)
return_codes["forward"] = 999

if (args.genlib) : # Everything went well, and we are trying to generate the external lib
read_not_written, not_read_then_written, read_then_written, var2idx = parse_tap_trace4inout(mode_log_path)
if func_type == 'FUNCTION':
param_for_genlib = all_params + [src.stem]
else:
param_for_genlib = all_params
param_2_tap_reordering = [var2idx[p.lower()] for p in param_for_genlib]
with open("DiffBlasGenLib", "a") as f:
f.write(("function " if func_type == 'FUNCTION' else "subroutine ") + src.stem + ":\n")
indent = " "
f.write(indent + "external:\n")
shape = "(" + ", ".join(["param " + str(i) for i in range(1,len(all_params)+1)] + ["result" if func_type == 'FUNCTION' else ""]) + ")" ## TODO: Need to add ', return' in case of a function,. dpeending on whether it is within the all params or not
f.write(indent + "shape: " + shape + "\n")
types = []
for p in param_for_genlib:
current_type = ""
if p.upper() in param_types['real_vars'] or p.lower() in param_types['real_vars']:
current_type = "metavar float" # We should probably be more precise in order to handle mixed precision things
# Namely, adapt to
# modifiedType(modifiers(ident double), float() for double / REAL*8
# float() for single precision
elif p.upper() in param_types['complex_vars'] or p.lower() in param_types['complex_vars']:
current_type = "metavar complex"
# Similar to the real variables, we should be able to be more precise in terms of precision of the complex variable
elif p.upper() in param_types['integer_vars'] or p.lower() in param_types['integer_vars']:
current_type = "metavar integer"
elif p.upper() in param_types['char_vars'] or p.lower() in param_types['char_vars']:
current_type = "character()"
if p.upper() in param_types['array_vars'] or p.lower() in param_types['array_vars']:
current_type = "arrayType(" + current_type + ", dimColons())"
types.append(current_type)
types = "(" + ", ".join(types) + ")"
f.write(indent + "type: " + types + "\n")
f.write(indent + "ReadNotWritten: (" + ", ".join([read_not_written[i] for i in param_2_tap_reordering]) + ")\n")
f.write(indent + "NotReadThenWritten: (" + ", ".join([not_read_then_written[i] for i in param_2_tap_reordering]) + ")\n")
f.write(indent + "ReadThenWritten: (" + ", ".join([read_then_written[i] for i in param_2_tap_reordering]) + ")\n")
f.write("\n")

# Run scalar reverse mode (b)
if run_b:
Expand Down Expand Up @@ -8383,6 +8508,7 @@ def run_task(task):
try:
with open(mode_log_path, "w") as logf:
logf.write(f"Mode: FORWARD VECTOR\n")
print("CMD:", cmd)
# Format command for logging (properly quoted for shell copy-paste)
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
logf.write(f"Command: {cmd_str}\n")
Expand Down Expand Up @@ -8449,6 +8575,7 @@ def run_task(task):
try:
with open(mode_log_path, "w") as logf:
logf.write(f"Mode: REVERSE VECTOR\n")
print("CMD:", cmd)
# Format command for logging (properly quoted for shell copy-paste)
cmd_str = ' '.join(shlex.quote(str(arg)) for arg in cmd)
logf.write(f"Command: {cmd_str}\n")
Expand Down Expand Up @@ -8677,8 +8804,8 @@ def run_task(task):

# Return the worst return code (non-zero if any mode failed)
final_rc = max(return_codes.values()) if return_codes else 999
return (src, final_rc)

return (src, final_rc)
# Serial or parallel execution
results = []
if args.jobs <= 1:
Expand Down Expand Up @@ -8774,6 +8901,9 @@ def run_task(task):
if "reverse" in args.mode or args.mode == "both":
print(" make vector-reverse # Build vector reverse mode only")
print(" ./test_<function>_vector_forward # Run vector forward mode test")

if args.genlib:
shrm(Path(GENLIBTMP))

def generate_top_level_makefile(out_dir, flat_mode=False):
"""Generate the top-level Makefile for building all subdirectories or flat makefiles"""
Expand Down
Loading