diff --git a/src/cpp.cc b/src/cpp.cc old mode 100644 new mode 100755 index 099438c5f..f53932a0d --- a/src/cpp.cc +++ b/src/cpp.cc @@ -3346,8 +3346,6 @@ void Printer::Cpp::pytorch_makefile(const Options &opts, const AST &ast) { stream << "CXXFLAGS += -DMALLOC_BATCH" << endl; stream << "endif" << endl << endl; } - // pytorch C++ extension doesn't support C++17 yet - stream << "CXXFLAGS := $(CXXFLAGS) | sed 's/c++17/c++14/'" << endl; stream << "compiler_args := $(CXXFLAGS) | sed -e 's/\\s/\", \"/g'" << endl; stream << "LDFLAGS := $(LDFLAGS) | sed 's/-L\\//\\//g'" << endl; @@ -3357,7 +3355,7 @@ void Printer::Cpp::pytorch_makefile(const Options &opts, const AST &ast) { << "sed -e 's/\\s/\", \"-Xlinker\", \"-rpath\", \"-Xlinker\", \"/g'" << endl; stream << "library_dirs := $(LDFLAGS) | sed -e 's/\\s/\", \"/g'" << endl; - stream << "LDLIBS := $(LDLIBS) | sed 's/-l//g'" << endl; + stream << "LDLIBS := $(LDLIBS) | sed -E 's/(^|[[:space:]])-l/\\1/g'" << endl; stream << "libraries := $(LDLIBS) | sed -e 's/\\s/\", \"/g'" << endl; stream << "CPPFLAGS := $(CPPFLAGS) | sed 's/-I\\//\\//g'" << endl; stream << "include_dirs := $(CPPFLAGS) | sed -e 's/\\s/\", \"/g'" @@ -3372,7 +3370,8 @@ void Printer::Cpp::pytorch_makefile(const Options &opts, const AST &ast) { stream << "setup.py:" << endl; stream << "\t@echo 'from setuptools import setup' > $@" << endl; stream << "\t@echo -e 'from torch.utils.cpp_extension import " - << "BuildExtension, CppExtension\\n' >> $@" << endl; + << "BuildExtension, CppExtension, include_paths, library_paths\\n' >> $@" << endl; + stream << "\t@echo 'import torch' >> $@" << endl; stream << "\t@echo -n 'compiler_args = [\"' >> $@" << endl; stream << "\t@echo -n $(compiler_args) >> $@" << endl; stream << "\t@echo -e '\"]\\n' >> $@" << endl; @@ -3399,7 +3398,7 @@ void Printer::Cpp::pytorch_makefile(const Options &opts, const AST &ast) { << "[\"pytorch_interface.cc\", ' >> $@" << endl; stream << "\t@echo '" << out_files << "],' >> $@" << endl; stream << "\t@echo -e ' include_dirs = " - << "include_dirs,' >> $@" << endl; + << "include_dirs + include_paths(),' >> $@" << endl; stream << "\t@echo -e ' extra_compile_args = " << "compiler_args,' >> $@" << endl; stream << "\t@echo -e ' extra_link_args = " @@ -3407,7 +3406,7 @@ void Printer::Cpp::pytorch_makefile(const Options &opts, const AST &ast) { stream << "\t@echo -e ' libraries = " << "libraries,' >> $@" << endl; stream << "\t@echo -e ' library_dirs = " - << "library_dirs)\\n' >> $@" << endl; + << "library_dirs + library_paths())\\n' >> $@" << endl; stream << "\t@echo 'setup(name = \"" << opts.pytorch_module_mame << "\",' >> $@" << endl; stream << "\t@echo ' version = \"0.0.1\",' >> $@" << endl;