Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 27 additions & 0 deletions .github/workflows/tests_workflow.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
name: tests_workflow

# execute this workflow automatically when a we push to any branch
on: [push]

jobs:
tests:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
pip install pytest
- name: Install PARC
run: |
pip install .
- name: Run Python tests -x
run: |
pytest tests/
continue-on-error: false

concurrency:
group: ci-${{ github.ref }}
cancel-in-progress: true
105 changes: 62 additions & 43 deletions parc/_parc.py
Original file line number Diff line number Diff line change
Expand Up @@ -113,24 +113,24 @@ def __init__(
self,
x_data: np.ndarray,
y_data_true: np.ndarray | None = None,
knn: int = 30,
n_iter_leiden: int = 5,
random_seed: int = 42,
distance_metric: str = "l2",
n_threads: int = -1,
hnsw_param_ef_construction: int = 150,
neighbor_graph: csr_matrix | None = None,
knn_struct: hnswlib.Index | None = None,
l2_std_factor: float = 3,
jac_threshold_type: str = "median",
jac_std_factor: float = 0.15,
jac_weighted_edges: bool = True,
do_prune_local: bool | None = None,
large_community_factor: float = 0.4,
small_community_size: int = 10,
jac_weighted_edges: bool = True,
knn: int = 30,
n_iter_leiden: int = 5,
random_seed: int = 42,
n_threads: int = -1,
distance_metric: str = "l2",
small_community_timeout: float = 15,
partition_type: str = "ModularityVP",
resolution_parameter: float = 1.0,
knn_struct: hnswlib.Index | None = None,
neighbor_graph: csr_matrix | None = None,
hnsw_param_ef_construction: int = 150
partition_type: str = "ModularityVP"
):
self.x_data = x_data
self.y_data_true = y_data_true
Expand Down Expand Up @@ -573,7 +573,10 @@ def run_toobig_subPARC(
distance_metric="l2"
)
if n_samples <= 10:
logger.message("Consider increasing the large_community_factor")
logger.message(
f"Large community is small with only {n_samples} nodes. "
f"Consider increasing the large_community_factor = {self.large_community_factor}."
)
if n_samples > self.knn:
knnbig = self.knn
else:
Expand Down Expand Up @@ -601,6 +604,7 @@ def run_toobig_subPARC(
small_cluster_list = []
small_community_exists = False
node_communities = np.unique(list(node_communities.flatten()), return_inverse=True)[1]
logger.message("Stating small community detection...")
for cluster in set(node_communities):
population = len(np.where(node_communities == cluster)[0])
if population < self.small_community_size:
Expand All @@ -621,8 +625,7 @@ def run_toobig_subPARC(
node_communities[single_cell] = best_group

time_start = time.time()
logger.message("Handling fragments...")
while small_community_exists & (time.time() - time_start < self.small_community_timeout):
while small_community_exists and (time.time() - time_start < self.small_community_timeout):
small_pop_list = []
small_community_exists = False
for cluster in set(list(node_communities.flatten())):
Expand Down Expand Up @@ -664,7 +667,7 @@ def run_parc(self):
neighbor_array = np.split(csr_array.indices, csr_array.indptr)[1:-1]
else:
if self.knn_struct is None:
logger.message("knn struct was not available, creating new one")
logger.message("Creating knn_struct...")
if n_samples < 10000:
ef_query = min(n_samples - 10, 500)
else:
Expand Down Expand Up @@ -702,26 +705,36 @@ def run_parc(self):

# The 0th cluster is the largest one.
# So, if cluster 0 is not too big, then the others won't be too big either
community_indices = np.where(node_communities == 0)[0]
large_community_id = 0
community_indices = np.where(node_communities == large_community_id)[0]
community_size = len(community_indices)
if community_size > large_community_factor * n_samples: # 0.4

if community_size > large_community_factor * n_samples:
logger.message(
f"\nCommunity 0 is too large and has size:\n"
f"{community_size} > large_community_factor * n_samples = "
f"{large_community_factor} * {n_samples} = {large_community_factor * n_samples}\n"
f"Starting large community expansion..."
)
too_big = True
large_community_indices = community_indices
list_pop_too_bigs = [community_size]
else:
logger.message(
f"\nCommunity 0 is not too large and has size:\n"
f"{community_size} <= large_community_factor * n_samples = "
f"{large_community_factor} * {n_samples} = {large_community_factor * n_samples}\n"
"Skipping large community expansion."
)

while too_big:
logger.message(f"Expanding large community {large_community_id}...")
node_communities_big = self.run_toobig_subPARC(
x_data=x_data[large_community_indices, :]
)
node_communities_big = node_communities_big + 100000
pop_list = []

for item in set(list(node_communities_big.flatten())):
pop_list.append([item, list(node_communities_big.flatten()).count(item)])

logger.message(f"pop of big clusters {pop_list}")
jj = 0
logger.message(f"shape node_communities {node_communities.shape}")
for j in large_community_indices:
node_communities[j] = node_communities_big[jj]
jj = jj + 1
Expand All @@ -730,39 +743,39 @@ def run_parc(self):
)[1]

too_big = False
set_node_communities = set(node_communities)
logger.message(f"New set of labels {set_node_communities}")

node_communities = np.asarray(node_communities)
for community_id in set_node_communities:
for community_id in set(node_communities):
community_indices = np.where(node_communities == community_id)[0]
community_size = len(community_indices)
not_yet_expanded = community_size not in list_pop_too_bigs
if community_size > large_community_factor * n_samples and not_yet_expanded:
too_big = True
logger.message(
f"Cluster {community_id} is too big and has population {community_size}."
f"Community {community_id} is too big and has population {community_size}."
)
large_community_indices = community_indices
large_community_id = community_id
large_community_size = community_size
if too_big:
list_pop_too_bigs.append(large_community_size)
logger.message(
f"Cluster {large_community_id} is too big and has population "
f"Community {large_community_id} is too big and has population "
f"{large_community_size}. It will be expanded."
)
node_communities = np.unique(list(node_communities.flatten()), return_inverse=True)[1]

logger.message("Starting small community detection...")
small_pop_list = []
small_cluster_list = []
small_community_exists = False

for cluster in set(node_communities):
population = len(np.where(node_communities == cluster)[0])

if population < small_community_size: # 10
if population < small_community_size:
logger.message(
f"Community {cluster} is a small community with population {population}"
)
small_community_exists = True

small_pop_list.append(list(np.where(node_communities == cluster)[0]))
small_cluster_list.append(cluster)

Expand All @@ -779,14 +792,16 @@ def run_parc(self):
best_group = max(available_neighbours_list, key=available_neighbours_list.count)
node_communities[single_cell] = best_group
time_start_sc = time.time()
while small_community_exists & (time.time() - time_start_sc) < self.small_community_timeout:
while small_community_exists and (time.time() - time_start_sc) < self.small_community_timeout:
small_pop_list = []
small_community_exists = False
for cluster in set(list(node_communities.flatten())):
population = len(np.where(node_communities == cluster)[0])
if population < small_community_size:
logger.info(
f"Community {cluster} is a small community with population {population}"
)
small_community_exists = True
logger.message(f"Cluster {cluster} has small population of {population}.")
small_pop_list.append(np.where(node_communities == cluster)[0])
for small_cluster in small_pop_list:
for single_cell in small_cluster:
Expand All @@ -800,8 +815,8 @@ def run_parc(self):
node_communities = list(node_communities.flatten())
pop_list = []
for item in set(node_communities):
pop_list.append((item, node_communities.count(item)))
logger.message(f"Cluster labels and populations {len(pop_list)} {pop_list}")
pop_list.append((int(item), node_communities.count(item)))
logger.message(f"Community labels and sizes: {pop_list}")

self.y_data_pred = node_communities
run_time = time.time() - time_start
Expand Down Expand Up @@ -830,10 +845,10 @@ def accuracy(self, target=1):
vals = [t for t in Index_dict[kk]]
majority_val = get_mode(vals)
if majority_val == target:
logger.message(f"Cluster {kk} has majority {target} with population {len(vals)}")
logger.info(f"Cluster {kk} has majority {target} with population {len(vals)}")
if kk == -1:
len_unknown = len(vals)
logger.message(f"len unknown: {len_unknown}")
logger.info(f"len unknown: {len_unknown}")
if (majority_val == target) and (kk != -1):
positive_labels.append(kk)
fp = fp + len([e for e in vals if e != target])
Expand Down Expand Up @@ -903,11 +918,11 @@ def compute_performance_metrics(self, run_time: float):
f1_accumulated = 0
f1_acc_noweighting = 0
for target in targets:
logger.message(f"Target is {target}")
logger.info(f"Target is {target}")
vals_roc, predict_class_array, majority_truth_labels, numclusters_targetval = \
self.accuracy(target=target)
f1_current = vals_roc[1]
logger.message(f"Target {target} has f1-score of {(f1_current * 100):.2f}")
logger.info(f"Target {target} has f1-score of {(f1_current * 100):.2f}")
f1_accumulated = f1_accumulated + \
f1_current * (list(self.y_data_true).count(target)) / n_samples
f1_acc_noweighting = f1_acc_noweighting + f1_current
Expand All @@ -920,18 +935,22 @@ def compute_performance_metrics(self, run_time: float):
)

f1_mean = f1_acc_noweighting / len(targets)
logger.message(f"f1-score (unweighted) mean: {(f1_mean * 100):.2f}")
logger.message(f"f1-score weighted (by population): {(f1_accumulated * 100):.2f}")

df_accuracy = pd.DataFrame(
list_roc,
columns=[
"jac_std_factor", "l2_std_factor", "onevsall-target", "error rate",
"jac_std_factor", "l2_std_factor", "target", "error rate",
"f1-score", "tnr", "fnr", "tpr", "fpr", "precision", "recall", "num_groups",
"population of target", "num clusters", "clustering runtime"
"target population", "num clusters", "clustering runtime"
]
)

logger.message(f"f1-score (unweighted) mean: {(f1_mean * 100):.2f}")
logger.message(f"f1-score weighted (by population): {(f1_accumulated * 100):.2f}")
logger.message(
f"\n{df_accuracy[['target', 'f1-score', 'target population', 'num clusters']]}"
)

self.f1_accumulated = f1_accumulated
self.f1_mean = f1_mean
self.stats_df = df_accuracy
Expand Down
2 changes: 1 addition & 1 deletion parc/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
MIN_LEVEL = logging.DEBUG
MESSAGE = 25
logging.addLevelName(MESSAGE, "MESSAGE")
LOGGING_LEVEL = 20
LOGGING_LEVEL = 25


class LogFilter(logging.Filter):
Expand Down
4 changes: 4 additions & 0 deletions pytest.ini
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
[pytest]
log_cli = true
log_cli_level = 25
addopts = -rP
Loading