diff --git a/mssql_python/pybind/ddbc_bindings.cpp b/mssql_python/pybind/ddbc_bindings.cpp index f0a5de75..96d38c6c 100644 --- a/mssql_python/pybind/ddbc_bindings.cpp +++ b/mssql_python/pybind/ddbc_bindings.cpp @@ -68,6 +68,35 @@ inline std::string GetEffectiveCharDecoding(const std::string& userEncoding) { #endif } +// Returns true if VARCHAR columns should be fetched as SQL_C_WCHAR (UTF-16LE) +// instead of SQL_C_CHAR to avoid the lossy ACP conversion on Windows. +// +// On Windows, the ODBC driver converts SQL_C_CHAR data from the server's encoding +// to the system's ANSI code page (e.g., CP1252). This is lossy for characters +// outside the ACP range. When the user requests UTF-8 decoding for SQL_CHAR, +// we fetch as SQL_C_WCHAR (UTF-16LE) which the ODBC driver converts losslessly, +// then decode from UTF-16LE to Python str. +// +// On Linux/macOS, the ODBC driver already returns UTF-8 for SQL_C_CHAR based +// on the system locale, so this workaround is not needed. +inline bool ShouldFetchCharAsWChar(const std::string& charEncoding) { +#if defined(_WIN32) + // Normalize: lowercase and strip '-' and '_' to match all Python codec + // variants ("utf-8", "UTF-8", "utf8", "Utf_8", "UTF_8", etc.) + std::string normalized; + normalized.reserve(charEncoding.size()); + for (char c : charEncoding) { + if (c != '-' && c != '_') { + normalized += static_cast(std::tolower(static_cast(c))); + } + } + return normalized == "utf8"; +#else + (void)charEncoding; + return false; +#endif +} + namespace PythonObjectCache { py::object get_time_class(); } @@ -275,38 +304,38 @@ struct ArrowSchemaPrivateData { #define ARROW_FLAG_MAP_KEYS_SORTED 4 struct ArrowSchema { - // Array type description - const char* format; - const char* name; - const char* metadata; - int64_t flags; - int64_t n_children; - struct ArrowSchema** children; - struct ArrowSchema* dictionary; - - // Release callback - void (*release)(struct ArrowSchema*); - // Opaque producer-specific data - // Only our child-arrays will set this, so we can give it the correct type - ArrowSchemaPrivateData* private_data; + // Array type description + const char* format; + const char* name; + const char* metadata; + int64_t flags; + int64_t n_children; + struct ArrowSchema** children; + struct ArrowSchema* dictionary; + + // Release callback + void (*release)(struct ArrowSchema*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowSchemaPrivateData* private_data; }; struct ArrowArray { - // Array data description - int64_t length; - int64_t null_count; - int64_t offset; - int64_t n_buffers; - int64_t n_children; - const void** buffers; - struct ArrowArray** children; - struct ArrowArray* dictionary; - - // Release callback - void (*release)(struct ArrowArray*); - // Opaque producer-specific data - // Only our child-arrays will set this, so we can give it the correct type - ArrowArrayPrivateData* private_data; + // Array data description + int64_t length; + int64_t null_count; + int64_t offset; + int64_t n_buffers; + int64_t n_children; + const void** buffers; + struct ArrowArray** children; + struct ArrowArray* dictionary; + + // Release callback + void (*release)(struct ArrowArray*); + // Opaque producer-specific data + // Only our child-arrays will set this, so we can give it the correct type + ArrowArrayPrivateData* private_data; }; #endif // ARROW_C_DATA_INTERFACE @@ -1828,8 +1857,8 @@ SQLRETURN SQLTables_wrap(SqlHandlePtr StatementHandle, const std::wstring& catal { // Release the GIL during the blocking ODBC catalog call py::gil_scoped_release release; - ret = SQLTables_ptr(StatementHandle->get(), catalogPtr, catalogLen, schemaPtr, - schemaLen, tablePtr, tableLen, tableTypePtr, tableTypeLen); + ret = SQLTables_ptr(StatementHandle->get(), catalogPtr, catalogLen, schemaPtr, schemaLen, + tablePtr, tableLen, tableTypePtr, tableTypeLen); } LOG("SQLTables: Catalog metadata query %s - SQLRETURN=%d", @@ -2036,8 +2065,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, while (offset < totalBytes) { size_t len = std::min(chunkBytes, totalBytes - offset); - rc = putData((SQLPOINTER)(dataPtr + offset), - static_cast(len)); + rc = putData((SQLPOINTER)(dataPtr + offset), static_cast(len)); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPutData failed for " "SQL_C_CHAR chunk - offset=%zu", @@ -2058,8 +2086,7 @@ SQLRETURN SQLExecute_wrap(const SqlHandlePtr statementHandle, const size_t chunkSize = DAE_CHUNK_SIZE; for (size_t offset = 0; offset < totalBytes; offset += chunkSize) { size_t len = std::min(chunkSize, totalBytes - offset); - rc = putData((SQLPOINTER)(dataPtr + offset), - static_cast(len)); + rc = putData((SQLPOINTER)(dataPtr + offset), static_cast(len)); if (!SQL_SUCCEEDED(rc)) { LOG("SQLExecute: SQLPutData failed for " "binary/bytes chunk - offset=%zu", @@ -3320,11 +3347,91 @@ SQLRETURN SQLGetData_wrap(SqlHandlePtr StatementHandle, SQLUSMALLINT colCount, p case SQL_LONGVARCHAR: { if (columnSize == SQL_NO_TOTAL || columnSize == 0 || columnSize > SQL_MAX_LOB_SIZE) { - LOG("SQLGetData: Streaming LOB for column %d (SQL_C_CHAR) " - "- columnSize=%lu", - i, (unsigned long)columnSize); - row.append( - FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, charEncoding)); + // LOB path: stream the data + if (ShouldFetchCharAsWChar(charEncoding)) { + // LCOV_EXCL_START - Windows-only: ShouldFetchCharAsWChar always false on Linux + // On Windows with UTF-8, fetch LOB VARCHAR as WCHAR to avoid + // lossy ACP conversion + LOG("SQLGetData: Streaming LOB for column %d (SQL_C_WCHAR via " + "UTF-8 workaround) - columnSize=%lu", + i, (unsigned long)columnSize); + row.append( + FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, charEncoding)); + // LCOV_EXCL_STOP + } else { + LOG("SQLGetData: Streaming LOB for column %d (SQL_C_CHAR) " + "- columnSize=%lu", + i, (unsigned long)columnSize); + row.append( + FetchLobColumnData(hStmt, i, SQL_C_CHAR, false, false, charEncoding)); + } + } else if (ShouldFetchCharAsWChar(charEncoding)) { + // LCOV_EXCL_START - Windows-only: ShouldFetchCharAsWChar always false on Linux + // On Windows with UTF-8 decoding: fetch VARCHAR as SQL_C_WCHAR + // to bypass the ODBC driver's lossy ACP (e.g. CP1252) conversion. + // The ODBC driver converts losslessly to UTF-16LE for SQL_C_WCHAR. + uint64_t wcharBufSize = (columnSize + 1); // in SQLWCHAR units + std::vector wdataBuffer(wcharBufSize); + SQLLEN dataLen; + ret = SQLGetData_ptr(hStmt, i, SQL_C_WCHAR, wdataBuffer.data(), + wcharBufSize * sizeof(SQLWCHAR), &dataLen); + if (SQL_SUCCEEDED(ret)) { + if (dataLen > 0) { + uint64_t numCharsInData = dataLen / sizeof(SQLWCHAR); + if (numCharsInData <= columnSize) { +#if defined(_WIN32) + PyObject* pyStr = PyUnicode_FromWideChar( + reinterpret_cast(wdataBuffer.data()), numCharsInData); +#else + PyObject* pyStr = PyUnicode_DecodeUTF16( + reinterpret_cast(wdataBuffer.data()), + numCharsInData * sizeof(SQLWCHAR), NULL, NULL); +#endif + if (pyStr) { + row.append(py::reinterpret_steal(pyStr)); + LOG("SQLGetData: CHAR column %d fetched as WCHAR (UTF-8 " + "workaround), %zu bytes -> decoded", + i, (size_t)dataLen); + } else { + PyErr_Clear(); + LOG_ERROR("SQLGetData: Failed to decode WCHAR data for " + "CHAR column %d", + i); + row.append(py::none()); + } + } else { + // Buffer too small, fallback to LOB streaming + LOG("SQLGetData: CHAR column %d WCHAR data truncated, " + "using streaming LOB", + i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, + charEncoding)); + } + } else if (dataLen == SQL_NULL_DATA) { + LOG("SQLGetData: Column %d is NULL (CHAR via WCHAR)", i); + row.append(py::none()); + } else if (dataLen == 0) { + row.append(py::str("")); + } else if (dataLen == SQL_NO_TOTAL) { + LOG("SQLGetData: SQL_NO_TOTAL for column %d (CHAR via WCHAR), " + "falling back to LOB", + i); + row.append(FetchLobColumnData(hStmt, i, SQL_C_WCHAR, true, false, + charEncoding)); + } else if (dataLen < 0) { + LOG("SQLGetData: Unexpected negative data length " + "for column %d (CHAR via WCHAR) - dataLen=%ld", + i, (long)dataLen); + ThrowStdException("SQLGetData returned an unexpected negative " + "data length"); + } + } else { + LOG("SQLGetData: Error retrieving WCHAR data for CHAR column %d " + "- SQLRETURN=%d, returning NULL", + i, ret); + row.append(py::none()); + } + // LCOV_EXCL_STOP } else { // Allocate columnSize * 4 + 1 on ALL platforms (no #if guard). // @@ -3846,9 +3953,12 @@ SQLRETURN SQLFetchScroll_wrap(SqlHandlePtr StatementHandle, SQLSMALLINT FetchOri // For column in the result set, binds a buffer to retrieve column data // TODO: Move to anonymous namespace, since it is not used outside this file +// charEncoding default is "" so callers that don't pass it (e.g. Arrow path) +// will NOT trigger the WCHAR workaround for VARCHAR columns. SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& columnNames, - SQLUSMALLINT numCols, int fetchSize) { + SQLUSMALLINT numCols, int fetchSize, const std::string& charEncoding = "") { SQLRETURN ret = SQL_SUCCESS; + const bool fetchCharAsWChar = ShouldFetchCharAsWChar(charEncoding); // Bind columns based on their data types for (SQLUSMALLINT col = 1; col <= numCols; col++) { auto columnMeta = columnNames[col - 1].cast(); @@ -3862,29 +3972,41 @@ SQLRETURN SQLBindColums(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& column // TODO: handle variable length data correctly. This logic wont // suffice HandleZeroColumnSizeAtFetch(columnSize); - // Use columnSize * 4 + 1 on Linux/macOS to accommodate UTF-8 - // expansion. The ODBC driver returns UTF-8 for SQL_C_CHAR where - // each character can be up to 4 bytes. + if (fetchCharAsWChar) { + // LCOV_EXCL_START - Windows-only: fetchCharAsWChar always false on Linux + // On Windows with UTF-8: bind VARCHAR as SQL_C_WCHAR to + // bypass the ODBC driver's lossy ACP conversion. + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + buffers.wcharBuffers[col - 1].resize(fetchSize * fetchBufferSize); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_WCHAR, buffers.wcharBuffers[col - 1].data(), + fetchBufferSize * sizeof(SQLWCHAR), buffers.indicators[col - 1].data()); + // LCOV_EXCL_STOP + } else { + // Use columnSize * 4 + 1 on Linux/macOS to accommodate UTF-8 + // expansion. The ODBC driver returns UTF-8 for SQL_C_CHAR where + // each character can be up to 4 bytes. #if defined(__APPLE__) || defined(__linux__) - uint64_t fetchBufferSize = columnSize * 4 + 1 /*null-terminator*/; + uint64_t fetchBufferSize = columnSize * 4 + 1 /*null-terminator*/; #else - uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; + uint64_t fetchBufferSize = columnSize + 1 /*null-terminator*/; #endif - // TODO: For LONGVARCHAR/BINARY types, columnSize is returned as - // 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB. - // fetchSize=1 if columnSize>1GB. So we'll allocate a vector of - // size 2GB. If a query fetches multiple (say N) LONG... - // columns, we will have allocated multiple (N) 2GB sized - // vectors. This will make driver very slow. And if the N is - // high enough, we could hit the OS limit for heap memory that - // we can allocate, & hence get a std::bad_alloc. The process - // could also be killed by OS for consuming too much memory. - // Hence this will be revisited in beta to not allocate 2GB+ - // memory, & use streaming instead - buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); - ret = SQLBindCol_ptr(hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), - fetchBufferSize * sizeof(SQLCHAR), - buffers.indicators[col - 1].data()); + // TODO: For LONGVARCHAR/BINARY types, columnSize is returned as + // 2GB-1 by SQLDescribeCol. So fetchBufferSize = 2GB. + // fetchSize=1 if columnSize>1GB. So we'll allocate a vector of + // size 2GB. If a query fetches multiple (say N) LONG... + // columns, we will have allocated multiple (N) 2GB sized + // vectors. This will make driver very slow. And if the N is + // high enough, we could hit the OS limit for heap memory that + // we can allocate, & hence get a std::bad_alloc. The process + // could also be killed by OS for consuming too much memory. + // Hence this will be revisited in beta to not allocate 2GB+ + // memory, & use streaming instead + buffers.charBuffers[col - 1].resize(fetchSize * fetchBufferSize); + ret = SQLBindCol_ptr( + hStmt, col, SQL_C_CHAR, buffers.charBuffers[col - 1].data(), + fetchBufferSize * sizeof(SQLCHAR), buffers.indicators[col - 1].data()); + } break; } case SQL_WCHAR: @@ -4043,6 +4165,7 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum bool isLob; }; std::vector columnInfos(numCols); + const bool fetchCharAsWChar = ShouldFetchCharAsWChar(charEncoding); for (SQLUSMALLINT col = 0; col < numCols; col++) { const auto& columnMeta = columnNames[col].cast(); columnInfos[col].dataType = columnMeta["DataType"].cast(); @@ -4051,22 +4174,31 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum std::find(lobColumns.begin(), lobColumns.end(), col + 1) != lobColumns.end(); columnInfos[col].processedColumnSize = columnInfos[col].columnSize; HandleZeroColumnSizeAtFetch(columnInfos[col].processedColumnSize); - // On Linux/macOS, the ODBC driver returns UTF-8 for SQL_C_CHAR where - // each character can be up to 4 bytes. Must match SQLBindColums buffer. -#if defined(__APPLE__) || defined(__linux__) + SQLSMALLINT dt = columnInfos[col].dataType; bool isCharType = (dt == SQL_CHAR || dt == SQL_VARCHAR || dt == SQL_LONGVARCHAR); - if (isCharType) { - columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize * 4 + - 1; // *4 for UTF-8, +1 for null terminator - } else { + + if (fetchCharAsWChar && isCharType) { + // When fetching VARCHAR as WCHAR (UTF-8 workaround on Windows), + // fetchBufferSize is in SQLWCHAR units to match SQLBindColums columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize + 1; // +1 for null terminator - } + } else { + // On Linux/macOS, the ODBC driver returns UTF-8 for SQL_C_CHAR where + // each character can be up to 4 bytes. Must match SQLBindColums buffer. +#if defined(__APPLE__) || defined(__linux__) + if (isCharType) { + columnInfos[col].fetchBufferSize = columnInfos[col].processedColumnSize * 4 + + 1; // *4 for UTF-8, +1 for null terminator + } else { + columnInfos[col].fetchBufferSize = + columnInfos[col].processedColumnSize + 1; // +1 for null terminator + } #else - columnInfos[col].fetchBufferSize = - columnInfos[col].processedColumnSize + 1; // +1 for null terminator + columnInfos[col].fetchBufferSize = + columnInfos[col].processedColumnSize + 1; // +1 for null terminator #endif + } } // Performance: Build function pointer dispatch table (once per batch) @@ -4118,7 +4250,13 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: - columnProcessors[col] = ColumnProcessors::ProcessChar; + // When fetchCharAsWChar is active, VARCHAR data is in wcharBuffers + // (bound as SQL_C_WCHAR) so use the WCHAR processor for decoding. + if (fetchCharAsWChar) { + columnProcessors[col] = ColumnProcessors::ProcessWChar; // LCOV_EXCL_LINE - Windows-only + } else { + columnProcessors[col] = ColumnProcessors::ProcessChar; + } break; case SQL_WCHAR: case SQL_WVARCHAR: @@ -4361,7 +4499,8 @@ SQLRETURN FetchBatchData(SQLHSTMT hStmt, ColumnBuffers& buffers, py::list& colum // Given a list of columns that are a part of single row in the result set, // calculates the max size of the row // TODO: Move to anonymous namespace, since it is not used outside this file -size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { +size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols, + const std::string& charEncoding = "") { size_t rowSize = 0; for (SQLUSMALLINT col = 1; col <= numCols; col++) { auto columnMeta = columnNames[col - 1].cast(); @@ -4372,7 +4511,14 @@ size_t calculateRowSize(py::list& columnNames, SQLUSMALLINT numCols) { case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: - rowSize += columnSize; + // When UTF-8 WCHAR workaround is active on Windows, + // VARCHAR is bound as SQL_C_WCHAR (2 bytes per char). + // Account for this in memory estimation. + if (ShouldFetchCharAsWChar(charEncoding)) { + rowSize += columnSize * sizeof(SQLWCHAR); // LCOV_EXCL_LINE - Windows-only + } else { + rowSize += columnSize; + } break; case SQL_SS_XML: case SQL_WCHAR: @@ -4521,7 +4667,7 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch ColumnBuffers buffers(numCols, fetchSize); // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize, charEncoding); if (!SQL_SUCCEEDED(ret)) { LOG("FetchMany_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; @@ -4559,15 +4705,12 @@ SQLRETURN FetchMany_wrap(SqlHandlePtr StatementHandle, py::list& rows, int fetch // @param indicator: Pointer to indicator value (SQL_NULL_DATA for NULL, or data length) // // @return SQLRETURN: SQL_SUCCESS on success, or error code on failure -template -SQLRETURN GetDataVar(SQLHSTMT hStmt, - SQLUSMALLINT colNumber, - SQLSMALLINT cType, - std::vector& dataVec, - SQLLEN* indicator) { +template +SQLRETURN GetDataVar(SQLHSTMT hStmt, SQLUSMALLINT colNumber, SQLSMALLINT cType, + std::vector& dataVec, SQLLEN* indicator) { size_t start = 0; size_t end = 0; - + // Determine null terminator size based on data type size_t sizeNullTerminator = 0; switch (cType) { @@ -4581,7 +4724,7 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, default: ThrowStdException("GetDataVar only supports SQL_C_CHAR, SQL_C_WCHAR, and SQL_C_BINARY"); } - + // Ensure initial buffer has space for at least the null terminator if (dataVec.size() < sizeNullTerminator) { dataVec.resize(sizeNullTerminator); @@ -4590,13 +4733,9 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, while (true) { SQLLEN localInd = 0; SQLRETURN ret = SQLGetData_ptr( - hStmt, - colNumber, - cType, - reinterpret_cast(dataVec.data() + start), + hStmt, colNumber, cType, reinterpret_cast(dataVec.data() + start), sizeof(T) * (dataVec.size() - start), // Available buffer size from start position - &localInd - ); + &localInd); // Handle NULL data if (localInd == SQL_NULL_DATA) { @@ -4630,10 +4769,10 @@ SQLRETURN GetDataVar(SQLHSTMT hStmt, assert(localInd % sizeof(T) == 0); end = start + static_cast(localInd) / sizeof(T) + sizeNullTerminator; } - + // The next read starts where the null terminator would have been placed start = dataVec.size() - sizeNullTerminator; - + // Resize buffer for next iteration dataVec.resize(end); } else { @@ -4670,17 +4809,14 @@ int32_t days_from_civil(int y, int m, int d) { // Returns number of days since Unix epoch (1970-01-01) y -= m <= 2; const int era = (y >= 0 ? y : y - 399) / 400; - const unsigned yoe = static_cast(y - era * 400); // [0, 399] - const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] - const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] + const unsigned yoe = static_cast(y - era * 400); // [0, 399] + const unsigned doy = (153 * (m + (m > 2 ? -3 : 9)) + 2) / 5 + d - 1; // [0, 365] + const unsigned doe = yoe * 365 + yoe / 4 - yoe / 100 + doy; // [0, 146096] return era * 146097 + static_cast(doe) - 719468; } -SQLRETURN FetchArrowBatch_wrap( - SqlHandlePtr StatementHandle, - py::list& capsules, - int arrowBatchSize -) { +SQLRETURN FetchArrowBatch_wrap(SqlHandlePtr StatementHandle, py::list& capsules, + int arrowBatchSize) { // An overly large fetch size doesn't seem to help performance int fetchSize = 64; @@ -4724,15 +4860,14 @@ SQLRETURN FetchArrowBatch_wrap( columnSizes[i] = columnSize; columnNullable[i] = (nullable != SQL_NO_NULLS); - if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || - dataType == SQL_VARCHAR || dataType == SQL_LONGVARCHAR || - dataType == SQL_VARBINARY || dataType == SQL_LONGVARBINARY || - dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && + if ((dataType == SQL_WVARCHAR || dataType == SQL_WLONGVARCHAR || dataType == SQL_VARCHAR || + dataType == SQL_LONGVARCHAR || dataType == SQL_VARBINARY || + dataType == SQL_LONGVARBINARY || dataType == SQL_SS_XML || dataType == SQL_SS_UDT) && (columnSize == 0 || columnSize == SQL_NO_TOTAL || columnSize > SQL_MAX_LOB_SIZE)) { - hasLobColumns = true; - if (fetchSize > 1) { - fetchSize = 1; // LOBs require row-by-row fetch - } + hasLobColumns = true; + if (fetchSize > 1) { + fetchSize = 1; // LOBs require row-by-row fetch + } } std::string columnName = colMeta["ColumnName"].cast(); @@ -4741,7 +4876,7 @@ SQLRETURN FetchArrowBatch_wrap( std::memcpy(arrowSchemaPrivateData[i]->name.get(), columnName.c_str(), nameLen); std::string format = ""; - switch(dataType) { + switch (dataType) { case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: @@ -4804,7 +4939,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DECIMAL: case SQL_NUMERIC: { std::ostringstream formatStream; - formatStream << "d:" << columnSize << "," << colMeta["DecimalDigits"].cast(); + formatStream << "d:" << columnSize << "," + << colMeta["DecimalDigits"].cast(); std::string formatStr = formatStream.str(); size_t formatLen = formatStr.length() + 1; arrowSchemaPrivateData[i]->format = std::make_unique(formatLen); @@ -4844,13 +4980,14 @@ SQLRETURN FetchArrowBatch_wrap( break; default: std::ostringstream errorString; - errorString << "Unsupported data type for Arrow batch fetch for column - " << columnName.c_str() - << ", Type - " << dataType << ", column ID - " << (i + 1); + errorString << "Unsupported data type for Arrow batch fetch for column - " + << columnName.c_str() << ", Type - " << dataType << ", column ID - " + << (i + 1); LOG(errorString.str().c_str()); ThrowStdException(errorString.str()); break; } - + // Store format string if not already stored. // For non-decimal types, format is now a static string. if (!arrowSchemaPrivateData[i]->format) { @@ -4869,13 +5006,16 @@ SQLRETURN FetchArrowBatch_wrap( if (!hasLobColumns && fetchSize > 0) { // Bind columns + // Arrow path intentionally omits charEncoding (defaults to "") so that + // ShouldFetchCharAsWChar returns false and VARCHAR stays bound as + // SQL_C_CHAR. Arrow has its own char-to-UTF-8 processing pipeline. ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); if (!SQL_SUCCEEDED(ret)) { LOG("Error when binding columns"); return ret; } } - + SQLULEN numRowsFetched = 0; FetchStateGuard fetchStateGuard(hStmt, &numRowsFetched, fetchSize); @@ -4893,7 +5033,7 @@ SQLRETURN FetchArrowBatch_wrap( ret = SQLFetch_ptr(hStmt); } if (ret == SQL_NO_DATA) { - ret = SQL_SUCCESS; // Normal completion + ret = SQL_SUCCESS; // Normal completion break; } if (!SQL_SUCCEEDED(ret)) { @@ -4912,18 +5052,14 @@ SQLRETURN FetchArrowBatch_wrap( if (hasLobColumns) { assert(idxRowSql == 0 && "GetData only works one row at a time"); - switch(dataType) { + switch (dataType) { case SQL_SS_UDT: case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_BINARY, - buffers.charBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_BINARY, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching BINARY LOB for column %d", idxCol + 1); return ret; @@ -4933,13 +5069,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_CHAR, - buffers.charBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching CHAR LOB for column %d", idxCol + 1); return ret; @@ -4950,13 +5082,9 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WCHAR: case SQL_WVARCHAR: case SQL_WLONGVARCHAR: { - ret = GetDataVar( - hStmt, - idxCol + 1, - SQL_C_WCHAR, - buffers.wcharBuffers[idxCol], - buffers.indicators[idxCol].data() - ); + ret = GetDataVar(hStmt, idxCol + 1, SQL_C_WCHAR, + buffers.wcharBuffers[idxCol], + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching WCHAR LOB data for column %d", idxCol + 1); return ret; @@ -4966,11 +5094,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_INTEGER: { buffers.intBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SLONG, - buffers.intBuffers[idxCol].data(), - sizeof(SQLINTEGER), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_SLONG, buffers.intBuffers[idxCol].data(), + sizeof(SQLINTEGER), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SLONG data for column %d", idxCol + 1); return ret; @@ -4979,12 +5104,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SMALLINT: { buffers.smallIntBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SSHORT, - buffers.smallIntBuffers[idxCol].data(), - sizeof(SQLSMALLINT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SSHORT, + buffers.smallIntBuffers[idxCol].data(), + sizeof(SQLSMALLINT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SSHORT data for column %d", idxCol + 1); return ret; @@ -4993,12 +5116,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_TINYINT: { buffers.charBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TINYINT, - buffers.charBuffers[idxCol].data(), - sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + ret = + SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TINYINT, + buffers.charBuffers[idxCol].data(), sizeof(SQLCHAR), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TINYINT data for column %d", idxCol + 1); return ret; @@ -5008,11 +5129,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BIT: { buffers.charBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_BIT, - buffers.charBuffers[idxCol].data(), - sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_BIT, buffers.charBuffers[idxCol].data(), + sizeof(SQLCHAR), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching BIT data for column %d", idxCol + 1); return ret; @@ -5022,11 +5140,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_REAL: { buffers.realBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_FLOAT, - buffers.realBuffers[idxCol].data(), - sizeof(SQLREAL), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_FLOAT, buffers.realBuffers[idxCol].data(), + sizeof(SQLREAL), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching FLOAT data for column %d", idxCol + 1); return ret; @@ -5036,12 +5151,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DECIMAL: case SQL_NUMERIC: { buffers.charBuffers[idxCol].resize(MAX_DIGITS_IN_NUMERIC); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_CHAR, - buffers.charBuffers[idxCol].data(), - MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_CHAR, + buffers.charBuffers[idxCol].data(), + MAX_DIGITS_IN_NUMERIC * sizeof(SQLCHAR), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching CHAR data for column %d", idxCol + 1); return ret; @@ -5051,12 +5164,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_DOUBLE: case SQL_FLOAT: { buffers.doubleBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_DOUBLE, - buffers.doubleBuffers[idxCol].data(), - sizeof(SQLDOUBLE), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_DOUBLE, + buffers.doubleBuffers[idxCol].data(), + sizeof(SQLDOUBLE), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching DOUBLE data for column %d", idxCol + 1); return ret; @@ -5067,12 +5178,10 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { buffers.timestampBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, - buffers.timestampBuffers[idxCol].data(), - sizeof(SQL_TIMESTAMP_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TYPE_TIMESTAMP, + buffers.timestampBuffers[idxCol].data(), + sizeof(SQL_TIMESTAMP_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_TIMESTAMP data for column %d", idxCol + 1); return ret; @@ -5081,12 +5190,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_BIGINT: { buffers.bigIntBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SBIGINT, - buffers.bigIntBuffers[idxCol].data(), - sizeof(SQLBIGINT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SBIGINT, + buffers.bigIntBuffers[idxCol].data(), + sizeof(SQLBIGINT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching SBIGINT data for column %d", idxCol + 1); return ret; @@ -5095,12 +5202,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_TYPE_DATE: { buffers.dateBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_TYPE_DATE, - buffers.dateBuffers[idxCol].data(), - sizeof(SQL_DATE_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_TYPE_DATE, + buffers.dateBuffers[idxCol].data(), + sizeof(SQL_DATE_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_DATE data for column %d", idxCol + 1); return ret; @@ -5109,12 +5214,10 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIME2: { buffers.timeBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SS_TIME2, - buffers.timeBuffers[idxCol].data(), - sizeof(SQL_SS_TIME2_STRUCT), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SS_TIME2, + buffers.timeBuffers[idxCol].data(), + sizeof(SQL_SS_TIME2_STRUCT), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching TYPE_TIME data for column %d", idxCol + 1); return ret; @@ -5124,11 +5227,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_GUID: { buffers.guidBuffers[idxCol].resize(1); ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_GUID, - buffers.guidBuffers[idxCol].data(), - sizeof(SQLGUID), - buffers.indicators[idxCol].data() - ); + hStmt, idxCol + 1, SQL_C_GUID, buffers.guidBuffers[idxCol].data(), + sizeof(SQLGUID), buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { LOG("Error fetching GUID data for column %d", idxCol + 1); return ret; @@ -5137,14 +5237,13 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIMESTAMPOFFSET: { buffers.datetimeoffsetBuffers[idxCol].resize(1); - ret = SQLGetData_ptr( - hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, - buffers.datetimeoffsetBuffers[idxCol].data(), - sizeof(DateTimeOffset), - buffers.indicators[idxCol].data() - ); + ret = SQLGetData_ptr(hStmt, idxCol + 1, SQL_C_SS_TIMESTAMPOFFSET, + buffers.datetimeoffsetBuffers[idxCol].data(), + sizeof(DateTimeOffset), + buffers.indicators[idxCol].data()); if (!SQL_SUCCEEDED(ret)) { - LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", idxCol + 1); + LOG("Error fetching SS_TIMESTAMPOFFSET data for column %d", + idxCol + 1); return ret; } break; @@ -5170,8 +5269,7 @@ SQLRETURN FetchArrowBatch_wrap( // Value buffer for variable length data types needs to be set appropriately // as it will be used by the next non null value - switch (dataType) - { + switch (dataType) { case SQL_CHAR: case SQL_VARCHAR: case SQL_LONGVARCHAR: @@ -5184,7 +5282,8 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_BINARY: case SQL_VARBINARY: case SQL_LONGVARBINARY: - arrowColumnProducer->varVal[idxRowArrow + 1] = arrowColumnProducer->varVal[idxRowArrow]; + arrowColumnProducer->varVal[idxRowArrow + 1] = + arrowColumnProducer->varVal[idxRowArrow]; break; default: break; @@ -5194,7 +5293,9 @@ SQLRETURN FetchArrowBatch_wrap( continue; } else if (indicator < 0) { // Negative value is unexpected, log column index, SQL type & raise exception - LOG("Unexpected negative data length. Column ID - %d, SQL Type - %d, Data Length - %lld", idxCol + 1, dataType, (long long)indicator); + LOG("Unexpected negative data length. Column ID - %d, SQL Type - %d, Data " + "Length - %lld", + idxCol + 1, dataType, (long long)indicator); ThrowStdException("Unexpected negative data length."); } auto dataLen = static_cast(indicator); @@ -5211,7 +5312,9 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + std::memcpy(&(*target_vec)[start], + &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], + dataLen); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } @@ -5229,7 +5332,9 @@ SQLRETURN FetchArrowBatch_wrap( target_vec->resize(target_vec->size() * 2); } - std::memcpy(&(*target_vec)[start], &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], dataLen); + std::memcpy(&(*target_vec)[start], + &buffers.charBuffers[idxCol][idxRowSql * fetchBufferSize], + dataLen); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLen; break; } @@ -5239,16 +5344,21 @@ SQLRETURN FetchArrowBatch_wrap( case SQL_WLONGVARCHAR: { assert(dataLen % sizeof(SQLWCHAR) == 0); auto dataLenW = dataLen / sizeof(SQLWCHAR); - auto wcharSource = &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; + auto wcharSource = + &buffers.wcharBuffers[idxCol][idxRowSql * (columnSize + 1)]; auto start = arrowColumnProducer->varVal[idxRowArrow]; auto target_vec = &arrowColumnProducer->varData; #if defined(_WIN32) // Convert wide string - int dataLenConverted = WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), NULL, 0, NULL, NULL); + int dataLenConverted = + WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), + NULL, 0, NULL, NULL); while (target_vec->size() < start + dataLenConverted) { target_vec->resize(target_vec->size() * 2); } - WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), reinterpret_cast(&(*target_vec)[start]), dataLenConverted, NULL, NULL); + WideCharToMultiByte(CP_UTF8, 0, wcharSource, static_cast(dataLenW), + reinterpret_cast(&(*target_vec)[start]), + dataLenConverted, NULL, NULL); arrowColumnProducer->varVal[idxRowArrow + 1] = start + dataLenConverted; #else // On Unix, use the SQLWCHARToWString utility and then convert to UTF-8 @@ -5262,8 +5372,9 @@ SQLRETURN FetchArrowBatch_wrap( break; } case SQL_GUID: { - // GUID is stored as a 36-character string in Arrow (e.g., "550e8400-e29b-41d4-a716-446655440000") - // Each GUID is exactly 36 bytes in UTF-8 + // GUID is stored as a 36-character string in Arrow (e.g., + // "550e8400-e29b-41d4-a716-446655440000") Each GUID is exactly 36 bytes in + // UTF-8 auto target_vec = &arrowColumnProducer->varData; auto start = arrowColumnProducer->varVal[idxRowArrow]; @@ -5277,37 +5388,40 @@ SQLRETURN FetchArrowBatch_wrap( // Convert GUID to string format: xxxxxxxx-xxxx-xxxx-xxxx-xxxxxxxxxxxx snprintf(reinterpret_cast(&target_vec->data()[start]), 37, - "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", - guidValue.Data1, - guidValue.Data2, - guidValue.Data3, - guidValue.Data4[0], guidValue.Data4[1], - guidValue.Data4[2], guidValue.Data4[3], - guidValue.Data4[4], guidValue.Data4[5], - guidValue.Data4[6], guidValue.Data4[7]); + "%08X-%04X-%04X-%02X%02X-%02X%02X%02X%02X%02X%02X", + guidValue.Data1, guidValue.Data2, guidValue.Data3, + guidValue.Data4[0], guidValue.Data4[1], guidValue.Data4[2], + guidValue.Data4[3], guidValue.Data4[4], guidValue.Data4[5], + guidValue.Data4[6], guidValue.Data4[7]); // Update offset for next row, ignoring null terminator arrowColumnProducer->varVal[idxRowArrow + 1] = start + 36; break; } case SQL_TINYINT: - arrowColumnProducer->uint8Val[idxRowArrow] = buffers.charBuffers[idxCol][idxRowSql]; + arrowColumnProducer->uint8Val[idxRowArrow] = + buffers.charBuffers[idxCol][idxRowSql]; break; case SQL_SMALLINT: - arrowColumnProducer->int16Val[idxRowArrow] = buffers.smallIntBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int16Val[idxRowArrow] = + buffers.smallIntBuffers[idxCol][idxRowSql]; break; case SQL_INTEGER: - arrowColumnProducer->int32Val[idxRowArrow] = buffers.intBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int32Val[idxRowArrow] = + buffers.intBuffers[idxCol][idxRowSql]; break; case SQL_BIGINT: - arrowColumnProducer->int64Val[idxRowArrow] = buffers.bigIntBuffers[idxCol][idxRowSql]; + arrowColumnProducer->int64Val[idxRowArrow] = + buffers.bigIntBuffers[idxCol][idxRowSql]; break; case SQL_REAL: - arrowColumnProducer->float32Val[idxRowArrow] = buffers.realBuffers[idxCol][idxRowSql]; + arrowColumnProducer->float32Val[idxRowArrow] = + buffers.realBuffers[idxCol][idxRowSql]; break; case SQL_FLOAT: case SQL_DOUBLE: - arrowColumnProducer->float64Val[idxRowArrow] = buffers.doubleBuffers[idxCol][idxRowSql]; + arrowColumnProducer->float64Val[idxRowArrow] = + buffers.doubleBuffers[idxCol][idxRowSql]; break; case SQL_DECIMAL: case SQL_NUMERIC: { @@ -5321,23 +5435,23 @@ SQLRETURN FetchArrowBatch_wrap( if (digitChar == '-') { sign = -1; } else if (digitChar >= '0' && digitChar <= '9') { - decimalValue = decimalValue.multiply_by_10() + (uint64_t)(digitChar - '0'); + decimalValue = + decimalValue.multiply_by_10() + (uint64_t)(digitChar - '0'); } } - arrowColumnProducer->decimalVal[idxRowArrow] = (sign > 0) ? decimalValue : -decimalValue; + arrowColumnProducer->decimalVal[idxRowArrow] = + (sign > 0) ? decimalValue : -decimalValue; break; } case SQL_TIMESTAMP: case SQL_TYPE_TIMESTAMP: case SQL_DATETIME: { - SQL_TIMESTAMP_STRUCT sql_value = buffers.timestampBuffers[idxCol][idxRowSql]; - int64_t days = days_from_civil( - sql_value.year, - sql_value.month, - sql_value.day - ); - arrowColumnProducer->tsMicroVal[idxRowArrow] = - days * 86400 * 1000000 + + SQL_TIMESTAMP_STRUCT sql_value = + buffers.timestampBuffers[idxCol][idxRowSql]; + int64_t days = + days_from_civil(sql_value.year, sql_value.month, sql_value.day); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + static_cast(sql_value.hour) * 3600 * 1000000 + static_cast(sql_value.minute) * 60 * 1000000 + static_cast(sql_value.second) * 1000000 + @@ -5346,29 +5460,30 @@ SQLRETURN FetchArrowBatch_wrap( } case SQL_SS_TIMESTAMPOFFSET: { DateTimeOffset sql_value = buffers.datetimeoffsetBuffers[idxCol][idxRowSql]; - int64_t days = days_from_civil( - sql_value.year, - sql_value.month, - sql_value.day - ); - arrowColumnProducer->tsMicroVal[idxRowArrow] = - days * 86400 * 1000000 + - (static_cast(sql_value.hour) - static_cast(sql_value.timezone_hour)) * 3600 * 1000000 + - (static_cast(sql_value.minute) - static_cast(sql_value.timezone_minute)) * 60 * 1000000 + + int64_t days = + days_from_civil(sql_value.year, sql_value.month, sql_value.day); + arrowColumnProducer->tsMicroVal[idxRowArrow] = + days * 86400 * 1000000 + + (static_cast(sql_value.hour) - + static_cast(sql_value.timezone_hour)) * + 3600 * 1000000 + + (static_cast(sql_value.minute) - + static_cast(sql_value.timezone_minute)) * + 60 * 1000000 + static_cast(sql_value.second) * 1000000 + static_cast(sql_value.fraction) / 1000; break; } case SQL_TYPE_DATE: - arrowColumnProducer->dateVal[idxRowArrow] = days_from_civil( - buffers.dateBuffers[idxCol][idxRowSql].year, - buffers.dateBuffers[idxCol][idxRowSql].month, - buffers.dateBuffers[idxCol][idxRowSql].day - ); + arrowColumnProducer->dateVal[idxRowArrow] = + days_from_civil(buffers.dateBuffers[idxCol][idxRowSql].year, + buffers.dateBuffers[idxCol][idxRowSql].month, + buffers.dateBuffers[idxCol][idxRowSql].day); break; case SQL_SS_TIME2: { - const SQL_SS_TIME2_STRUCT& timeValue = buffers.timeBuffers[idxCol][idxRowSql]; - arrowColumnProducer->timeNanoVal[idxRowArrow] = + const SQL_SS_TIME2_STRUCT& timeValue = + buffers.timeBuffers[idxCol][idxRowSql]; + arrowColumnProducer->timeNanoVal[idxRowArrow] = static_cast(timeValue.hour) * 3600 * 1000000000 + static_cast(timeValue.minute) * 60 * 1000000000 + static_cast(timeValue.second) * 1000000000 + @@ -5379,11 +5494,11 @@ SQLRETURN FetchArrowBatch_wrap( // SQL_BIT is stored as a single bit in Arrow's bitmap format // Get the boolean value from the buffer bool bitValue = buffers.charBuffers[idxCol][idxRowSql] != 0; - + // Set the bit in the Arrow bitmap size_t byteIndex = idxRowArrow / 8; size_t bitIndex = idxRowArrow % 8; - + if (bitValue) { // Set bit to 1 arrowColumnProducer->bitVal[byteIndex] |= (1 << bitIndex); @@ -5419,7 +5534,7 @@ SQLRETURN FetchArrowBatch_wrap( // Second, transfer ownership to arrowSchemaBatch // No unhandled exceptions until the pycapsule owns the arrowSchemaBatch to avoid memory leaks - + for (SQLSMALLINT i = 0; i < numCols; i++) { *arrowSchemaBatchChildPointers[i] = { arrowSchemaPrivateData[i]->format.get(), @@ -5434,7 +5549,7 @@ SQLRETURN FetchArrowBatch_wrap( assert(schema->release != nullptr); assert(schema->private_data != nullptr); assert(schema->children == nullptr && schema->n_children == 0); - delete schema->private_data; // Frees format and name + delete schema->private_data; // Frees format and name schema->release = nullptr; }, arrowSchemaPrivateData[i].release(), @@ -5477,13 +5592,14 @@ SQLRETURN FetchArrowBatch_wrap( // Finally, transfer ownership of arrowSchemaBatch and its pointer to pycapsule py::capsule arrowSchemaBatchCapsule; try { - arrowSchemaBatchCapsule = py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { - auto arrowSchema = static_cast(ptr); - if (arrowSchema->release) { - arrowSchema->release(arrowSchema); - } - delete arrowSchema; - }); + arrowSchemaBatchCapsule = + py::capsule(arrowSchemaBatch.get(), "arrow_schema", [](void* ptr) { + auto arrowSchema = static_cast(ptr); + if (arrowSchema->release) { + arrowSchema->release(arrowSchema); + } + delete arrowSchema; + }); } catch (...) { arrowSchemaBatch->release(arrowSchemaBatch.get()); throw; @@ -5527,7 +5643,7 @@ SQLRETURN FetchArrowBatch_wrap( assert(array->release != nullptr); assert(array->children == nullptr); assert(array->n_children == 0); - delete array->private_data; // Frees all buffer entries + delete array->private_data; // Frees all buffer entries assert(array->buffers != nullptr); array->release = nullptr; }, @@ -5664,7 +5780,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, // No LOBs detected - use binding path with batch fetching // Define a memory limit (1 GB) const size_t memoryLimit = 1ULL * 1024 * 1024 * 1024; - size_t totalRowSize = calculateRowSize(columnNames, numCols); + size_t totalRowSize = calculateRowSize(columnNames, numCols, charEncoding); // Calculate fetch size based on the total row size and memory limit size_t numRowsInMemLimit; @@ -5705,7 +5821,7 @@ SQLRETURN FetchAll_wrap(SqlHandlePtr StatementHandle, py::list& rows, ColumnBuffers buffers(numCols, fetchSize); // Bind columns - ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize); + ret = SQLBindColums(hStmt, buffers, columnNames, numCols, fetchSize, charEncoding); if (!SQL_SUCCEEDED(ret)) { LOG("FetchAll_wrap: Error when binding columns - SQLRETURN=%d", ret); return ret; @@ -5895,7 +6011,8 @@ PYBIND11_MODULE(ddbc_bindings, m) { py::class_(m, "SqlHandle") .def("free", &SqlHandle::free, "Free the handle") - .def("_close_cursor", &SqlHandle::close_cursor, "Internal: close the cursor without freeing the prepared statement"); + .def("_close_cursor", &SqlHandle::close_cursor, + "Internal: close the cursor without freeing the prepared statement"); py::class_(m, "Connection") .def(py::init(), py::arg("conn_str"), @@ -5936,9 +6053,11 @@ PYBIND11_MODULE(ddbc_bindings, m) { m.def("DDBCSQLFetchAll", &FetchAll_wrap, "Fetch all rows from the result set", py::arg("StatementHandle"), py::arg("rows"), py::arg("charEncoding") = "utf-8", py::arg("wcharEncoding") = "utf-16le"); - m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, "Fetch an arrow batch of given length from the result set"); + m.def("DDBCSQLFetchArrowBatch", &FetchArrowBatch_wrap, + "Fetch an arrow batch of given length from the result set"); m.def("DDBCSQLFreeHandle", &SQLFreeHandle_wrap, "Free a handle"); - m.def("DDBCSQLResetStmt", &SQLResetStmt_wrap, "Close cursor and unbind params without freeing HSTMT"); + m.def("DDBCSQLResetStmt", &SQLResetStmt_wrap, + "Close cursor and unbind params without freeing HSTMT"); m.def("DDBCSQLCheckError", &SQLCheckError_Wrap, "Check for driver errors"); m.def("DDBCSQLGetAllDiagRecords", &SQLGetAllDiagRecords, "Get all diagnostic records for a handle", py::arg("handle")); diff --git a/tests/test_013_encoding_decoding.py b/tests/test_013_encoding_decoding.py index 034afae6..31b24daf 100644 --- a/tests/test_013_encoding_decoding.py +++ b/tests/test_013_encoding_decoding.py @@ -7256,5 +7256,347 @@ def test_dae_encoding_large_string(db_connection): cursor.close() +def test_varchar_utf8_collation_unicode_roundtrip(db_connection): + """Test that VARCHAR columns with UTF-8 collation properly round-trip Unicode data. + + This tests the scenario where a VARCHAR column uses a UTF-8 collation + (e.g., Latin1_General_100_CI_AS_SC_UTF8) which enables storing full Unicode + in VARCHAR. The ODBC driver on Windows converts SQL_C_CHAR data to the + system ANSI code page (e.g., CP1252), which is lossy for non-Latin characters. + The fix fetches such columns as SQL_C_WCHAR (UTF-16LE) to preserve all Unicode. + + Covers: fetchone, fetchall, fetchmany paths. + """ + cursor = db_connection.cursor() + + try: + # Create table with UTF-8 collation on VARCHAR column + cursor.execute(""" + CREATE TABLE #test_varchar_utf8_collation ( + id INT PRIMARY KEY, + varchar_utf8 VARCHAR(200) COLLATE Latin1_General_100_CI_AS_SC_UTF8, + nvarchar_ref NVARCHAR(200) + ) + """) + + # Configure UTF-8 decoding for SQL_CHAR (VARCHAR) + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + db_connection.setdecoding(SQL_WCHAR, encoding="utf-16le") + + # Test cases covering BMP and supplementary plane characters + test_cases = [ + (1, "Hello World"), # ASCII baseline + (2, "Grüße"), # German - extended Latin (in CP1252 range) + (3, "你好世界"), # Chinese - outside CP1252 + (4, "こんにちは"), # Japanese Hiragana - outside CP1252 + (5, "Привет"), # Russian Cyrillic - outside CP1252 + (6, "Hello 世界"), # Mixed ASCII + CJK + (7, "😀😃😄😁"), # Emoji - supplementary plane (4-byte UTF-8) + (8, "Ελληνικά"), # Greek + (9, "مرحبا"), # Arabic + (10, "café résumé naïve"), # French accented + ] + + # Insert using parameterized queries + for id_val, text in test_cases: + cursor.execute( + "INSERT INTO #test_varchar_utf8_collation (id, varchar_utf8, nvarchar_ref) " + "VALUES (?, ?, ?)", + id_val, + text, + text, + ) + + # ---- Test fetchone path ---- + for id_val, expected_text in test_cases: + cursor.execute( + "SELECT varchar_utf8, nvarchar_ref FROM #test_varchar_utf8_collation WHERE id = ?", + id_val, + ) + row = cursor.fetchone() + assert row is not None, f"No row returned for id={id_val}" + + varchar_result = row[0] + nvarchar_result = row[1] + + # NVARCHAR should always work (baseline check) + assert nvarchar_result == expected_text, ( + f"NVARCHAR mismatch for id={id_val}: " + f"expected {expected_text!r}, got {nvarchar_result!r}" + ) + + # VARCHAR with UTF-8 collation should also return correct str + assert isinstance(varchar_result, str), ( + f"VARCHAR UTF-8 returned {type(varchar_result).__name__} instead of str " + f"for id={id_val} ({expected_text!r}): got {varchar_result!r}" + ) + assert varchar_result == expected_text, ( + f"VARCHAR UTF-8 mismatch for id={id_val}: " + f"expected {expected_text!r}, got {varchar_result!r}" + ) + + # ---- Test fetchall path ---- + cursor.execute( + "SELECT id, varchar_utf8, nvarchar_ref " "FROM #test_varchar_utf8_collation ORDER BY id" + ) + all_rows = cursor.fetchall() + assert len(all_rows) == len( + test_cases + ), f"fetchall row count mismatch: expected {len(test_cases)}, got {len(all_rows)}" + for row, (expected_id, expected_text) in zip(all_rows, test_cases): + assert row[1] == expected_text, ( + f"fetchall VARCHAR UTF-8 mismatch for id={expected_id}: " + f"expected {expected_text!r}, got {row[1]!r}" + ) + assert row[2] == expected_text, ( + f"fetchall NVARCHAR mismatch for id={expected_id}: " + f"expected {expected_text!r}, got {row[2]!r}" + ) + + # ---- Test fetchmany path ---- + cursor.execute( + "SELECT id, varchar_utf8, nvarchar_ref " "FROM #test_varchar_utf8_collation ORDER BY id" + ) + many_rows = cursor.fetchmany(5) + assert len(many_rows) == 5, f"fetchmany(5) returned {len(many_rows)} rows" + for row, (expected_id, expected_text) in zip(many_rows, test_cases[:5]): + assert row[1] == expected_text, ( + f"fetchmany VARCHAR UTF-8 mismatch for id={expected_id}: " + f"expected {expected_text!r}, got {row[1]!r}" + ) + + finally: + try: + cursor.execute("DROP TABLE #test_varchar_utf8_collation") + except: + pass + cursor.close() + + +def test_varchar_utf8_collation_null_and_empty(db_connection): + """Test NULL and empty string handling for VARCHAR UTF-8 collation columns. + + Covers the NULL (SQL_NULL_DATA) and empty string (dataLen==0) branches + in the WCHAR fetch path of SQLGetData_wrap and the batch paths. + """ + cursor = db_connection.cursor() + + try: + cursor.execute(""" + CREATE TABLE #test_utf8_null_empty ( + id INT PRIMARY KEY, + varchar_utf8 VARCHAR(200) COLLATE Latin1_General_100_CI_AS_SC_UTF8, + nvarchar_ref NVARCHAR(200) + ) + """) + + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + db_connection.setdecoding(SQL_WCHAR, encoding="utf-16le") + + # Insert NULL and empty string rows + cursor.execute( + "INSERT INTO #test_utf8_null_empty (id, varchar_utf8, nvarchar_ref) " + "VALUES (?, ?, ?)", + 1, + None, + None, + ) + cursor.execute( + "INSERT INTO #test_utf8_null_empty (id, varchar_utf8, nvarchar_ref) " + "VALUES (?, ?, ?)", + 2, + "", + "", + ) + cursor.execute( + "INSERT INTO #test_utf8_null_empty (id, varchar_utf8, nvarchar_ref) " + "VALUES (?, ?, ?)", + 3, + "hello", + "hello", + ) + + # ---- Test fetchone path (covers lines 3291-3295) ---- + # NULL row + cursor.execute("SELECT varchar_utf8, nvarchar_ref FROM #test_utf8_null_empty WHERE id = 1") + row = cursor.fetchone() + assert row is not None + assert row[0] is None, f"Expected NULL varchar, got {row[0]!r}" + assert row[1] is None, f"Expected NULL nvarchar, got {row[1]!r}" + + # Empty string row + cursor.execute("SELECT varchar_utf8, nvarchar_ref FROM #test_utf8_null_empty WHERE id = 2") + row = cursor.fetchone() + assert row is not None + assert row[0] == "", f"Expected empty varchar, got {row[0]!r}" + assert row[1] == "", f"Expected empty nvarchar, got {row[1]!r}" + + # Normal row + cursor.execute("SELECT varchar_utf8, nvarchar_ref FROM #test_utf8_null_empty WHERE id = 3") + row = cursor.fetchone() + assert row is not None + assert row[0] == "hello", f"Expected 'hello', got {row[0]!r}" + + # ---- Test fetchall path (covers batch bind lines 3854-3859, dispatch line 4127) ---- + cursor.execute("SELECT id, varchar_utf8 FROM #test_utf8_null_empty ORDER BY id") + all_rows = cursor.fetchall() + assert len(all_rows) == 3 + assert all_rows[0][1] is None, f"fetchall: expected NULL, got {all_rows[0][1]!r}" + assert all_rows[1][1] == "", f"fetchall: expected empty, got {all_rows[1][1]!r}" + assert all_rows[2][1] == "hello", f"fetchall: expected 'hello', got {all_rows[2][1]!r}" + + # ---- Test fetchmany path ---- + cursor.execute("SELECT id, varchar_utf8 FROM #test_utf8_null_empty ORDER BY id") + many_rows = cursor.fetchmany(3) + assert len(many_rows) == 3 + assert many_rows[0][1] is None + assert many_rows[1][1] == "" + assert many_rows[2][1] == "hello" + + finally: + try: + cursor.execute("DROP TABLE #test_utf8_null_empty") + except: + pass + cursor.close() + + +def test_varchar_utf8_collation_lob_streaming(db_connection): + """Test VARCHAR(MAX) with UTF-8 collation triggers LOB streaming path. + + VARCHAR(MAX) columns have columnSize=0 or SQL_NO_TOTAL, which routes + through FetchLobColumnData with SQL_C_WCHAR when UTF-8 WCHAR workaround + is active. Covers lines 3237-3241 (LOB WCHAR path). + """ + cursor = db_connection.cursor() + + try: + cursor.execute(""" + CREATE TABLE #test_utf8_lob ( + id INT PRIMARY KEY, + varchar_max_utf8 VARCHAR(MAX) COLLATE Latin1_General_100_CI_AS_SC_UTF8, + nvarchar_max_ref NVARCHAR(MAX) + ) + """) + + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + db_connection.setdecoding(SQL_WCHAR, encoding="utf-16le") + + # Test cases: short, medium, and large Unicode strings + test_cases = [ + (1, "Hello World"), + (2, "你好世界 こんにちは Привет"), + (3, "café résumé naïve Grüße"), + (4, "😀😃😄😁🌍🌎🌏"), + # Large string to exercise LOB streaming loop + (5, "Unicode混合テスト" * 500), + ] + + for id_val, text in test_cases: + cursor.execute( + "INSERT INTO #test_utf8_lob (id, varchar_max_utf8, nvarchar_max_ref) " + "VALUES (?, ?, ?)", + id_val, + text, + text, + ) + + # ---- Test fetchone path (LOB streaming per-row) ---- + for id_val, expected_text in test_cases: + cursor.execute( + "SELECT varchar_max_utf8, nvarchar_max_ref " "FROM #test_utf8_lob WHERE id = ?", + id_val, + ) + row = cursor.fetchone() + assert row is not None, f"No row for id={id_val}" + assert isinstance( + row[0], str + ), f"LOB VARCHAR(MAX) returned {type(row[0]).__name__} for id={id_val}" + assert row[0] == expected_text, ( + f"LOB VARCHAR(MAX) mismatch for id={id_val}: " + f"expected len={len(expected_text)}, got len={len(row[0])}" + ) + assert row[1] == expected_text, f"LOB NVARCHAR(MAX) mismatch for id={id_val}" + + # ---- Test fetchall path (LOB triggers per-row SQLGetData fallback) ---- + cursor.execute("SELECT id, varchar_max_utf8 FROM #test_utf8_lob ORDER BY id") + all_rows = cursor.fetchall() + assert len(all_rows) == len(test_cases) + for row, (expected_id, expected_text) in zip(all_rows, test_cases): + assert row[0] == expected_id + assert row[1] == expected_text, f"fetchall LOB mismatch for id={expected_id}" + + # ---- Test NULL in VARCHAR(MAX) ---- + cursor.execute( + "INSERT INTO #test_utf8_lob (id, varchar_max_utf8, nvarchar_max_ref) " + "VALUES (?, ?, ?)", + 99, + None, + None, + ) + cursor.execute("SELECT varchar_max_utf8 FROM #test_utf8_lob WHERE id = 99") + row = cursor.fetchone() + assert row is not None + assert row[0] is None, f"Expected NULL for VARCHAR(MAX), got {row[0]!r}" + + finally: + try: + cursor.execute("DROP TABLE #test_utf8_lob") + except: + pass + cursor.close() + + +def test_varchar_utf8_collation_encoding_variants(db_connection): + """Test that various UTF-8 encoding name variants all activate the WCHAR workaround. + + Python codec names are case-insensitive and can use '-' or '_' separators. + ShouldFetchCharAsWChar must handle: utf-8, UTF-8, utf8, Utf_8, UTF_8, etc. + """ + cursor = db_connection.cursor() + + try: + cursor.execute(""" + CREATE TABLE #test_utf8_variants ( + id INT PRIMARY KEY, + varchar_utf8 VARCHAR(100) COLLATE Latin1_General_100_CI_AS_SC_UTF8 + ) + """) + + unicode_text = "你好世界" + + # Test multiple encoding name variants + encoding_variants = ["utf-8", "UTF-8", "utf8"] + + for i, enc in enumerate(encoding_variants): + db_connection.setdecoding(SQL_CHAR, encoding=enc) + + cursor.execute("DELETE FROM #test_utf8_variants") + cursor.execute( + "INSERT INTO #test_utf8_variants (id, varchar_utf8) VALUES (?, ?)", + 1, + unicode_text, + ) + cursor.execute("SELECT varchar_utf8 FROM #test_utf8_variants WHERE id = 1") + row = cursor.fetchone() + assert row is not None + assert isinstance( + row[0], str + ), f"Encoding variant '{enc}' returned {type(row[0]).__name__} instead of str" + assert ( + row[0] == unicode_text + ), f"Encoding variant '{enc}' mismatch: expected {unicode_text!r}, got {row[0]!r}" + + # Restore default + db_connection.setdecoding(SQL_CHAR, encoding="utf-8") + + finally: + try: + cursor.execute("DROP TABLE #test_utf8_variants") + except: + pass + cursor.close() + + if __name__ == "__main__": pytest.main([__file__, "-v"])