diff --git a/examples/convert_xgboost.py b/examples/convert_xgboost.py index caa21bb..7d39b50 100644 --- a/examples/convert_xgboost.py +++ b/examples/convert_xgboost.py @@ -1,17 +1,14 @@ """ Before use this file, install xgboost first -""" -from __future__ import print_function -import sys -import xgboost as xgb +""" import os import struct -from ctypes import cdll -from ctypes import c_float, c_uint, c_char_p, c_bool +import sys + +import xgboost as xgb -LIB_PATH = "./libgbdt.so" -def convert(input_model, objective, output_file): +def convert(input_model, output_file): model = xgb.Booster() model.load_model(input_model) tmp_file = output_file + ".gbdt_rs.mid" @@ -27,11 +24,11 @@ def convert(input_model, objective, output_file): base_score = struct.unpack('f',f.read(4))[0] except Exception as e: print("error: ", e) - return 1 + return if os.path.exists(tmp_file): print("Intermediate file %s exists. Please remove this file or change your output file path" % tmp_file) - return 1 + return # dump json model.dump_model(tmp_file, dump_format="json") @@ -46,19 +43,16 @@ def convert(input_model, objective, output_file): except Exception as e: print("error: ", e) os.remove(tmp_file) - return 1 + return os.remove(tmp_file) - return 0 if __name__ == "__main__": - if len(sys.argv) != 4: - print("usage: python script input_model_path objective output_file_path") - print("supported booster: gbtree") - print("supported objective: 'reg:linear', 'binary:logistic', 'reg:logistic'," + \ - "'binary:logitraw', 'multi:softmax', 'multi:softprob', 'rank:pairwise'") + if len(sys.argv) != 3: + print("usage: python convert_xgboost.py ") + print("--supported booster: gbtree") + print("--supported objective: 'reg:linear', 'binary:logistic', 'reg:logistic', " + "'binary:logitraw', 'multi:softmax', 'multi:softprob', 'rank:pairwise'") exit(1) - convert(sys.argv[1], sys.argv[2], sys.argv[3]) - - + convert(sys.argv[1], sys.argv[2])