Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 67 additions & 18 deletions wrap.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,13 +55,15 @@
mpicc = 'mpicc' # Default name for the MPI compiler
includes = [] # Default set of directories to inlucde when parsing mpi.h
pmpi_init_binding = "pmpi_init_" # Default binding for pmpi_init
pmpi_init_thread_binding = "pmpi_init_thread_" # Default binding for pmpi_init
output_fortran_wrappers = False # Don't print fortran wrappers by default
output_guards = False # Don't print reentry guards by default
skip_headers = False # Skip header information and defines (for non-C output)
dump_prototypes = False # Just exit and dump MPI protos if false.

# Possible legal bindings for the fortran version of PMPI_Init()
pmpi_init_bindings = ["PMPI_INIT", "pmpi_init", "pmpi_init_", "pmpi_init__"]
pmpi_init_thread_bindings = ["PMPI_INIT_THREAD", "pmpi_init_thread", "pmpi_init_thread_", "pmpi_init_thread__"]

# Possible function return types to consider, used for declaration parser.
# In general, all MPI calls we care about return int. We include double
Expand Down Expand Up @@ -120,12 +122,20 @@
#pragma weak PMPI_INIT
#pragma weak pmpi_init_
#pragma weak pmpi_init__
#pragma weak pmpi_init_thread
#pragma weak PMPI_INIT_THREAD
#pragma weak pmpi_init_thread_
#pragma weak pmpi_init_thread__
#endif /* PIC */

_EXTERN_C_ void pmpi_init(MPI_Fint *ierr);
_EXTERN_C_ void PMPI_INIT(MPI_Fint *ierr);
_EXTERN_C_ void pmpi_init_(MPI_Fint *ierr);
_EXTERN_C_ void pmpi_init__(MPI_Fint *ierr);
_EXTERN_C_ void pmpi_init_thread(MPI_Fint *required, MPI_Fint *provided, MPI_Fint *ierr);
_EXTERN_C_ void PMPI_INIT_THREAD(MPI_Fint *required, MPI_Fint *provided, MPI_Fint *ierr);
_EXTERN_C_ void pmpi_init_thread_(MPI_Fint *required, MPI_Fint *provided, MPI_Fint *ierr);
_EXTERN_C_ void pmpi_init_thread__(MPI_Fint *required, MPI_Fint *provided, MPI_Fint *ierr);

'''

Expand Down Expand Up @@ -466,6 +476,7 @@ def getArgName(self, index):
def fortranFormals(self):
formals = map(Param.fortranFormal, self.argsNoEllipsis())
if self.name == "MPI_Init": formals = [] # Special case for init: no args in fortran
if self.name == "MPI_Init_thread": del formals[0:2] # Special case for init

ierr = []
if self.returnsErrorCode(): ierr = ["MPI_Fint *ierr"]
Expand All @@ -474,6 +485,7 @@ def fortranFormals(self):
def fortranArgNames(self):
names = self.argNames()
if self.name == "MPI_Init": names = []
if self.name == "MPI_Init_thread": del names[0:2]

ierr = []
if self.returnsErrorCode(): ierr = ["ierr"]
Expand Down Expand Up @@ -703,26 +715,30 @@ def write_fortran_wrappers(out, decl, return_val):
out.write(" { \n")

call = FortranDelegation(decl, return_val)

if decl.name == "MPI_Init":

start_arg_at = 0
if decl.name == "MPI_Init" or decl.name == "MPI_Init_thread":
# Use out.write() here so it comes at very beginning of wrapper function
out.write(" int argc = 0;\n");
out.write(" char ** argv = NULL;\n");
call.addActual("&argc");
call.addActual("&argv");
call.write(out)
out.write(" *ierr = %s;\n" % return_val)
out.write("}\n\n")

# Write out various bindings that delegate to the main fortran wrapper
write_fortran_binding(out, decl, delegate_name, "MPI_INIT", ["fortran_init = 1;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init", ["fortran_init = 2;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init_", ["fortran_init = 3;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init__", ["fortran_init = 4;"])
return
if decl.name == "MPI_Init":
call.write(out)
out.write(" *ierr = %s;\n" % return_val)
out.write("}\n\n")
write_fortran_binding(out, decl, delegate_name, "MPI_INIT", ["fortran_init = 1;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init", ["fortran_init = 2;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init_", ["fortran_init = 3;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init__", ["fortran_init = 4;"])
return
else:
start_arg_at = 2

# This look processes the rest of the call for all other routines.
for arg in decl.args:
for arg in decl.args[start_arg_at:]:
if arg.name == "...": # skip ellipsis
continue

Expand Down Expand Up @@ -790,12 +806,18 @@ def write_fortran_wrappers(out, decl, return_val):
else:
out.write(" return %s;\n" % return_val)
out.write("}\n\n")

# Write out various bindings that delegate to the main fortran wrapper
write_fortran_binding(out, decl, delegate_name, decl.name.upper())
write_fortran_binding(out, decl, delegate_name, decl.name.lower())
write_fortran_binding(out, decl, delegate_name, decl.name.lower() + "_")
write_fortran_binding(out, decl, delegate_name, decl.name.lower() + "__")

if decl.name == "MPI_Init_thread":
write_fortran_binding(out, decl, delegate_name, "MPI_INIT_THREAD", ["fortran_init = 1;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init_thread", ["fortran_init = 2;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init_thread_", ["fortran_init = 3;"])
write_fortran_binding(out, decl, delegate_name, "mpi_init_thread__", ["fortran_init = 4;"])
else:
# Write out various bindings that delegate to the main fortran wrapper
write_fortran_binding(out, decl, delegate_name, decl.name.upper())
write_fortran_binding(out, decl, delegate_name, decl.name.lower())
write_fortran_binding(out, decl, delegate_name, decl.name.lower() + "_")
write_fortran_binding(out, decl, delegate_name, decl.name.lower() + "__")


################################################################################
Expand Down Expand Up @@ -928,7 +950,7 @@ def callfn(out, scope, args, children):
out.write(" if (!PMPI_INIT && !pmpi_init && !pmpi_init_ && !pmpi_init__) {\n")
out.write(" fprintf(stderr, \"ERROR: Couldn't find fortran pmpi_init function. Link against static library instead.\\n\");\n")
out.write(" exit(1);\n")
out.write(" }")
out.write(" }\n")
out.write(" switch (fortran_init) {\n")
out.write(" case 1: PMPI_INIT(&%s); break;\n" % return_val)
out.write(" case 2: pmpi_init(&%s); break;\n" % return_val)
Expand All @@ -951,6 +973,32 @@ def write_fortran_init_flag():
output.write("static int fortran_init = 0;\n")
once(write_fortran_init_flag)

elif fn_name == "MPI_Init_thread" and output_fortran_wrappers:
def callfn(out, scope, args, children):
out.write(" if (fortran_init) {\n")
out.write("#ifdef PIC\n")
out.write(" if (!PMPI_INIT_THREAD && !pmpi_init_thread && !pmpi_init_thread_ && !pmpi_init_thread__) {\n")
out.write(" fprintf(stderr, \"ERROR: Couldn't find fortran pmpi_init_thread function. Link against static library instead.\\n\");\n")
out.write(" exit(1);\n")
out.write(" }\n")
out.write(" switch (fortran_init) {\n")
out.write(" case 2: PMPI_INIT_THREAD(&%s,%s,&%s); break;\n" % (fn.args[2].name,fn.args[3].name,return_val))
out.write(" case 2: pmpi_init_thread(&%s,%s,&%s); break;\n" % (fn.args[2].name,fn.args[3].name,return_val))
out.write(" case 3: pmpi_init_thread_(&%s,%s,&%s); break;\n" % (fn.args[2].name,fn.args[3].name,return_val))
out.write(" case 4: pmpi_init_thread__(&%s,%s,&%s); break;\n" % (fn.args[2].name,fn.args[3].name,return_val))
out.write(" default:\n")
out.write(" fprintf(stderr, \"NO SUITABLE FORTRAN MPI_INIT BINDING\\n\");\n")
out.write(" break;\n")
out.write(" }\n")
out.write("#else /* !PIC */\n")
out.write(" %s(&%s,%s,&%s);\n" % (pmpi_init_thread_binding, fn.args[2].name,fn.args[3].name,return_val))
out.write("#endif /* !PIC */\n")
out.write(" } else {\n")
out.write(" %s\n" % c_call)
out.write(" }\n")

fn_scope["callfn"] = callfn

else:
fn_scope["callfn"] = c_call

Expand Down Expand Up @@ -1262,6 +1310,7 @@ def usage():
usage()
else:
pmpi_init_binding = arg
pmpi_init_thread_binding = pmpi_init_thread_bindings[pmpi_init_bindings.index(arg)]

if len(args) < 1 and not dump_prototypes:
usage()
Expand Down