3535#define NUMPY_TIMEDELTA 12
3636#define NUMPY_DATETIME 13
3737#define NUMPY_OBJECT 14
38+ #define NUMPY_BYTES 15
3839
3940#define MYSQL_FLAG_NOT_NULL 1
4041#define MYSQL_FLAG_PRI_KEY 2
339340
340341#define CHECKRC (x ) if ((x) < 0) goto error;
341342
343+ typedef struct {
344+ int type ;
345+ Py_ssize_t length ;
346+ } NumpyColType ;
347+
342348typedef struct {
343349 int results_type ;
344350 int parse_json ;
@@ -2646,8 +2652,8 @@ static char *get_array_base_address(PyObject *py_array) {
26462652}
26472653
26482654
2649- static int get_numpy_col_type (PyObject * py_array ) {
2650- int out = 0 ;
2655+ static NumpyColType get_numpy_col_type (PyObject * py_array ) {
2656+ NumpyColType out = { 0 } ;
26512657 char * str = NULL ;
26522658 PyObject * py_array_interface = NULL ;
26532659 PyObject * py_typestr = NULL ;
@@ -2665,58 +2671,79 @@ static int get_numpy_col_type(PyObject *py_array) {
26652671
26662672 switch (str [1 ]) {
26672673 case 'b' :
2668- out = NUMPY_BOOL ;
2674+ out .type = NUMPY_BOOL ;
2675+ out .length = 1 ;
26692676 break ;
26702677 case 'i' :
26712678 switch (str [2 ]) {
26722679 case '1' :
2673- out = NUMPY_INT8 ;
2680+ out .type = NUMPY_INT8 ;
2681+ out .length = 1 ;
26742682 break ;
26752683 case '2' :
2676- out = NUMPY_INT16 ;
2684+ out .type = NUMPY_INT16 ;
2685+ out .length = 2 ;
26772686 break ;
26782687 case '4' :
2679- out = NUMPY_INT32 ;
2688+ out .type = NUMPY_INT32 ;
2689+ out .length = 4 ;
26802690 break ;
26812691 case '8' :
2682- out = NUMPY_INT64 ;
2692+ out .type = NUMPY_INT64 ;
2693+ out .length = 8 ;
26832694 break ;
26842695 }
26852696 break ;
26862697 case 'u' :
26872698 switch (str [2 ]) {
26882699 case '1' :
2689- out = NUMPY_UINT8 ;
2700+ out .type = NUMPY_UINT8 ;
2701+ out .length = 1 ;
26902702 break ;
26912703 case '2' :
2692- out = NUMPY_UINT16 ;
2704+ out .type = NUMPY_UINT16 ;
2705+ out .length = 2 ;
26932706 break ;
26942707 case '4' :
2695- out = NUMPY_UINT32 ;
2708+ out .type = NUMPY_UINT32 ;
2709+ out .length = 4 ;
26962710 break ;
26972711 case '8' :
2698- out = NUMPY_UINT64 ;
2712+ out .type = NUMPY_UINT64 ;
2713+ out .length = 8 ;
26992714 break ;
27002715 }
27012716 break ;
27022717 case 'f' :
27032718 switch (str [2 ]) {
27042719 case '4' :
2705- out = NUMPY_FLOAT32 ;
2720+ out .type = NUMPY_FLOAT32 ;
2721+ out .length = 4 ;
27062722 break ;
27072723 case '8' :
2708- out = NUMPY_FLOAT64 ;
2724+ out .type = NUMPY_FLOAT64 ;
2725+ out .length = 8 ;
27092726 break ;
27102727 }
27112728 break ;
27122729 case 'O' :
2713- out = NUMPY_OBJECT ;
2730+ out .type = NUMPY_OBJECT ;
2731+ out .length = 8 ;
27142732 break ;
27152733 case 'm' :
2716- out = NUMPY_TIMEDELTA ;
2734+ out .type = NUMPY_TIMEDELTA ;
2735+ out .length = 8 ;
27172736 break ;
27182737 case 'M' :
2719- out = NUMPY_DATETIME ;
2738+ out .type = NUMPY_DATETIME ;
2739+ out .length = 8 ;
2740+ break ;
2741+ case 'S' :
2742+ out .type = NUMPY_BYTES ;
2743+ out .length = (Py_ssize_t )strtol (str + 2 , NULL , 10 );
2744+ if (out .length < 0 ) {
2745+ goto error ;
2746+ }
27202747 break ;
27212748 default :
27222749 goto error ;
@@ -2730,7 +2757,8 @@ static int get_numpy_col_type(PyObject *py_array) {
27302757 return out ;
27312758
27322759error :
2733- out = 0 ;
2760+ out .type = 0 ;
2761+ out .length = 0 ;
27342762 goto exit ;
27352763}
27362764
@@ -2774,7 +2802,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
27742802 unsigned long long j = 0 ;
27752803 char * * cols = NULL ;
27762804 char * * masks = NULL ;
2777- int * col_types = NULL ;
2805+ NumpyColType * col_types = NULL ;
27782806 int64_t * row_ids = NULL ;
27792807
27802808 // Parse function args.
@@ -2847,7 +2875,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
28472875 // Get column array memory
28482876 cols = calloc (sizeof (char * ), n_cols );
28492877 if (!cols ) goto error ;
2850- col_types = calloc (sizeof (int ), n_cols );
2878+ col_types = calloc (sizeof (NumpyColType ), n_cols );
28512879 if (!col_types ) goto error ;
28522880 masks = calloc (sizeof (char * ), n_cols );
28532881 if (!masks ) goto error ;
@@ -2865,7 +2893,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
28652893 }
28662894
28672895 col_types [i ] = get_numpy_col_type (py_data );
2868- if (!col_types [i ]) {
2896+ if (!col_types [i ]. type ) {
28692897 PyErr_SetString (PyExc_ValueError , "unable to get column type of data column" );
28702898 goto error ;
28712899 }
@@ -2874,7 +2902,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
28742902 if (!py_mask ) goto error ;
28752903
28762904 masks [i ] = get_array_base_address (py_mask );
2877- if (masks [i ] && get_numpy_col_type (py_mask ) != NUMPY_BOOL ) {
2905+ if (masks [i ] && get_numpy_col_type (py_mask ). type != NUMPY_BOOL ) {
28782906 PyErr_SetString (PyExc_ValueError , "mask must only contain boolean values" );
28792907 goto error ;
28802908 }
@@ -2958,7 +2986,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
29582986
29592987 case MYSQL_TYPE_TINY :
29602988 CHECKMEM (1 );
2961- switch (col_types [i ]) {
2989+ switch (col_types [i ]. type ) {
29622990 case NUMPY_BOOL :
29632991 i8 = * (int8_t * )(cols [i ] + j * 1 );
29642992 CHECK_TINYINT (i8 , 0 );
@@ -3025,7 +3053,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
30253053 // Use negative to indicate unsigned
30263054 case - MYSQL_TYPE_TINY :
30273055 CHECKMEM (1 );
3028- switch (col_types [i ]) {
3056+ switch (col_types [i ]. type ) {
30293057 case NUMPY_BOOL :
30303058 i8 = * (int8_t * )(cols [i ] + j * 1 );
30313059 CHECK_UNSIGNED_TINYINT (i8 , 0 );
@@ -3091,7 +3119,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
30913119
30923120 case MYSQL_TYPE_SHORT :
30933121 CHECKMEM (2 );
3094- switch (col_types [i ]) {
3122+ switch (col_types [i ]. type ) {
30953123 case NUMPY_BOOL :
30963124 i8 = * (int8_t * )(cols [i ] + j * 1 );
30973125 CHECK_SMALLINT (i8 , 0 );
@@ -3158,7 +3186,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
31583186 // Use negative to indicate unsigned
31593187 case - MYSQL_TYPE_SHORT :
31603188 CHECKMEM (2 );
3161- switch (col_types [i ]) {
3189+ switch (col_types [i ]. type ) {
31623190 case NUMPY_BOOL :
31633191 i8 = * (int8_t * )(cols [i ] + j * 1 );
31643192 CHECK_UNSIGNED_SMALLINT (i8 , 0 );
@@ -3224,7 +3252,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
32243252
32253253 case MYSQL_TYPE_INT24 :
32263254 CHECKMEM (4 );
3227- switch (col_types [i ]) {
3255+ switch (col_types [i ]. type ) {
32283256 case NUMPY_BOOL :
32293257 i8 = * (int8_t * )(cols [i ] + j * 1 );
32303258 CHECK_MEDIUMINT (i8 , 0 );
@@ -3290,7 +3318,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
32903318
32913319 case MYSQL_TYPE_LONG :
32923320 CHECKMEM (4 );
3293- switch (col_types [i ]) {
3321+ switch (col_types [i ]. type ) {
32943322 case NUMPY_BOOL :
32953323 i8 = * (int8_t * )(cols [i ] + j * 1 );
32963324 CHECK_INT (i8 , 0 );
@@ -3357,7 +3385,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
33573385 // Use negative to indicate unsigned
33583386 case - MYSQL_TYPE_INT24 :
33593387 CHECKMEM (4 );
3360- switch (col_types [i ]) {
3388+ switch (col_types [i ]. type ) {
33613389 case NUMPY_BOOL :
33623390 i8 = * (int8_t * )(cols [i ] + j * 1 );
33633391 CHECK_UNSIGNED_MEDIUMINT (i8 , 0 );
@@ -3424,7 +3452,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
34243452 // Use negative to indicate unsigned
34253453 case - MYSQL_TYPE_LONG :
34263454 CHECKMEM (4 );
3427- switch (col_types [i ]) {
3455+ switch (col_types [i ]. type ) {
34283456 case NUMPY_BOOL :
34293457 i8 = * (int8_t * )(cols [i ] + j * 1 );
34303458 CHECK_UNSIGNED_INT (i8 , 0 );
@@ -3490,7 +3518,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
34903518
34913519 case MYSQL_TYPE_LONGLONG :
34923520 CHECKMEM (8 );
3493- switch (col_types [i ]) {
3521+ switch (col_types [i ]. type ) {
34943522 case NUMPY_BOOL :
34953523 i8 = * (int8_t * )(cols [i ] + j * 1 );
34963524 CHECK_BIGINT (i8 , 0 );
@@ -3557,7 +3585,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
35573585 // Use negative to indicate unsigned
35583586 case - MYSQL_TYPE_LONGLONG :
35593587 CHECKMEM (8 );
3560- switch (col_types [i ]) {
3588+ switch (col_types [i ]. type ) {
35613589 case NUMPY_BOOL :
35623590 i8 = * (int8_t * )(cols [i ] + j * 1 );
35633591 CHECK_UNSIGNED_BIGINT (i8 , 0 );
@@ -3623,7 +3651,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
36233651
36243652 case MYSQL_TYPE_FLOAT :
36253653 CHECKMEM (4 );
3626- switch (col_types [i ]) {
3654+ switch (col_types [i ]. type ) {
36273655 case NUMPY_BOOL :
36283656 flt = (float )((is_null ) ? 0 : * (int8_t * )(cols [i ] + j * 1 ));
36293657 break ;
@@ -3667,7 +3695,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
36673695
36683696 case MYSQL_TYPE_DOUBLE :
36693697 CHECKMEM (8 );
3670- switch (col_types [i ]) {
3698+ switch (col_types [i ]. type ) {
36713699 case NUMPY_BOOL :
36723700 dbl = (double )((is_null ) ? 0 : * (int8_t * )(cols [i ] + j * 1 ));
36733701 break ;
@@ -3742,7 +3770,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
37423770
37433771 case MYSQL_TYPE_YEAR :
37443772 CHECKMEM (2 );
3745- switch (col_types [i ]) {
3773+ switch (col_types [i ]. type ) {
37463774 case NUMPY_BOOL :
37473775 i8 = * (int8_t * )(cols [i ] + j * 1 );
37483776 CHECK_YEAR (i8 );
@@ -3817,7 +3845,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
38173845 case MYSQL_TYPE_MEDIUM_BLOB :
38183846 case MYSQL_TYPE_LONG_BLOB :
38193847 case MYSQL_TYPE_BLOB :
3820- if (col_types [i ] != NUMPY_OBJECT ) {
3848+ if (col_types [i ]. type != NUMPY_OBJECT ) {
38213849 PyErr_SetString (PyExc_ValueError , "unsupported numpy data type for character output types" );
38223850 goto error ;
38233851 }
@@ -3873,7 +3901,7 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
38733901 case - MYSQL_TYPE_MEDIUM_BLOB :
38743902 case - MYSQL_TYPE_LONG_BLOB :
38753903 case - MYSQL_TYPE_BLOB :
3876- if (col_types [i ] != NUMPY_OBJECT ) {
3904+ if (col_types [i ]. type != NUMPY_OBJECT && col_types [ i ]. type != NUMPY_BYTES ) {
38773905 PyErr_SetString (PyExc_ValueError , "unsupported numpy data type for binary output types" );
38783906 goto error ;
38793907 }
@@ -3884,6 +3912,24 @@ static PyObject *dump_rowdat_1_numpy(PyObject *self, PyObject *args, PyObject *k
38843912 memcpy (out + out_idx , & i64 , 8 );
38853913 out_idx += 8 ;
38863914
3915+ } else if (col_types [i ].type == NUMPY_BYTES ) {
3916+ void * bytes = (void * )(cols [i ] + j * 8 );
3917+
3918+ if (bytes == NULL ) {
3919+ CHECKMEM (8 );
3920+ i64 = 0 ;
3921+ memcpy (out + out_idx , & i64 , 8 );
3922+ out_idx += 8 ;
3923+ } else {
3924+ Py_ssize_t str_l = col_types [i ].length ;
3925+ CHECKMEM (8 + str_l );
3926+ i64 = str_l ;
3927+ memcpy (out + out_idx , & i64 , 8 );
3928+ out_idx += 8 ;
3929+ memcpy (out + out_idx , bytes , str_l );
3930+ out_idx += str_l ;
3931+ }
3932+
38873933 } else {
38883934 u64 = * (uint64_t * )(cols [i ] + j * 8 );
38893935
0 commit comments