Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 14 additions & 20 deletions examples/convert_xgboost.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,14 @@
"""
Before use this file, install xgboost first
"""
from __future__ import print_function
import sys
import xgboost as xgb
"""
import os
import struct
from ctypes import cdll
from ctypes import c_float, c_uint, c_char_p, c_bool
import sys

import xgboost as xgb

LIB_PATH = "./libgbdt.so"

def convert(input_model, objective, output_file):
def convert(input_model, output_file):
model = xgb.Booster()
model.load_model(input_model)
tmp_file = output_file + ".gbdt_rs.mid"
Expand All @@ -27,11 +24,11 @@ def convert(input_model, objective, output_file):
base_score = struct.unpack('f',f.read(4))[0]
except Exception as e:
print("error: ", e)
return 1
return

if os.path.exists(tmp_file):
print("Intermediate file %s exists. Please remove this file or change your output file path" % tmp_file)
return 1
return

# dump json
model.dump_model(tmp_file, dump_format="json")
Expand All @@ -46,19 +43,16 @@ def convert(input_model, objective, output_file):
except Exception as e:
print("error: ", e)
os.remove(tmp_file)
return 1
return

os.remove(tmp_file)
return 0


if __name__ == "__main__":
if len(sys.argv) != 4:
print("usage: python script input_model_path objective output_file_path")
print("supported booster: gbtree")
print("supported objective: 'reg:linear', 'binary:logistic', 'reg:logistic'," + \
"'binary:logitraw', 'multi:softmax', 'multi:softprob', 'rank:pairwise'")
if len(sys.argv) != 3:
print("usage: python convert_xgboost.py <input_xgboost_model_path> <output_gbdt_model_path>")
print("--supported booster: gbtree")
print("--supported objective: 'reg:linear', 'binary:logistic', 'reg:logistic', "
"'binary:logitraw', 'multi:softmax', 'multi:softprob', 'rank:pairwise'")
exit(1)
convert(sys.argv[1], sys.argv[2], sys.argv[3])


convert(sys.argv[1], sys.argv[2])