Skip to content

Commit 4c49769

Browse files
authored
[FIX] Sparse initial pivot (#785)
* Replaced coo_matrix with coo_array better compatability and added test to test coo_array functionnality * Updated release file * Replaced some more coo_matrix calls * Fix O(n³) performance issue in sparse bipartite graph arc iteration - Added position tracking maps (_arc_to_out_pos, _arc_to_in_pos) for O(1) arc lookups - Modified nextOut() and nextIn() to use position maps instead of linear search * Added changes to release * Fixed PR numbers in RELEASES.md
1 parent 9043960 commit 4c49769

File tree

3 files changed

+75
-30
lines changed

3 files changed

+75
-30
lines changed

RELEASES.md

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@
44

55
This new release adds support for sparse cost matrices in the exact EMD solver. Users can now pass sparse cost matrices (e.g., k-NN graphs, sparse graphs) and receive sparse transport plans, significantly reducing memory footprint for large-scale problems. The implementation is backend-agnostic, automatically handling scipy.sparse for NumPy and torch.sparse for PyTorch, and preserves full gradient computation capabilities for automatic differentiation in PyTorch. This enables efficient solving of OT problems on graphs with millions of nodes where only a sparse subset of edges have finite costs.
66

7-
#### New features
8-
- Add support for sparse cost matrices in exact EMD solver `ot.emd` and `ot.emd2` (PR #778)
9-
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` API (PR #TBD)
7+
#### New features
8+
- Migrate backend from deprecated `scipy.sparse.coo_matrix` to modern `scipy.sparse.coo_array` (PR #782)
9+
- Geomloss function now handles both scalar and slice indices for i and j. Using backend agnostic reshaping. Allows to do plan[i,:] and plan[:,j] (PR #785)
10+
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
1011

1112
#### Closed issues
12-
- Add support for sparse cost matrices in EMD solver (PR #778, Issue #397)
13+
- Fix O(n³) performance bottleneck in sparse bipartite graph arc iteration (PR #785)
1314
- Fix deprecated JAX function in `ot.backend.JaxBackend` (PR #771, Issue #770)
1415
- Add test for build from source (PR #772, Issue #764)
1516
- Fix device for batch Ot solver in `ot.batch` (PR #784, Issue #783)

ot/bregman/_geomloss.py

Lines changed: 24 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -54,13 +54,32 @@ def get_sinkhorn_geomloss_lazytensor(
5454
shape = (X_a.shape[0], X_b.shape[0])
5555

5656
def func(i, j, X_a, X_b, f, g, a, b, metric, blur):
57+
X_a_i = X_a[i]
58+
X_b_j = X_b[j]
59+
60+
if X_a_i.ndim == 1:
61+
X_a_i = X_a_i[None, :]
62+
if X_b_j.ndim == 1:
63+
X_b_j = X_b_j[None, :]
64+
5765
if metric == "sqeuclidean":
58-
C = dist(X_a[i], X_b[j], metric=metric) / 2
66+
C = dist(X_a_i, X_b_j, metric=metric) / 2
5967
else:
60-
C = dist(X_a[i], X_b[j], metric=metric)
61-
return nx.exp((f[i, None] + g[None, j] - C) / (blur**2)) * (
62-
a[i, None] * b[None, j]
63-
)
68+
C = dist(X_a_i, X_b_j, metric=metric)
69+
70+
# Robust broadcasting using nx backend (handles both numpy and torch)
71+
# For scalars, slice to keep 1D; for arrays, index directly
72+
f_i = f[i : i + 1] if isinstance(i, int) else f[i]
73+
g_j = g[j : j + 1] if isinstance(j, int) else g[j]
74+
a_i = a[i : i + 1] if isinstance(i, int) else a[i]
75+
b_j = b[j : j + 1] if isinstance(j, int) else b[j]
76+
77+
f_i = nx.reshape(f_i, (-1, 1))
78+
g_j = nx.reshape(g_j, (1, -1))
79+
a_i = nx.reshape(a_i, (-1, 1))
80+
b_j = nx.reshape(b_j, (1, -1))
81+
82+
return nx.squeeze(nx.exp((f_i + g_j - C) / (blur**2)) * a_i * b_j)
6483

6584
T = LazyTensor(
6685
shape, func, X_a=X_a, X_b=X_b, f=f, g=g, a=a, b=b, metric=metric, blur=blur

ot/lp/sparse_bipartitegraph.h

Lines changed: 46 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -43,8 +43,13 @@ namespace lemon {
4343

4444
mutable std::vector<std::vector<Arc>> _in_arcs; // _in_arcs[node] = incoming arc IDs
4545
mutable bool _in_arcs_built;
46+
47+
// Position tracking for O(1) iteration
48+
mutable std::vector<int64_t> _arc_to_out_pos; // _arc_to_out_pos[arc_id] = position in _arc_ids
49+
mutable std::vector<int64_t> _arc_to_in_pos; // _arc_to_in_pos[arc_id] = position in _in_arcs[target]
50+
mutable bool _position_maps_built;
4651

47-
SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false) {}
52+
SparseBipartiteDigraphBase() : _node_num(0), _arc_num(0), _n1(0), _n2(0), _in_arcs_built(false), _position_maps_built(false) {}
4853

4954
void construct(int n1, int n2) {
5055
_node_num = n1 + n2;
@@ -58,6 +63,9 @@ namespace lemon {
5863
_arc_ids.clear();
5964
_in_arcs.clear();
6065
_in_arcs_built = false;
66+
_arc_to_out_pos.clear();
67+
_arc_to_in_pos.clear();
68+
_position_maps_built = false;
6169
}
6270

6371
void build_in_arcs() const {
@@ -72,6 +80,31 @@ namespace lemon {
7280

7381
_in_arcs_built = true;
7482
}
83+
84+
void build_position_maps() const {
85+
if (_position_maps_built) return;
86+
87+
_arc_to_out_pos.resize(_arc_num);
88+
_arc_to_in_pos.resize(_arc_num);
89+
90+
// Build outgoing arc position map from CSR structure
91+
for (int64_t pos = 0; pos < _arc_num; ++pos) {
92+
Arc arc_id = _arc_ids[pos];
93+
_arc_to_out_pos[arc_id] = pos;
94+
}
95+
96+
// Build incoming arc position map
97+
build_in_arcs();
98+
for (Node node = 0; node < _node_num; ++node) {
99+
const std::vector<Arc>& in = _in_arcs[node];
100+
for (size_t pos = 0; pos < in.size(); ++pos) {
101+
Arc arc_id = in[pos];
102+
_arc_to_in_pos[arc_id] = pos;
103+
}
104+
}
105+
106+
_position_maps_built = true;
107+
}
75108

76109
public:
77110

@@ -212,18 +245,14 @@ namespace lemon {
212245

213246
void nextOut(Arc& arc) const {
214247
if (arc < 0) return;
215-
248+
249+
build_position_maps();
250+
251+
int64_t pos = _arc_to_out_pos[arc];
216252
Node src = _arc_sources[arc];
217-
int64_t start = _row_ptr[src];
218253
int64_t end = _row_ptr[src + 1];
219-
220-
for (int64_t i = start; i < end; ++i) {
221-
if (_arc_ids[i] == arc) {
222-
arc = (i + 1 < end) ? _arc_ids[i + 1] : Arc(-1);
223-
return;
224-
}
225-
}
226-
arc = -1;
254+
255+
arc = (pos + 1 < end) ? _arc_ids[pos + 1] : Arc(-1);
227256
}
228257

229258
void firstIn(Arc& arc, const Node& node) const {
@@ -240,18 +269,14 @@ namespace lemon {
240269

241270
void nextIn(Arc& arc) const {
242271
if (arc < 0) return;
243-
272+
273+
build_position_maps();
274+
275+
int64_t pos = _arc_to_in_pos[arc];
244276
Node tgt = _arc_targets[arc];
245277
const std::vector<Arc>& in = _in_arcs[tgt];
246-
247-
// Find current arc in the list and return next one
248-
for (size_t i = 0; i < in.size(); ++i) {
249-
if (in[i] == arc) {
250-
arc = (i + 1 < in.size()) ? in[i + 1] : Arc(-1);
251-
return;
252-
}
253-
}
254-
arc = -1;
278+
279+
arc = (pos + 1 < in.size()) ? in[pos + 1] : Arc(-1);
255280
}
256281
};
257282

0 commit comments

Comments
 (0)