|
94 | 94 | import scipy |
95 | 95 | import scipy.linalg |
96 | 96 | import scipy.special as special |
97 | | -from scipy.sparse import coo_matrix, csr_matrix, issparse |
| 97 | +from scipy.sparse import coo_array, csr_matrix, issparse |
98 | 98 |
|
99 | 99 | DISABLE_TORCH_KEY = "POT_BACKEND_DISABLE_PYTORCH" |
100 | 100 | DISABLE_JAX_KEY = "POT_BACKEND_DISABLE_JAX" |
@@ -802,9 +802,9 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): |
802 | 802 | r""" |
803 | 803 | Creates a sparse tensor in COOrdinate format. |
804 | 804 |
|
805 | | - This function follows the api from :any:`scipy.sparse.coo_matrix` |
| 805 | + This function follows the api from :any:`scipy.sparse.coo_array` |
806 | 806 |
|
807 | | - See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_matrix.html |
| 807 | + See: https://docs.scipy.org/doc/scipy/reference/generated/scipy.sparse.coo_array.html |
808 | 808 | """ |
809 | 809 | raise NotImplementedError() |
810 | 810 |
|
@@ -1354,9 +1354,9 @@ def randperm(self, size, type_as=None): |
1354 | 1354 |
|
1355 | 1355 | def coo_matrix(self, data, rows, cols, shape=None, type_as=None): |
1356 | 1356 | if type_as is None: |
1357 | | - return coo_matrix((data, (rows, cols)), shape=shape) |
| 1357 | + return coo_array((data, (rows, cols)), shape=shape) |
1358 | 1358 | else: |
1359 | | - return coo_matrix((data, (rows, cols)), shape=shape, dtype=type_as.dtype) |
| 1359 | + return coo_array((data, (rows, cols)), shape=shape, dtype=type_as.dtype) |
1360 | 1360 |
|
1361 | 1361 | def issparse(self, a): |
1362 | 1362 | return issparse(a) |
@@ -1384,9 +1384,9 @@ def todense(self, a): |
1384 | 1384 | return a |
1385 | 1385 |
|
1386 | 1386 | def sparse_coo_data(self, a): |
1387 | | - # Convert to COO format if needed |
1388 | | - if not isinstance(a, coo_matrix): |
1389 | | - a_coo = coo_matrix(a) |
| 1387 | + # Convert to COO array format if needed |
| 1388 | + if not isinstance(a, coo_array): |
| 1389 | + a_coo = coo_array(a) |
1390 | 1390 | else: |
1391 | 1391 | a_coo = a |
1392 | 1392 |
|
@@ -1815,9 +1815,7 @@ def sparse_coo_data(self, a): |
1815 | 1815 | # JAX doesn't support sparse matrices, so this shouldn't be called |
1816 | 1816 | # But if it is, convert the dense array to sparse using scipy |
1817 | 1817 | a_np = self.to_numpy(a) |
1818 | | - from scipy.sparse import coo_matrix |
1819 | | - |
1820 | | - a_coo = coo_matrix(a_np) |
| 1818 | + a_coo = coo_array(a_np) |
1821 | 1819 | return a_coo.row, a_coo.col, a_coo.data, a_coo.shape |
1822 | 1820 |
|
1823 | 1821 | def where(self, condition, x=None, y=None): |
@@ -2804,10 +2802,10 @@ def coo_matrix(self, data, rows, cols, shape=None, type_as=None): |
2804 | 2802 | rows = self.from_numpy(rows) |
2805 | 2803 | cols = self.from_numpy(cols) |
2806 | 2804 | if type_as is None: |
2807 | | - return cupyx.scipy.sparse.coo_matrix((data, (rows, cols)), shape=shape) |
| 2805 | + return cupyx.scipy.sparse.coo_array((data, (rows, cols)), shape=shape) |
2808 | 2806 | else: |
2809 | 2807 | with cp.cuda.Device(type_as.device): |
2810 | | - return cupyx.scipy.sparse.coo_matrix( |
| 2808 | + return cupyx.scipy.sparse.coo_array( |
2811 | 2809 | (data, (rows, cols)), shape=shape, dtype=type_as.dtype |
2812 | 2810 | ) |
2813 | 2811 |
|
|
0 commit comments