|
23 | 23 | GB_Matrix_free (&T) ; \ |
24 | 24 | } |
25 | 25 |
|
| 26 | +#define GBI(Ai,p,avlen) ((Ai == NULL) ? ((p) % (avlen)) : Ai [p]) |
| 27 | + |
| 28 | +#define GBB(Ab,p) ((Ab == NULL) ? 1 : Ab [p]) |
| 29 | + |
| 30 | +#define GBP(Ap,k,avlen) ((Ap == NULL) ? ((k) * (avlen)) : Ap [k]) |
| 31 | + |
| 32 | +#define GBH(Ah,k) ((Ah == NULL) ? (k) : Ah [k]) |
| 33 | + |
26 | 34 | #include "kronecker/GB_kron.h" |
27 | 35 | #include "mxm/GB_mxm.h" |
28 | 36 | #include "transpose/GB_transpose.h" |
29 | 37 | #include "mask/GB_accum_mask.h" |
30 | 38 |
|
| 39 | +static bool GB_lookup_xoffset ( |
| 40 | + GrB_Index* p, |
| 41 | + GrB_Matrix A, |
| 42 | + GrB_Index row, |
| 43 | + GrB_Index col |
| 44 | +) |
| 45 | +{ |
| 46 | + GrB_Index vector = A->is_csc ? col : row ; |
| 47 | + GrB_Index coord = A->is_csc ? row : col ; |
| 48 | + |
| 49 | + if (A->p == NULL) |
| 50 | + { |
| 51 | + GrB_Index offset = vector * A->vlen + coord ; |
| 52 | + if (A->b == NULL || ((int8_t*)A->b)[offset]) |
| 53 | + { |
| 54 | + *p = A->iso ? 0 : offset ; |
| 55 | + return true ; |
| 56 | + } |
| 57 | + return false ; |
| 58 | + } |
| 59 | + |
| 60 | + int64_t start, end ; |
| 61 | + bool res ; |
| 62 | + |
| 63 | + if (A->h == NULL) |
| 64 | + { |
| 65 | + start = A->p_is_32 ? ((uint32_t*)A->p)[vector] : ((uint64_t*)A->p)[vector] ; |
| 66 | + end = A->p_is_32 ? ((uint32_t*)A->p)[vector + 1] : ((uint64_t*)A->p)[vector + 1] ; |
| 67 | + end-- ; |
| 68 | + if (start > end) return false ; |
| 69 | + res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ; |
| 70 | + if (res) { *p = A->iso ? 0 : start ; } |
| 71 | + return res ; |
| 72 | + } |
| 73 | + else |
| 74 | + { |
| 75 | + start = 0 ; end = A->plen - 1 ; |
| 76 | + res = GB_binary_search(vector, A->h, A->j_is_32, &start, &end) ; |
| 77 | + if (!res) return false ; |
| 78 | + int64_t k = start ; |
| 79 | + start = A->p_is_32 ? ((uint32_t*)A->p)[k] : ((uint64_t*)A->p)[k] ; |
| 80 | + end = A->p_is_32 ? ((uint32_t*)A->p)[k+1] : ((uint64_t*)A->p)[k+1] ; |
| 81 | + end-- ; |
| 82 | + if (start > end) return false ; |
| 83 | + res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ; |
| 84 | + if (res) { *p = A->iso ? 0 : start ; } |
| 85 | + return res ; |
| 86 | + } |
| 87 | +} |
| 88 | + |
| 89 | +#include "emult/GB_emult.h" |
| 90 | + |
31 | 91 | GrB_Info GB_kron // C<M> = accum (C, kron(A,B)) |
32 | 92 | ( |
33 | 93 | GrB_Matrix C, // input/output matrix for results |
@@ -104,6 +164,314 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B)) |
104 | 164 | // quick return if an empty mask is complemented |
105 | 165 | GB_RETURN_IF_QUICK_MASK (C, C_replace, M, Mask_comp, Mask_struct) ; |
106 | 166 |
|
| 167 | + // check if it's possible to apply mask immediately in kron |
| 168 | + // TODO: make MT of same CSR/CSC format as C |
| 169 | + |
| 170 | + GrB_Matrix MT; |
| 171 | + if (M != NULL && !Mask_comp) |
| 172 | + { |
| 173 | + // iterate over mask, count how many elements will be present in MT |
| 174 | + // initialize MT->p |
| 175 | + |
| 176 | + GB_MATRIX_WAIT(M); |
| 177 | + |
| 178 | + size_t allocated = 0 ; |
| 179 | + bool MT_hypersparse = (A->h != NULL) || (B->h != NULL); |
| 180 | + int64_t centries ; |
| 181 | + uint64_t nvecs ; |
| 182 | + centries = 0 ; |
| 183 | + nvecs = 0 ; |
| 184 | + |
| 185 | + uint32_t* MTp32 = NULL ; uint64_t* MTp64 = NULL ; |
| 186 | + MTp32 = M->p_is_32 ? GB_calloc_memory (M->vdim + 1, sizeof(uint32_t), &allocated) : NULL ; |
| 187 | + MTp64 = M->p_is_32 ? NULL : GB_calloc_memory (M->vdim + 1, sizeof(uint64_t), &allocated) ; |
| 188 | + if (MTp32 == NULL && MTp64 == NULL) |
| 189 | + { |
| 190 | + OUT_OF_MEM_p: |
| 191 | + GB_FREE_WORKSPACE ; |
| 192 | + return GrB_OUT_OF_MEMORY ; |
| 193 | + } |
| 194 | + |
| 195 | + GrB_Type MTtype = op->ztype ; |
| 196 | + const size_t MTsize = MTtype->size ; |
| 197 | + GB_void MTscalar [GB_VLA(MTsize)] ; |
| 198 | + bool MTiso = GB_emult_iso (MTscalar, MTtype, A, B, op) ; |
| 199 | + |
| 200 | + GB_Mp_DECLARE(Mp, ) ; |
| 201 | + GB_Mp_PTR(Mp, M) ; |
| 202 | + |
| 203 | + GB_Mh_DECLARE(Mh, ) ; |
| 204 | + GB_Mh_PTR(Mh, M) ; |
| 205 | + |
| 206 | + GB_Mi_DECLARE(Mi, ) ; |
| 207 | + GB_Mi_PTR(Mi, M) ; |
| 208 | + |
| 209 | + GB_cast_function cast_A = NULL ; |
| 210 | + GB_cast_function cast_B = NULL ; |
| 211 | + |
| 212 | + cast_A = GB_cast_factory (op->xtype->code, A->type->code) ; |
| 213 | + cast_B = GB_cast_factory (op->ytype->code, B->type->code) ; |
| 214 | + |
| 215 | + int64_t vlen = M->vlen ; |
| 216 | + #pragma omp parallel |
| 217 | + { |
| 218 | + GrB_Index offset ; |
| 219 | + |
| 220 | + #pragma omp for reduction(+:nvecs) |
| 221 | + for (GrB_Index k = 0 ; k < M->nvec ; k++) |
| 222 | + { |
| 223 | + GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ; |
| 224 | + |
| 225 | + int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ; |
| 226 | + int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ; |
| 227 | + bool nonempty = false ; |
| 228 | + for (GrB_Index p = pA_start ; p < pA_end ; p++) |
| 229 | + { |
| 230 | + if (!GBB (M->b, p)) continue ; |
| 231 | + |
| 232 | + int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ; |
| 233 | + GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ; |
| 234 | + |
| 235 | + // extract elements from A and B, increment MTp |
| 236 | + |
| 237 | + if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) |
| 238 | + { |
| 239 | + GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows); |
| 240 | + GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols); |
| 241 | + |
| 242 | + GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows); |
| 243 | + GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols); |
| 244 | + |
| 245 | + bool code = GB_lookup_xoffset(&offset, A, arow, acol) ; |
| 246 | + if (!code) |
| 247 | + { |
| 248 | + continue; |
| 249 | + } |
| 250 | + |
| 251 | + code = GB_lookup_xoffset(&offset, B, brow, bcol) ; |
| 252 | + if (!code) |
| 253 | + { |
| 254 | + continue; |
| 255 | + } |
| 256 | + |
| 257 | + if (M->p_is_32) |
| 258 | + { |
| 259 | + (MTp32[j])++ ; |
| 260 | + } |
| 261 | + else |
| 262 | + { |
| 263 | + (MTp64[j])++ ; |
| 264 | + } |
| 265 | + nonempty = true ; |
| 266 | + } |
| 267 | + } |
| 268 | + if (nonempty) nvecs++ ; |
| 269 | + } |
| 270 | + } |
| 271 | + |
| 272 | + // GB_cumsum for MT->p |
| 273 | + |
| 274 | + double work = M->vdim ; |
| 275 | + int nthreads_max = GB_Context_nthreads_max ( ) ; |
| 276 | + double chunk = GB_Context_chunk ( ) ; |
| 277 | + int cumsum_threads = GB_nthreads (work, chunk, nthreads_max) ; |
| 278 | + M->p_is_32 ? GB_cumsum(MTp32, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) : |
| 279 | + GB_cumsum(MTp64, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) ; |
| 280 | + |
| 281 | + centries = M->p_is_32 ? MTp32[M->vdim] : MTp64[M->vdim] ; |
| 282 | + |
| 283 | + uint32_t* MTi32 = NULL ; uint64_t* MTi64 = NULL; |
| 284 | + MTi32 = M->i_is_32 ? GB_malloc_memory (centries, sizeof(uint32_t), &allocated) : NULL ; |
| 285 | + MTi64 = M->i_is_32 ? NULL : GB_malloc_memory (centries, sizeof(uint64_t), &allocated) ; |
| 286 | + |
| 287 | + if (centries > 0 && MTi32 == NULL && MTi64 == NULL) |
| 288 | + { |
| 289 | + OUT_OF_MEM_i: |
| 290 | + if (M->p_is_32) { GB_free_memory (&MTp32, (M->vdim + 1) * sizeof(uint32_t)) ; } |
| 291 | + else { GB_free_memory (&MTp64, (M->vdim + 1) * sizeof(uint64_t)) ; } |
| 292 | + goto OUT_OF_MEM_p ; |
| 293 | + } |
| 294 | + |
| 295 | + void* MTx = NULL ; |
| 296 | + if (!MTiso) |
| 297 | + { |
| 298 | + MTx = GB_malloc_memory (centries, op->ztype->size, &allocated) ; |
| 299 | + } |
| 300 | + else |
| 301 | + { |
| 302 | + MTx = GB_malloc_memory (1, op->ztype->size, &allocated) ; |
| 303 | + if (MTx == NULL) goto OUT_OF_MEM_x ; |
| 304 | + memcpy (MTx, MTscalar, MTsize) ; |
| 305 | + } |
| 306 | + |
| 307 | + if (centries > 0 && MTx == NULL) |
| 308 | + { |
| 309 | + OUT_OF_MEM_x: |
| 310 | + if (M->i_is_32) { GB_free_memory (&MTi32, centries * sizeof(uint32_t)) ; } |
| 311 | + else { GB_free_memory (&MTi64, centries * sizeof (uint64_t)) ; } |
| 312 | + goto OUT_OF_MEM_i ; |
| 313 | + } |
| 314 | + |
| 315 | + #pragma omp parallel |
| 316 | + { |
| 317 | + GrB_Index offset ; |
| 318 | + GB_void a_elem[op->xtype->size] ; |
| 319 | + GB_void b_elem[op->ytype->size] ; |
| 320 | + |
| 321 | + #pragma omp for |
| 322 | + for (GrB_Index k = 0 ; k < M->nvec ; k++) |
| 323 | + { |
| 324 | + GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ; |
| 325 | + |
| 326 | + int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ; |
| 327 | + int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ; |
| 328 | + GrB_Index pos = M->p_is_32 ? MTp32[j] : MTp64[j] ; |
| 329 | + for (GrB_Index p = pA_start ; p < pA_end ; p++) |
| 330 | + { |
| 331 | + if (!GBB (M->b, p)) continue ; |
| 332 | + |
| 333 | + int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ; |
| 334 | + GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ; |
| 335 | + |
| 336 | + // extract elements from A and B, |
| 337 | + // initialize offset in MTi and MTx, |
| 338 | + // get result of op, place it in MTx |
| 339 | + |
| 340 | + if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p])) |
| 341 | + { |
| 342 | + GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows); |
| 343 | + GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols); |
| 344 | + |
| 345 | + GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows); |
| 346 | + GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols); |
| 347 | + |
| 348 | + bool code = GB_lookup_xoffset (&offset, A, arow, acol) ; |
| 349 | + if (!code) |
| 350 | + { |
| 351 | + continue; |
| 352 | + } |
| 353 | + if (!MTiso) |
| 354 | + cast_A (a_elem, A->x + offset * A->type->size, A->type->size) ; |
| 355 | + |
| 356 | + code = GB_lookup_xoffset (&offset, B, brow, bcol) ; |
| 357 | + if (!code) |
| 358 | + { |
| 359 | + continue; |
| 360 | + } |
| 361 | + if (!MTiso) |
| 362 | + cast_B (b_elem, B->x + offset * B->type->size, B->type->size) ; |
| 363 | + |
| 364 | + if (!MTiso) |
| 365 | + { |
| 366 | + if (op->binop_function) |
| 367 | + { |
| 368 | + op->binop_function (MTx + op->ztype->size * pos, a_elem, b_elem) ; |
| 369 | + } |
| 370 | + else |
| 371 | + { |
| 372 | + GrB_Index ix, iy, jx, jy ; |
| 373 | + ix = A_transpose ? acol : arow ; |
| 374 | + iy = A_transpose ? arow : acol ; |
| 375 | + jx = B_transpose ? bcol : brow ; |
| 376 | + jy = B_transpose ? brow : bcol ; |
| 377 | + op->idxbinop_function (MTx + op->ztype->size * pos, a_elem, ix, iy, |
| 378 | + b_elem, jx, jy, op->theta) ; |
| 379 | + } |
| 380 | + } |
| 381 | + |
| 382 | + if (M->i_is_32) { MTi32[pos] = i ; } else { MTi64[pos] = i ; } |
| 383 | + pos++ ; |
| 384 | + } |
| 385 | + } |
| 386 | + } |
| 387 | + } |
| 388 | + |
| 389 | + #undef GBI |
| 390 | + #undef GBB |
| 391 | + #undef GBP |
| 392 | + #undef GBH |
| 393 | + |
| 394 | + // initialize other fields of MT properly |
| 395 | + |
| 396 | + MT = NULL ; |
| 397 | + GrB_Info MTalloc = GB_new_bix (&MT, op->ztype, vlen, M->vdim, GB_ph_null, M->is_csc, |
| 398 | + GxB_SPARSE, true, M->hyper_switch, M->vdim, centries, true, MTiso, |
| 399 | + M->p_is_32, M->j_is_32, M->i_is_32) ; |
| 400 | + if (MTalloc != GrB_SUCCESS) |
| 401 | + { |
| 402 | + if (MTiso) { GB_free_memory (&MTx, op->ztype->size) ; } |
| 403 | + else { GB_free_memory (&MTx, centries * op->ztype->size) ; } |
| 404 | + goto OUT_OF_MEM_x ; |
| 405 | + } |
| 406 | + |
| 407 | + GB_MATRIX_WAIT(MT) ; |
| 408 | + |
| 409 | + GB_free_memory (&MT->i, MT->i_size) ; |
| 410 | + GB_free_memory (&MT->x, MT->x_size) ; |
| 411 | + |
| 412 | + MT->p = M->p_is_32 ? (void*)MTp32 : (void*)MTp64 ; |
| 413 | + MT->i = M->i_is_32 ? (void*)MTi32 : (void*)MTi64 ; |
| 414 | + MT->x = MTx ; |
| 415 | + |
| 416 | + MT->p_size = (M->p_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * (M->vdim + 1) ; |
| 417 | + MT->i_size = ((M->i_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * centries) ; |
| 418 | + MT->x_size = MT->iso ? op->ztype->size : op->ztype->size * centries ; |
| 419 | + MT->magic = GB_MAGIC ; |
| 420 | + MT->nvals = centries ; |
| 421 | + MT->nvec_nonempty = nvecs ; |
| 422 | + |
| 423 | + // transpose and convert to hyper if needed |
| 424 | + |
| 425 | + if (MT->is_csc != C->is_csc) |
| 426 | + { |
| 427 | + GrB_Info MTtranspose = GB_transpose_in_place (MT, true, Werk) ; |
| 428 | + if (MTtranspose != GrB_SUCCESS) |
| 429 | + { |
| 430 | + GB_FREE_WORKSPACE ; |
| 431 | + GB_Matrix_free (&MT) ; |
| 432 | + return MTtranspose ; |
| 433 | + } |
| 434 | + } |
| 435 | + |
| 436 | + if (MT_hypersparse) |
| 437 | + { |
| 438 | + uint32_t* MTh32 = NULL ; uint64_t* MTh64 = NULL ; |
| 439 | + if (MT->j_is_32) |
| 440 | + { |
| 441 | + MTh32 = GB_malloc_memory (MT->vdim, sizeof(uint32_t), &allocated) ; |
| 442 | + } |
| 443 | + else |
| 444 | + { |
| 445 | + MTh64 = GB_malloc_memory (MT->vdim, sizeof(uint64_t), &allocated) ; |
| 446 | + } |
| 447 | + |
| 448 | + if (MTh32 == NULL && MTh64 == NULL) |
| 449 | + { |
| 450 | + GB_FREE_WORKSPACE ; |
| 451 | + GB_Matrix_free (&MT) ; |
| 452 | + return GrB_OUT_OF_MEMORY ; |
| 453 | + } |
| 454 | + |
| 455 | + #pragma omp parallel for |
| 456 | + for (GrB_Index i = 0; i < MT->vdim; i++) |
| 457 | + { |
| 458 | + if (MT->j_is_32) { MTh32[i] = i ; } else { MTh64[i] = i ; } |
| 459 | + } |
| 460 | + |
| 461 | + MT->h = MTh32 ? (void*)MTh32 : (void*)MTh64 ; |
| 462 | + |
| 463 | + GrB_Info MThyperprune = GB_hyper_prune (MT, Werk) ; |
| 464 | + if (MThyperprune != GrB_SUCCESS) |
| 465 | + { |
| 466 | + GB_FREE_WORKSPACE ; |
| 467 | + GB_Matrix_free (&MT) ; |
| 468 | + return MThyperprune ; |
| 469 | + } |
| 470 | + } |
| 471 | + |
| 472 | + return (GB_accum_mask (C, M, NULL, accum, &MT, C_replace, Mask_comp, Mask_struct, Werk)) ; |
| 473 | + } |
| 474 | + |
107 | 475 | //-------------------------------------------------------------------------- |
108 | 476 | // transpose A and B if requested |
109 | 477 | //-------------------------------------------------------------------------- |
@@ -153,7 +521,7 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B)) |
153 | 521 | GB_CLEAR_MATRIX_HEADER (T, &T_header) ; |
154 | 522 | GB_OK (GB_kroner (T, T_is_csc, op, flipij, |
155 | 523 | A_transpose ? AT : A, A_is_pattern, |
156 | | - B_transpose ? BT : B, B_is_pattern, Werk)) ; |
| 524 | + B_transpose ? BT : B, B_is_pattern, M, Mask_comp, Mask_struct, Werk)) ; |
157 | 525 |
|
158 | 526 | GB_FREE_WORKSPACE ; |
159 | 527 | ASSERT_MATRIX_OK (T, "T = kron(A,B)", GB0) ; |
|
0 commit comments