Skip to content

Commit d0ba53e

Browse files
authored
Dev (#302)
* formatting * use cache * skip test_dcnt_r01 * remove separator * Add numpy as np
1 parent 2b47d8e commit d0ba53e

File tree

3 files changed

+89
-41
lines changed

3 files changed

+89
-41
lines changed

goatools/godag/go_tasks.py

Lines changed: 59 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,20 @@
11
"""item-DAG tasks."""
22

3-
__copyright__ = "Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
3+
__copyright__ = (
4+
"Copyright (C) 2010-present, DV Klopfenstein, H Tang, All rights reserved."
5+
)
46
__author__ = "DV Klopfenstein"
57

6-
from goatools.godag.consts import RELATIONSHIP_SET
8+
from ..godag.consts import RELATIONSHIP_SET
79

810

9-
# ------------------------------------------------------------------------------------
1011
def get_go2parents(go2obj, relationships):
1112
"""Get set of parents GO IDs, including parents through user-specfied relationships"""
12-
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
13+
if (
14+
go2obj
15+
and not hasattr(next(iter(go2obj.values())), "relationship")
16+
or not relationships
17+
):
1318
return get_go2parents_isa(go2obj)
1419
go2parents = {}
1520
for goid_main, goterm in go2obj.items():
@@ -21,10 +26,14 @@ def get_go2parents(go2obj, relationships):
2126
go2parents[goid_main] = parents_goids
2227
return go2parents
2328

24-
# ------------------------------------------------------------------------------------
29+
2530
def get_go2children(go2obj, relationships):
2631
"""Get set of children GO IDs, including children through user-specfied relationships"""
27-
if go2obj and not hasattr(next(iter(go2obj.values())), 'relationship') or not relationships:
32+
if (
33+
go2obj
34+
and not hasattr(next(iter(go2obj.values())), "relationship")
35+
or not relationships
36+
):
2837
return get_go2children_isa(go2obj)
2938
go2children = {}
3039
for goid_main, goterm in go2obj.items():
@@ -36,7 +45,7 @@ def get_go2children(go2obj, relationships):
3645
go2children[goid_main] = children_goids
3746
return go2children
3847

39-
# ------------------------------------------------------------------------------------
48+
4049
def get_go2parents_isa(go2obj):
4150
"""Get set of immediate parents GO IDs"""
4251
go2parents = {}
@@ -46,7 +55,7 @@ def get_go2parents_isa(go2obj):
4655
go2parents[goid_main] = parents_goids
4756
return go2parents
4857

49-
# ------------------------------------------------------------------------------------
58+
5059
def get_go2children_isa(go2obj):
5160
"""Get set of immediate children GO IDs"""
5261
go2children = {}
@@ -56,84 +65,96 @@ def get_go2children_isa(go2obj):
5665
go2children[goid_main] = children_goids
5766
return go2children
5867

59-
# ------------------------------------------------------------------------------------
68+
6069
def get_go2ancestors(terms, relationships, prt=None):
6170
"""Get GO-to- ancestors (all parents)"""
6271
if not relationships:
6372
if prt is not None:
64-
prt.write('up: is_a\n')
73+
prt.write("up: is_a\n")
6574
return get_id2parents(terms)
6675
if relationships == RELATIONSHIP_SET or relationships is True:
6776
if prt is not None:
68-
prt.write('up: is_a and {Rs}\n'.format(
69-
Rs=' '.join(sorted(RELATIONSHIP_SET))))
77+
prt.write(
78+
"up: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
79+
)
7080
return get_id2upper(terms)
7181
if prt is not None:
72-
prt.write('up: is_a and {Rs}\n'.format(
73-
Rs=' '.join(sorted(relationships))))
82+
prt.write("up: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
7483
return get_id2upperselect(terms, relationships)
7584

85+
7686
def get_go2descendants(terms, relationships, prt=None):
7787
"""Get GO-to- descendants"""
7888
if not relationships:
7989
if prt is not None:
80-
prt.write('down: is_a\n')
90+
prt.write("down: is_a\n")
8191
return get_id2children(terms)
8292
if relationships == RELATIONSHIP_SET or relationships is True:
8393
if prt is not None:
84-
prt.write('down: is_a and {Rs}\n'.format(
85-
Rs=' '.join(sorted(RELATIONSHIP_SET))))
94+
prt.write(
95+
"down: is_a and {Rs}\n".format(Rs=" ".join(sorted(RELATIONSHIP_SET)))
96+
)
8697
return get_id2lower(terms)
8798
if prt is not None:
88-
prt.write('down: is_a and {Rs}\n'.format(
89-
Rs=' '.join(sorted(relationships))))
99+
prt.write("down: is_a and {Rs}\n".format(Rs=" ".join(sorted(relationships))))
90100
return get_id2lowerselect(terms, relationships)
91101

92-
# ------------------------------------------------------------------------------------
102+
93103
def get_go2depth(goobjs, relationships):
94104
"""Get depth of each object"""
95105
if not relationships:
96-
return {o.item_id:o.depth for o in goobjs}
106+
return {o.item_id: o.depth for o in goobjs}
97107
from goatools.godag.reldepth import get_go2reldepth
108+
98109
return get_go2reldepth(goobjs, relationships)
99110

100-
# ------------------------------------------------------------------------------------
111+
101112
def get_id2parents(objs):
102113
"""Get all parent IDs up the hierarchy"""
103114
id2parents = {}
104115
for obj in objs:
105116
_get_id2parents(id2parents, obj.item_id, obj)
106-
return {e:es for e, es in id2parents.items() if es}
117+
return {e: es for e, es in id2parents.items() if es}
118+
107119

108120
def get_id2children(objs):
109121
"""Get all child IDs down the hierarchy"""
110122
id2children = {}
111123
for obj in objs:
112124
_get_id2children(id2children, obj.item_id, obj)
113-
return {e:es for e, es in id2children.items() if es}
125+
return {e: es for e, es in id2children.items() if es}
126+
114127

115128
def get_id2upper(objs):
116129
"""Get all ancestor IDs, including all parents and IDs up all relationships"""
117130
id2upper = {}
118131
for obj in objs:
119132
_get_id2upper(id2upper, obj.item_id, obj)
120-
return {e:es for e, es in id2upper.items() if es}
133+
return {e: es for e, es in id2upper.items() if es}
134+
121135

122136
def get_id2lower(objs):
123137
"""Get all descendant IDs, including all children and IDs down all relationships"""
124138
id2lower = {}
139+
cache = set()
125140
for obj in objs:
126-
_get_id2lower(id2lower, obj.item_id, obj)
127-
return {e:es for e, es in id2lower.items() if es}
141+
item_id = obj.item_id
142+
if item_id in cache:
143+
continue
144+
_get_id2lower(id2lower, obj.item_id, obj, cache)
145+
return {e: es for e, es in id2lower.items() if es}
146+
128147

129148
def get_id2upperselect(objs, relationship_set):
130149
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
131150
return IdToUpperSelect(objs, relationship_set).id2upperselect
132151

152+
133153
def get_id2lowerselect(objs, relationship_set):
134154
"""Get all descendant IDs, including all children and IDs down selected relationships"""
135155
return IdToLowerSelect(objs, relationship_set).id2lowerselect
136156

157+
137158
def get_relationship_targets(item_ids, relationships, id2rec):
138159
"""Get item ID set of item IDs in a relationship target set"""
139160
# Requirements to use this function:
@@ -148,7 +169,7 @@ def get_relationship_targets(item_ids, relationships, id2rec):
148169
reltgt_objs_all.update(reltgt_objs_cur)
149170
return reltgt_objs_all
150171

151-
# ------------------------------------------------------------------------------------
172+
152173
# pylint: disable=too-few-public-methods
153174
class IdToUpperSelect:
154175
"""Get all ancestor IDs, including all parents and IDs up selected relationships"""
@@ -178,6 +199,7 @@ def _get_id2upperselect(self, item_id, item_obj):
178199
id2upperselect[item_id] = parent_ids
179200
return parent_ids
180201

202+
181203
class IdToLowerSelect:
182204
"""Get all descendant IDs, including all children and IDs down selected relationships"""
183205

@@ -206,7 +228,6 @@ def _get_id2lowerselect(self, item_id, item_obj):
206228
id2lowerselect[item_id] = child_ids
207229
return child_ids
208230

209-
# ------------------------------------------------------------------------------------
210231

211232
def _get_id2parents(id2parents, item_id, item_obj):
212233
"""Add the parent item IDs for one item object and their parents."""
@@ -220,6 +241,7 @@ def _get_id2parents(id2parents, item_id, item_obj):
220241
id2parents[item_id] = parent_ids
221242
return parent_ids
222243

244+
223245
def _get_id2children(id2children, item_id, item_obj):
224246
"""Add the child item IDs for one item object and their children."""
225247
if item_id in id2children:
@@ -232,6 +254,7 @@ def _get_id2children(id2children, item_id, item_obj):
232254
id2children[item_id] = child_ids
233255
return child_ids
234256

257+
235258
def _get_id2upper(id2upper, item_id, item_obj):
236259
"""Add the parent item IDs for one item object and their upper."""
237260
if item_id in id2upper:
@@ -244,19 +267,23 @@ def _get_id2upper(id2upper, item_id, item_obj):
244267
id2upper[item_id] = upper_ids
245268
return upper_ids
246269

247-
def _get_id2lower(id2lower, item_id, item_obj):
270+
271+
def _get_id2lower(id2lower, item_id, item_obj, cache: set):
248272
"""Add the lower item IDs for one item object and the objects below them."""
249273
if item_id in id2lower:
250274
return id2lower[item_id]
251275
lower_ids = set()
276+
cache.add(item_id)
252277
for lower_obj in item_obj.get_goterms_lower():
253278
lower_id = lower_obj.item_id
254279
lower_ids.add(lower_id)
255-
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj)
280+
if lower_id in cache:
281+
continue
282+
lower_ids |= _get_id2lower(id2lower, lower_id, lower_obj, cache)
256283
id2lower[item_id] = lower_ids
257284
return lower_ids
258285

259-
# ------------------------------------------------------------------------------------
286+
260287
class CurNHigher:
261288
"""Fill id2obj with item IDs in relationships."""
262289

goatools/nt_utils.py

Lines changed: 27 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,11 +7,13 @@
77
import datetime
88
import collections as cx
99

10+
1011
def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
1112
"""Return a new dict of namedtuples by combining "dicts" of namedtuples or objects."""
1213
assert len(ids) == len(set(ids)), "NOT ALL IDs ARE UNIQUE: {IDs}".format(IDs=ids)
1314
assert len(flds) == len(set(flds)), "DUPLICATE FIELDS: {IDs}".format(
14-
IDs=cx.Counter(flds).most_common())
15+
IDs=cx.Counter(flds).most_common()
16+
)
1517
usr_id_nt = []
1618
# 1. Instantiate namedtuple object
1719
ntobj = cx.namedtuple("Nt", " ".join(flds))
@@ -23,6 +25,7 @@ def get_dict_w_id2nts(ids, id2nts, flds, dflt_null=""):
2325
usr_id_nt.append((item_id, ntobj._make(vals)))
2426
return cx.OrderedDict(usr_id_nt)
2527

28+
2629
def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
2730
"""Return a new list of namedtuples by combining "dicts" of namedtuples or objects."""
2831
combined_nt_list = []
@@ -36,48 +39,61 @@ def get_list_w_id2nts(ids, id2nts, flds, dflt_null=""):
3639
combined_nt_list.append(ntobj._make(vals))
3740
return combined_nt_list
3841

42+
3943
def combine_nt_lists(lists, flds, dflt_null=""):
4044
"""Return a new list of namedtuples by zipping "lists" of namedtuples or objects."""
4145
combined_nt_list = []
4246
# Check that all lists are the same length
4347
lens = [len(lst) for lst in lists]
44-
assert len(set(lens)) == 1, \
45-
"LIST LENGTHS MUST BE EQUAL: {Ls}".format(Ls=" ".join(str(l) for l in lens))
48+
assert len(set(lens)) == 1, "LIST LENGTHS MUST BE EQUAL: {Ls}".format(
49+
Ls=" ".join(str(l) for l in lens)
50+
)
4651
# 1. Instantiate namedtuple object
4752
ntobj = cx.namedtuple("Nt", " ".join(flds))
4853
# 2. Loop through zipped list
4954
for lst0_lstn in zip(*lists):
5055
# 2a. Combine various namedtuples into a single namedtuple
51-
combined_nt_list.append(ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null)))
56+
combined_nt_list.append(
57+
ntobj._make(_combine_nt_vals(lst0_lstn, flds, dflt_null))
58+
)
5259
return combined_nt_list
5360

61+
5462
def wr_py_nts(fout_py, nts, docstring=None, varname="nts"):
5563
"""Save namedtuples into a Python module."""
5664
if nts:
57-
with open(fout_py, 'w') as prt:
65+
with open(fout_py, "w") as prt:
5866
prt.write('"""{DOCSTRING}"""\n\n'.format(DOCSTRING=docstring))
5967
prt.write("# Created: {DATE}\n".format(DATE=str(datetime.date.today())))
6068
prt_nts(prt, nts, varname)
61-
sys.stdout.write(" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py))
69+
sys.stdout.write(
70+
" {N:7,} items WROTE: {PY}\n".format(N=len(nts), PY=fout_py)
71+
)
6272

63-
def prt_nts(prt, nts, varname, spc=' '):
73+
74+
def prt_nts(prt, nts, varname, spc=" "):
6475
"""Print namedtuples into a Python module."""
6576
first_nt = nts[0]
6677
nt_name = type(first_nt).__name__
6778
prt.write("import collections as cx\n\n")
79+
prt.write("import numpy as np\n\n")
6880
prt.write("NT_FIELDS = [\n")
6981
for fld in first_nt._fields:
7082
prt.write('{SPC}"{F}",\n'.format(SPC=spc, F=fld))
7183
prt.write("]\n\n")
72-
prt.write('{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
73-
NtName=nt_name))
84+
prt.write(
85+
'{NtName} = cx.namedtuple("{NtName}", " ".join(NT_FIELDS))\n\n'.format(
86+
NtName=nt_name
87+
)
88+
)
7489
prt.write("# {N:,} items\n".format(N=len(nts)))
7590
prt.write("# pylint: disable=line-too-long\n")
7691
prt.write("{VARNAME} = [\n".format(VARNAME=varname))
7792
for ntup in nts:
7893
prt.write("{SPC}{NT},\n".format(SPC=spc, NT=ntup))
7994
prt.write("]\n")
8095

96+
8197
def get_unique_fields(fld_lists):
8298
"""Get unique namedtuple fields, despite potential duplicates in lists of fields."""
8399
flds = []
@@ -93,6 +109,7 @@ def get_unique_fields(fld_lists):
93109
assert len(flds) == len(fld_set)
94110
return flds
95111

112+
96113
# -- Internal methods ----------------------------------------------------------------
97114
def _combine_nt_vals(lst0_lstn, flds, dflt_null):
98115
"""Given a list of lists of nts, return a single namedtuple."""
@@ -110,4 +127,5 @@ def _combine_nt_vals(lst0_lstn, flds, dflt_null):
110127
vals.append(dflt_null)
111128
return vals
112129

130+
113131
# Copyright (C) 2016-2018, DV Klopfenstein, H Tang. All rights reserved.

tests/test_dcnt_r01.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,8 @@
55
import sys
66
import timeit
77
import numpy as np
8+
import pytest
9+
810
from numpy.random import shuffle
911
from scipy import stats
1012

@@ -14,6 +16,7 @@
1416
from goatools.obo_parser import GODag
1517

1618

19+
@pytest.mark.skip(reason="Latest obo (`releases/2024-06-10`) is not DAG")
1720
def test_go_pools():
1821
"""Print a comparison of GO terms from different species in two different comparisons."""
1922
objr = _Run()

0 commit comments

Comments
 (0)