Skip to content

Commit 002641d

Browse files
committed
bug fix
1 parent 383f2ea commit 002641d

2 files changed

Lines changed: 37 additions & 63 deletions

File tree

snapatac2/tools/_clustering.py

Lines changed: 35 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,18 @@
44
import scipy.sparse as ss
55
import numpy as np
66

7+
import snapatac2
78
import snapatac2._snapatac2 as internal
89
from snapatac2._utils import get_igraph_from_adjacency, is_anndata
910

1011
def leiden(
1112
adata: internal.AnnData | internal.AnnDataSet | ss.spmatrix,
1213
resolution: float = 1,
13-
objective_function: Literal["CPM", "modularity", "RBConfiguration"] = "modularity",
14+
objective_function: Literal["CPM", "modularity"] = "modularity",
1415
min_cluster_size: int = 5,
1516
n_iterations: int = -1,
1617
random_state: int = 0,
1718
key_added: str = "leiden",
18-
use_leidenalg: bool = False,
1919
weighted: bool = False,
2020
inplace: bool = True,
2121
) -> np.ndarray | None:
@@ -39,7 +39,7 @@ def leiden(
3939
to one that doesn't accept a `resolution_parameter`.
4040
objective_function
4141
whether to use the Constant Potts Model (CPM) or modularity.
42-
Must be either "CPM", "modularity" or "RBConfiguration".
42+
Must be either "CPM" or "modularity".
4343
min_cluster_size
4444
The minimum size of clusters.
4545
n_iterations
@@ -50,8 +50,6 @@ def leiden(
5050
Change the initialization of the optimization.
5151
key_added
5252
`adata.obs` key under which to add the cluster labels.
53-
use_leidenalg
54-
If `True`, `leidenalg` package is used. Otherwise, `python-igraph` is used.
5553
weighted
5654
Whether to use the edge weights in the graph
5755
inplace
@@ -64,8 +62,13 @@ def leiden(
6462
dim (number of samples) that stores the subgroup id
6563
(`'0'`, `'1'`, ...) for each cell. Otherwise, returns the array directly.
6664
"""
65+
from igraph import set_random_number_generator
6766
from collections import Counter
6867
import polars
68+
import random
69+
70+
random.seed(random_state)
71+
set_random_number_generator(random)
6972

7073
if is_anndata(adata):
7174
adjacency = adata.obsp["distances"]
@@ -81,43 +84,14 @@ def leiden(
8184
else:
8285
weights = None
8386

84-
if use_leidenalg or objective_function == "RBConfiguration":
85-
import leidenalg
86-
from leidenalg.VertexPartition import MutableVertexPartition
87-
88-
if objective_function == "modularity":
89-
partition_type = leidenalg.ModularityVertexPartition
90-
elif objective_function == "CPM":
91-
partition_type = leidenalg.CPMVertexPartition
92-
elif objective_function == "RBConfiguration":
93-
partition_type = leidenalg.RBConfigurationVertexPartition
94-
else:
95-
raise ValueError("objective function is not supported: " + partition_type)
96-
97-
partition = leidenalg.find_partition(
98-
gr,
99-
partition_type,
100-
n_iterations=n_iterations,
101-
seed=random_state,
102-
resolution_parameter=resolution,
103-
weights=weights,
104-
)
105-
else:
106-
from igraph import set_random_number_generator
107-
import random
108-
109-
random.seed(random_state)
110-
set_random_number_generator(random)
111-
partition = gr.community_leiden(
112-
objective_function=objective_function,
113-
weights=weights,
114-
resolution=resolution,
115-
beta=0.01,
116-
initial_membership=None,
117-
n_iterations=n_iterations,
118-
)
119-
120-
groups = partition.membership
87+
groups = gr.community_leiden(
88+
objective_function=objective_function,
89+
weights=weights,
90+
resolution=resolution,
91+
beta=0.01,
92+
initial_membership=None,
93+
n_iterations=n_iterations,
94+
).membership
12195

12296
new_cl_id = dict(
12397
[
@@ -134,21 +108,14 @@ def leiden(
134108
groups,
135109
dtype=polars.datatypes.Categorical,
136110
)
137-
# store information on the clustering parameters
138-
# adata.uns['leiden'] = {}
139-
# adata.uns['leiden']['params'] = dict(
140-
# resolution=resolution,
141-
# random_state=random_state,
142-
# n_iterations=n_iterations,
143-
# )
144111
else:
145112
return groups
146113

147114

148115
def leiden_sweep(
149116
adata: internal.AnnData | internal.AnnDataSet | ss.spmatrix,
150117
resolutions: list[float],
151-
use_rep: str = "X_spectral",
118+
use_rep: str | np.ndarray = "X_spectral",
152119
objective_function: Literal["CPM", "modularity", "RBConfiguration"] = "modularity",
153120
min_cluster_size: int = 5,
154121
n_iterations: int = -1,
@@ -192,8 +159,12 @@ def leiden_sweep(
192159
from sklearn.metrics import silhouette_score
193160
from multiprocess import get_context
194161

195-
mat = adata.obsm[use_rep]
196-
distances = adata.obsp["distances"]
162+
if is_anndata(adata):
163+
distances = adata.obsp["distances"]
164+
mat = adata.obsm[use_rep]
165+
else:
166+
mat = use_rep
167+
distances = adata
197168

198169
def _func(resolution):
199170
groups = leiden(
@@ -206,16 +177,19 @@ def _func(resolution):
206177
weighted=weighted,
207178
inplace=False,
208179
)
209-
score = silhouette_score(
210-
mat,
211-
groups,
212-
sample_size=20000,
213-
)
180+
if len(set(groups)) > 1:
181+
score = silhouette_score(
182+
mat,
183+
groups,
184+
sample_size=20000,
185+
)
186+
else:
187+
score = 0
214188
return {
215-
"resolution": resolution,
216-
"n_clusters": len(set(groups)),
217-
"silhouette_score": score,
218-
}
189+
"resolution": resolution,
190+
"n_clusters": len(set(groups)),
191+
"silhouette_score": score,
192+
}
219193

220194
with get_context("spawn").Pool(n_jobs) as p:
221195
return list(p.imap(_func, resolutions))

src/embedding.rs

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -275,7 +275,7 @@ impl Nystrom {
275275
(0..num_threads)
276276
.into_par_iter()
277277
.map(|i| {
278-
let start = i * chunk_size;
278+
let start = (i * chunk_size).min(nrows);
279279
let end = ((i + 1) * chunk_size).min(nrows);
280280
let mut qmat = spmm_dense(start, end, &mat, &self.qmat);
281281
let mut q_sum = qmat.row_sum_tr();
@@ -429,7 +429,7 @@ fn normalize(input: &mut CsrMatrix<f64>, feature_weights: &[f64]) {
429429
}
430430

431431
fn spmm_dense(i: usize, j: usize, mat: &CsrMatrix<f64>, dense: &DMatrix<f64>) -> DMatrix<f64> {
432-
assert!(i < j);
432+
assert!(i <= j, "Invalid row range, {}-{}", i, j);
433433
let mut result = DMatrix::zeros(j - i, dense.ncols());
434434
for row_idx in i..j {
435435
let row = mat.row(row_idx);

0 commit comments

Comments
 (0)