diff --git a/biasanalyzer/cohort.py b/biasanalyzer/cohort.py index f9f8865..00166d7 100644 --- a/biasanalyzer/cohort.py +++ b/biasanalyzer/cohort.py @@ -63,7 +63,10 @@ def get_concept_stats(self, concept_type='condition_occurrence', filter_count=0, vocab=vocab, print_concept_hierarchy=print_concept_hierarchy) return (cohort_stats, - ConceptHierarchy.build_concept_hierarchy_from_results(self.cohort_id, cohort_stats[concept_type])) + ConceptHierarchy.build_concept_hierarchy_from_results(self.cohort_id, concept_type, + cohort_stats[concept_type], + filter_count=filter_count, + vocab=vocab)) def __del__(self): @@ -162,7 +165,10 @@ def get_cohorts_concept_stats(self, cohorts: List[int], filter_count=filter_count, vocab=vocab) for c in cohorts] - hierarchies = [ConceptHierarchy.build_concept_hierarchy_from_results(c, c_stats.get(concept_type, [])) + hierarchies = [ConceptHierarchy.build_concept_hierarchy_from_results(c, concept_type, + c_stats.get(concept_type, []), + filter_count=filter_count, + vocab=vocab) for c, c_stats in zip(cohorts, cohort_concept_stats)] return reduce(lambda h1, h2: h1.union(h2), hierarchies).to_dict() diff --git a/biasanalyzer/concept.py b/biasanalyzer/concept.py index a71c044..1c98504 100644 --- a/biasanalyzer/concept.py +++ b/biasanalyzer/concept.py @@ -86,15 +86,20 @@ def _normalize_identifier(identifier: str) -> str: return "+".join(parts) @classmethod - def build_concept_hierarchy_from_results(cls, cohort_id: int, results: List[dict]): + def build_concept_hierarchy_from_results(cls, cohort_id: int, concept_type: str, results: List[dict], + filter_count=0, vocab=None): """ build concept hierarchy tree managed by networkx from list of dicts returned from the concept prevalence SQL - with cache management + with cache management. cohort_id, concept_type, and filter_count are used for caching to uniquely identify + a cached concept hierarchy for the specified cohort_id, concept_type, and filter_count. :param results: list of dicts from prevalence SQL :param cohort_id: cohort id to get concept hierarchy for + :param concept_type: concept_type to get concept hierarchy for + :param filer_count: filter_count to get concept hierarchy for with default value 0 meaning no filtering + :param vocab: vocab to get concept hierarchy for with default value None meaning default vocab will be used :return: ConceptHierarchy object """ - identifer = str(cohort_id) + identifer = f"{cohort_id}-{concept_type}-{filter_count}-{vocab}" if identifer in cls._graph_cache: return cls._graph_cache[identifer] @@ -117,7 +122,7 @@ def build_concept_hierarchy_from_results(cls, cohort_id: int, results: List[dict graph = nx.DiGraph() # add nodes with metadata + metrics for cid, meta in node_metadata.items(): - graph.add_node(cid, **meta, metrics={identifer: metrics_by_concept[cid]}) + graph.add_node(cid, **meta, metrics={str(cohort_id): metrics_by_concept[cid]}) # add parent-child edges for row in results: diff --git a/tests/query_based/test_hierarchical_prevalence.py b/tests/query_based/test_hierarchical_prevalence.py index cba79b1..a12d173 100644 --- a/tests/query_based/test_hierarchical_prevalence.py +++ b/tests/query_based/test_hierarchical_prevalence.py @@ -62,15 +62,24 @@ def test_identifier_normalization_and_cache(): assert ConceptHierarchy._normalize_identifier("1+2+2") == "1+2" # fake minimal results to build hierarchy - results = [ + results1 = [ {"ancestor_concept_id": 1, "descendant_concept_id": 1, "concept_name": "Diabetes", "concept_code": "DIA", "count_in_cohort": 5, "prevalence": 0.5} ] - h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) - h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) - assert h1 is h2 # cache reuse - assert h1.identifier == "1" + results2 = [ + {"ancestor_concept_id": 1, "descendant_concept_id": 1, + "concept_name": "Diabetes2", "concept_code": "DIA", + "count_in_cohort": 15, "prevalence": 0.15} + ] + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results1) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results2) + assert h1 is h2 # cache reuse even though results2 is different from results1 + assert h1.identifier == "1-condition_occurrence-0-None" + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'drug_exposure', results2) + assert not h1 is h2 # cache is not used since drug_exposure concept_name is different than the cached + # condition_occurrence + assert h2.identifier == "1-drug_exposure-0-None" def test_union_and_cache_behavior(): ConceptHierarchy.clear_cache() @@ -85,14 +94,14 @@ def test_union_and_cache_behavior(): "count_in_cohort": 3, "prevalence": 0.3} ] - h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, results1) - h2 = ConceptHierarchy.build_concept_hierarchy_from_results(2, results2) - assert "1" in ConceptHierarchy._graph_cache - assert "2" in ConceptHierarchy._graph_cache + h1 = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results1) + h2 = ConceptHierarchy.build_concept_hierarchy_from_results(2, 'condition_occurrence', results2) + assert "1-condition_occurrence-0-None" in ConceptHierarchy._graph_cache + assert "2-condition_occurrence-0-None" in ConceptHierarchy._graph_cache h12 = h1.union(h2) h21 = h2.union(h1) - assert h12.identifier == "1+2" - assert h21.identifier == "1+2" + assert h12.identifier == "1-condition_occurrence-0-None+2-condition_occurrence-0-None" + assert h21.identifier == "1-condition_occurrence-0-None+2-condition_occurrence-0-None" assert h12 is h21 def test_traversal_and_serialization(): @@ -105,7 +114,7 @@ def test_traversal_and_serialization(): "concept_name": "Child", "concept_code": "C", "count_in_cohort": 2, "prevalence": 0.2} ] - h = ConceptHierarchy.build_concept_hierarchy_from_results(1, results) + h = ConceptHierarchy.build_concept_hierarchy_from_results(1, 'condition_occurrence', results) # roots roots = h.get_root_nodes()