diff --git a/gpkit/solution_array.py b/gpkit/solution_array.py index d3f511fe..b115ee4d 100644 --- a/gpkit/solution_array.py +++ b/gpkit/solution_array.py @@ -5,6 +5,7 @@ import warnings as pywarnings import pickle import gzip +import json import pickletools import numpy as np from .nomials import NomialArray @@ -320,6 +321,57 @@ def cast(function, val1, val2): return function(val1, val2) +def diff_retrieval(self, other, sharedvks, showvars=None, *, jsondiff=False, + senssdiff=False, absdiff=False, reldiff=False): + """A helper function for generalized diff method + - retreives svars and ovars, + """ + svars = self["variables"] + ovars = other["variables"] + # get the type of diffs + diff_dict = {} + if jsondiff == False: + if reldiff: + rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) + for vk in sharedvks} + diff_dict['rel'] = rel_diff + if absdiff: + abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} + diff_dict['abs'] = abs_diff + if senssdiff: + ssenss = self["sensitivities"]["variables"] + osenss = other["sensitivities"]["variables"] + senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) + for vk in sharedvks} + diff_dict['sens'] = senss_delta + else: + if reldiff: + rel_diff = {} + for vk in sharedvks: + val = 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) + if isinstance(val, np.ndarray): + val = val.tolist() + rel_diff[str(vk)] = val + diff_dict['rel'] = rel_diff + if absdiff: + abs_diff = {} + for vk in sharedvks: + val = cast(sub, svars[vk], ovars[vk]) + if isinstance(val, np.ndarray): + val = val.tolist() + abs_diff[str(vk)] = val + diff_dict['abs'] = abs_diff + if senssdiff: + sense_delta = {} + ssenss = self["sensitivities"]["variables"] + osenss = other["sensitivities"]["variables"] + for vk in sharedvks: + val = cast(sub, ssenss[vk], osenss[vk]) + sense_delta[str(vk)] = val + diff_dict['sens'] = senss_delta + return diff_dict + + class SolutionArray(DictOfLists): """A dictionary (of dictionaries) of lists, with convenience methods. @@ -398,8 +450,8 @@ def almost_equal(self, other, reltol=1e-3, sens_abstol=0.01): return True # pylint: disable=too-many-locals, too-many-branches, too-many-statements - def diff(self, other, showvars=None, *, - constraintsdiff=True, senssdiff=False, sensstol=0.1, + def diff(self, other, showvars=None, *, jsondiff=False, filename="solution.json" + ,constraintsdiff=True, senssdiff=False, sensstol=0.1, absdiff=False, abstol=0.1, reldiff=True, reltol=1.0, sortmodelsbysenss=True, **tableargs): """Outputs differences between this solution and another @@ -469,9 +521,18 @@ def diff(self, other, showvars=None, *, lines.append("\n".join(" %s" % key for key in ovks - svks)) lines.append("") sharedvks = svks.intersection(ovks) + + # retrieve diff data + diff_dict = diff_retrieval(self, other, sharedvks, showvars=showvars, + jsondiff=jsondiff, senssdiff=senssdiff, + absdiff=absdiff, reldiff=reldiff) + if jsondiff: + with open(filename, "w") as f: + json.dump(diff_dict, f) + return diff_dict + if reldiff: - rel_diff = {vk: 100*(cast(np.divide, svars[vk], ovars[vk]) - 1) - for vk in sharedvks} + rel_diff = diff_dict['rel'] lines += var_table(rel_diff, "Relative Differences |above %g%%|" % reltol, valfmt="%+.1f%% ", vecfmt="%+6.1f%% ", @@ -480,7 +541,7 @@ def diff(self, other, showvars=None, *, lines.insert(-1, ("The largest is %+g%%." % unrolled_absmax(rel_diff.values()))) if absdiff: - abs_diff = {vk: cast(sub, svars[vk], ovars[vk]) for vk in sharedvks} + abs_diff = diff_dict['abs'] lines += var_table(abs_diff, "Absolute Differences |above %g|" % abstol, valfmt="%+.2g", vecfmt="%+8.2g", @@ -489,10 +550,7 @@ def diff(self, other, showvars=None, *, lines.insert(-1, ("The largest is %+g." % unrolled_absmax(abs_diff.values()))) if senssdiff: - ssenss = self["sensitivities"]["variables"] - osenss = other["sensitivities"]["variables"] - senss_delta = {vk: cast(sub, ssenss[vk], osenss[vk]) - for vk in svks.intersection(ovks)} + senss_delta = diff_dict['sens'] lines += var_table(senss_delta, "Sensitivity Differences |above %g|" % sensstol, valfmt="%+-.2f ", vecfmt="%+-6.2f",