Skip to content

Commit 2ecc008

Browse files
committed
feat: mask support in kronecker product for intermediate result matrix
1 parent 2c419fb commit 2ecc008

File tree

5 files changed

+457
-6
lines changed

5 files changed

+457
-6
lines changed

Source/kronecker/GB_kron.c

Lines changed: 369 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,71 @@
2323
GB_Matrix_free (&T) ; \
2424
}
2525

26+
#define GBI(Ai,p,avlen) ((Ai == NULL) ? ((p) % (avlen)) : Ai [p])
27+
28+
#define GBB(Ab,p) ((Ab == NULL) ? 1 : Ab [p])
29+
30+
#define GBP(Ap,k,avlen) ((Ap == NULL) ? ((k) * (avlen)) : Ap [k])
31+
32+
#define GBH(Ah,k) ((Ah == NULL) ? (k) : Ah [k])
33+
2634
#include "kronecker/GB_kron.h"
2735
#include "mxm/GB_mxm.h"
2836
#include "transpose/GB_transpose.h"
2937
#include "mask/GB_accum_mask.h"
3038

39+
static bool GB_lookup_xoffset (
40+
GrB_Index* p,
41+
GrB_Matrix A,
42+
GrB_Index row,
43+
GrB_Index col
44+
)
45+
{
46+
GrB_Index vector = A->is_csc ? col : row ;
47+
GrB_Index coord = A->is_csc ? row : col ;
48+
49+
if (A->p == NULL)
50+
{
51+
GrB_Index offset = vector * A->vlen + coord ;
52+
if (A->b == NULL || ((int8_t*)A->b)[offset])
53+
{
54+
*p = A->iso ? 0 : offset ;
55+
return true ;
56+
}
57+
return false ;
58+
}
59+
60+
int64_t start, end ;
61+
bool res ;
62+
63+
if (A->h == NULL)
64+
{
65+
start = A->p_is_32 ? ((uint32_t*)A->p)[vector] : ((uint64_t*)A->p)[vector] ;
66+
end = A->p_is_32 ? ((uint32_t*)A->p)[vector + 1] : ((uint64_t*)A->p)[vector + 1] ;
67+
end-- ;
68+
if (start > end) return false ;
69+
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
70+
if (res) { *p = A->iso ? 0 : start ; }
71+
return res ;
72+
}
73+
else
74+
{
75+
start = 0 ; end = A->plen - 1 ;
76+
res = GB_binary_search(vector, A->h, A->j_is_32, &start, &end) ;
77+
if (!res) return false ;
78+
int64_t k = start ;
79+
start = A->p_is_32 ? ((uint32_t*)A->p)[k] : ((uint64_t*)A->p)[k] ;
80+
end = A->p_is_32 ? ((uint32_t*)A->p)[k+1] : ((uint64_t*)A->p)[k+1] ;
81+
end-- ;
82+
if (start > end) return false ;
83+
res = GB_binary_search(coord, A->i, A->i_is_32, &start, &end) ;
84+
if (res) { *p = A->iso ? 0 : start ; }
85+
return res ;
86+
}
87+
}
88+
89+
#include "emult/GB_emult.h"
90+
3191
GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
3292
(
3393
GrB_Matrix C, // input/output matrix for results
@@ -104,6 +164,314 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
104164
// quick return if an empty mask is complemented
105165
GB_RETURN_IF_QUICK_MASK (C, C_replace, M, Mask_comp, Mask_struct) ;
106166

167+
// check if it's possible to apply mask immediately in kron
168+
// TODO: make MT of same CSR/CSC format as C
169+
170+
GrB_Matrix MT;
171+
if (M != NULL && !Mask_comp)
172+
{
173+
// iterate over mask, count how many elements will be present in MT
174+
// initialize MT->p
175+
176+
GB_MATRIX_WAIT(M);
177+
178+
size_t allocated = 0 ;
179+
bool MT_hypersparse = (A->h != NULL) || (B->h != NULL);
180+
int64_t centries ;
181+
uint64_t nvecs ;
182+
centries = 0 ;
183+
nvecs = 0 ;
184+
185+
uint32_t* MTp32 = NULL ; uint64_t* MTp64 = NULL ;
186+
MTp32 = M->p_is_32 ? GB_calloc_memory (M->vdim + 1, sizeof(uint32_t), &allocated) : NULL ;
187+
MTp64 = M->p_is_32 ? NULL : GB_calloc_memory (M->vdim + 1, sizeof(uint64_t), &allocated) ;
188+
if (MTp32 == NULL && MTp64 == NULL)
189+
{
190+
OUT_OF_MEM_p:
191+
GB_FREE_WORKSPACE ;
192+
return GrB_OUT_OF_MEMORY ;
193+
}
194+
195+
GrB_Type MTtype = op->ztype ;
196+
const size_t MTsize = MTtype->size ;
197+
GB_void MTscalar [GB_VLA(MTsize)] ;
198+
bool MTiso = GB_emult_iso (MTscalar, MTtype, A, B, op) ;
199+
200+
GB_Mp_DECLARE(Mp, ) ;
201+
GB_Mp_PTR(Mp, M) ;
202+
203+
GB_Mh_DECLARE(Mh, ) ;
204+
GB_Mh_PTR(Mh, M) ;
205+
206+
GB_Mi_DECLARE(Mi, ) ;
207+
GB_Mi_PTR(Mi, M) ;
208+
209+
GB_cast_function cast_A = NULL ;
210+
GB_cast_function cast_B = NULL ;
211+
212+
cast_A = GB_cast_factory (op->xtype->code, A->type->code) ;
213+
cast_B = GB_cast_factory (op->ytype->code, B->type->code) ;
214+
215+
int64_t vlen = M->vlen ;
216+
#pragma omp parallel
217+
{
218+
GrB_Index offset ;
219+
220+
#pragma omp for reduction(+:nvecs)
221+
for (GrB_Index k = 0 ; k < M->nvec ; k++)
222+
{
223+
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;
224+
225+
int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
226+
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
227+
bool nonempty = false ;
228+
for (GrB_Index p = pA_start ; p < pA_end ; p++)
229+
{
230+
if (!GBB (M->b, p)) continue ;
231+
232+
int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
233+
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;
234+
235+
// extract elements from A and B, increment MTp
236+
237+
if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p]))
238+
{
239+
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
240+
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);
241+
242+
GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
243+
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);
244+
245+
bool code = GB_lookup_xoffset(&offset, A, arow, acol) ;
246+
if (!code)
247+
{
248+
continue;
249+
}
250+
251+
code = GB_lookup_xoffset(&offset, B, brow, bcol) ;
252+
if (!code)
253+
{
254+
continue;
255+
}
256+
257+
if (M->p_is_32)
258+
{
259+
(MTp32[j])++ ;
260+
}
261+
else
262+
{
263+
(MTp64[j])++ ;
264+
}
265+
nonempty = true ;
266+
}
267+
}
268+
if (nonempty) nvecs++ ;
269+
}
270+
}
271+
272+
// GB_cumsum for MT->p
273+
274+
double work = M->vdim ;
275+
int nthreads_max = GB_Context_nthreads_max ( ) ;
276+
double chunk = GB_Context_chunk ( ) ;
277+
int cumsum_threads = GB_nthreads (work, chunk, nthreads_max) ;
278+
M->p_is_32 ? GB_cumsum(MTp32, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) :
279+
GB_cumsum(MTp64, M->p_is_32, M->vdim, NULL, cumsum_threads, Werk) ;
280+
281+
centries = M->p_is_32 ? MTp32[M->vdim] : MTp64[M->vdim] ;
282+
283+
uint32_t* MTi32 = NULL ; uint64_t* MTi64 = NULL;
284+
MTi32 = M->i_is_32 ? GB_malloc_memory (centries, sizeof(uint32_t), &allocated) : NULL ;
285+
MTi64 = M->i_is_32 ? NULL : GB_malloc_memory (centries, sizeof(uint64_t), &allocated) ;
286+
287+
if (centries > 0 && MTi32 == NULL && MTi64 == NULL)
288+
{
289+
OUT_OF_MEM_i:
290+
if (M->p_is_32) { GB_free_memory (&MTp32, (M->vdim + 1) * sizeof(uint32_t)) ; }
291+
else { GB_free_memory (&MTp64, (M->vdim + 1) * sizeof(uint64_t)) ; }
292+
goto OUT_OF_MEM_p ;
293+
}
294+
295+
void* MTx = NULL ;
296+
if (!MTiso)
297+
{
298+
MTx = GB_malloc_memory (centries, op->ztype->size, &allocated) ;
299+
}
300+
else
301+
{
302+
MTx = GB_malloc_memory (1, op->ztype->size, &allocated) ;
303+
if (MTx == NULL) goto OUT_OF_MEM_x ;
304+
memcpy (MTx, MTscalar, MTsize) ;
305+
}
306+
307+
if (centries > 0 && MTx == NULL)
308+
{
309+
OUT_OF_MEM_x:
310+
if (M->i_is_32) { GB_free_memory (&MTi32, centries * sizeof(uint32_t)) ; }
311+
else { GB_free_memory (&MTi64, centries * sizeof (uint64_t)) ; }
312+
goto OUT_OF_MEM_i ;
313+
}
314+
315+
#pragma omp parallel
316+
{
317+
GrB_Index offset ;
318+
GB_void a_elem[op->xtype->size] ;
319+
GB_void b_elem[op->ytype->size] ;
320+
321+
#pragma omp for
322+
for (GrB_Index k = 0 ; k < M->nvec ; k++)
323+
{
324+
GrB_Index j = Mh32 ? GBH (Mh32, k) : GBH (Mh64, k) ;
325+
326+
int64_t pA_start = Mp32 ? GBP (Mp32, k, vlen) : GBP(Mp64, k, vlen) ;
327+
int64_t pA_end = Mp32 ? GBP (Mp32, k+1, vlen) : GBP(Mp64, k+1, vlen) ;
328+
GrB_Index pos = M->p_is_32 ? MTp32[j] : MTp64[j] ;
329+
for (GrB_Index p = pA_start ; p < pA_end ; p++)
330+
{
331+
if (!GBB (M->b, p)) continue ;
332+
333+
int64_t i = Mi32 ? GBI (Mi32, p, vlen) : GBI (Mi64, p, vlen) ;
334+
GrB_Index Mrow = M->is_csc ? i : j ; GrB_Index Mcol = M->is_csc ? j : i ;
335+
336+
// extract elements from A and B,
337+
// initialize offset in MTi and MTx,
338+
// get result of op, place it in MTx
339+
340+
if (Mask_struct || (M->iso ? ((int8_t*)M->x)[0] : ((int8_t*)M->x)[p]))
341+
{
342+
GrB_Index arow = A_transpose ? (Mcol / bncols) : (Mrow / bnrows);
343+
GrB_Index acol = A_transpose ? (Mrow / bnrows) : (Mcol / bncols);
344+
345+
GrB_Index brow = B_transpose ? (Mcol % bncols) : (Mrow % bnrows);
346+
GrB_Index bcol = B_transpose ? (Mrow % bnrows) : (Mcol % bncols);
347+
348+
bool code = GB_lookup_xoffset (&offset, A, arow, acol) ;
349+
if (!code)
350+
{
351+
continue;
352+
}
353+
if (!MTiso)
354+
cast_A (a_elem, A->x + offset * A->type->size, A->type->size) ;
355+
356+
code = GB_lookup_xoffset (&offset, B, brow, bcol) ;
357+
if (!code)
358+
{
359+
continue;
360+
}
361+
if (!MTiso)
362+
cast_B (b_elem, B->x + offset * B->type->size, B->type->size) ;
363+
364+
if (!MTiso)
365+
{
366+
if (op->binop_function)
367+
{
368+
op->binop_function (MTx + op->ztype->size * pos, a_elem, b_elem) ;
369+
}
370+
else
371+
{
372+
GrB_Index ix, iy, jx, jy ;
373+
ix = A_transpose ? acol : arow ;
374+
iy = A_transpose ? arow : acol ;
375+
jx = B_transpose ? bcol : brow ;
376+
jy = B_transpose ? brow : bcol ;
377+
op->idxbinop_function (MTx + op->ztype->size * pos, a_elem, ix, iy,
378+
b_elem, jx, jy, op->theta) ;
379+
}
380+
}
381+
382+
if (M->i_is_32) { MTi32[pos] = i ; } else { MTi64[pos] = i ; }
383+
pos++ ;
384+
}
385+
}
386+
}
387+
}
388+
389+
#undef GBI
390+
#undef GBB
391+
#undef GBP
392+
#undef GBH
393+
394+
// initialize other fields of MT properly
395+
396+
MT = NULL ;
397+
GrB_Info MTalloc = GB_new_bix (&MT, op->ztype, vlen, M->vdim, GB_ph_null, M->is_csc,
398+
GxB_SPARSE, true, M->hyper_switch, M->vdim, centries, true, MTiso,
399+
M->p_is_32, M->j_is_32, M->i_is_32) ;
400+
if (MTalloc != GrB_SUCCESS)
401+
{
402+
if (MTiso) { GB_free_memory (&MTx, op->ztype->size) ; }
403+
else { GB_free_memory (&MTx, centries * op->ztype->size) ; }
404+
goto OUT_OF_MEM_x ;
405+
}
406+
407+
GB_MATRIX_WAIT(MT) ;
408+
409+
GB_free_memory (&MT->i, MT->i_size) ;
410+
GB_free_memory (&MT->x, MT->x_size) ;
411+
412+
MT->p = M->p_is_32 ? (void*)MTp32 : (void*)MTp64 ;
413+
MT->i = M->i_is_32 ? (void*)MTi32 : (void*)MTi64 ;
414+
MT->x = MTx ;
415+
416+
MT->p_size = (M->p_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * (M->vdim + 1) ;
417+
MT->i_size = ((M->i_is_32 ? sizeof(uint32_t) : sizeof(uint64_t)) * centries) ;
418+
MT->x_size = MT->iso ? op->ztype->size : op->ztype->size * centries ;
419+
MT->magic = GB_MAGIC ;
420+
MT->nvals = centries ;
421+
MT->nvec_nonempty = nvecs ;
422+
423+
// transpose and convert to hyper if needed
424+
425+
if (MT->is_csc != C->is_csc)
426+
{
427+
GrB_Info MTtranspose = GB_transpose_in_place (MT, true, Werk) ;
428+
if (MTtranspose != GrB_SUCCESS)
429+
{
430+
GB_FREE_WORKSPACE ;
431+
GB_Matrix_free (&MT) ;
432+
return MTtranspose ;
433+
}
434+
}
435+
436+
if (MT_hypersparse)
437+
{
438+
uint32_t* MTh32 = NULL ; uint64_t* MTh64 = NULL ;
439+
if (MT->j_is_32)
440+
{
441+
MTh32 = GB_malloc_memory (MT->vdim, sizeof(uint32_t), &allocated) ;
442+
}
443+
else
444+
{
445+
MTh64 = GB_malloc_memory (MT->vdim, sizeof(uint64_t), &allocated) ;
446+
}
447+
448+
if (MTh32 == NULL && MTh64 == NULL)
449+
{
450+
GB_FREE_WORKSPACE ;
451+
GB_Matrix_free (&MT) ;
452+
return GrB_OUT_OF_MEMORY ;
453+
}
454+
455+
#pragma omp parallel for
456+
for (GrB_Index i = 0; i < MT->vdim; i++)
457+
{
458+
if (MT->j_is_32) { MTh32[i] = i ; } else { MTh64[i] = i ; }
459+
}
460+
461+
MT->h = MTh32 ? (void*)MTh32 : (void*)MTh64 ;
462+
463+
GrB_Info MThyperprune = GB_hyper_prune (MT, Werk) ;
464+
if (MThyperprune != GrB_SUCCESS)
465+
{
466+
GB_FREE_WORKSPACE ;
467+
GB_Matrix_free (&MT) ;
468+
return MThyperprune ;
469+
}
470+
}
471+
472+
return (GB_accum_mask (C, M, NULL, accum, &MT, C_replace, Mask_comp, Mask_struct, Werk)) ;
473+
}
474+
107475
//--------------------------------------------------------------------------
108476
// transpose A and B if requested
109477
//--------------------------------------------------------------------------
@@ -153,7 +521,7 @@ GrB_Info GB_kron // C<M> = accum (C, kron(A,B))
153521
GB_CLEAR_MATRIX_HEADER (T, &T_header) ;
154522
GB_OK (GB_kroner (T, T_is_csc, op, flipij,
155523
A_transpose ? AT : A, A_is_pattern,
156-
B_transpose ? BT : B, B_is_pattern, Werk)) ;
524+
B_transpose ? BT : B, B_is_pattern, M, Mask_comp, Mask_struct, Werk)) ;
157525

158526
GB_FREE_WORKSPACE ;
159527
ASSERT_MATRIX_OK (T, "T = kron(A,B)", GB0) ;

0 commit comments

Comments
 (0)