diff --git a/README.md b/README.md index e4a1e35d..7b426a6f 100644 --- a/README.md +++ b/README.md @@ -151,6 +151,7 @@ The following switches have different behavior in this version of `sqlcmd` compa - To provide the value of the host name in the server certificate when using strict encryption, pass the host name with `-F`. Example: `-Ns -F myhost.domain.com` - More information about client/server encryption negotiation can be found at - `-u` The generated Unicode output file will have the UTF16 Little-Endian Byte-order mark (BOM) written to it. +- `-f` Specifies the code page for input and output files. See [Code Page Support](#code-page-support) below for details and examples. - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. - All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. - `-i` doesn't handle a comma `,` in a file name correctly unless the file name argument is triple quoted. For example: @@ -255,6 +256,68 @@ To see a list of available styles along with colored syntax samples, use this co :list color ``` +### Code Page Support + +The `-f` flag specifies the code page for reading input files and writing output. This is useful when working with SQL scripts saved in legacy encodings or when output needs to be in a specific encoding. + +#### Format + +``` +-f codepage # Set both input and output to the same codepage +-f i:codepage # Set input codepage only +-f o:codepage # Set output codepage only +-f i:codepage,o:codepage # Set input and output to different codepages +-f o:codepage,i:codepage # Same as above (order doesn't matter) +``` + +#### Common Code Pages + +| Code Page | Name | Description | +|-----------|------|-------------| +| 65001 | UTF-8 | Unicode (UTF-8) - default for most modern systems | +| 1200 | UTF-16LE | Unicode (UTF-16 Little-Endian) | +| 1201 | UTF-16BE | Unicode (UTF-16 Big-Endian) | +| 1252 | Windows-1252 | Western European (Windows) | +| 932 | Shift_JIS | Japanese | +| 936 | GBK | Chinese Simplified | +| 949 | EUC-KR | Korean | +| 950 | Big5 | Chinese Traditional | +| 437 | CP437 | OEM United States (DOS) | + +#### Examples + +**Run a script saved in Windows-1252 encoding:** +```bash +sqlcmd -S myserver -i legacy_script.sql -f 1252 +``` + +**Read UTF-16 input file and write UTF-8 output:** +```bash +sqlcmd -S myserver -i unicode_script.sql -o results.txt -f i:1200,o:65001 +``` + +**Process a Japanese Shift-JIS encoded script:** +```bash +sqlcmd -S myserver -i japanese_data.sql -f 932 +``` + +**Write output in Windows-1252 for legacy applications:** +```bash +sqlcmd -S myserver -Q "SELECT * FROM Products" -o report.txt -f o:1252 +``` + +**List all supported code pages:** +```bash +sqlcmd --list-codepages +``` + +#### Notes + +- When no `-f` flag is specified, sqlcmd auto-detects UTF-8/UTF-16LE/UTF-16BE BOM (Byte Order Mark) in input files and switches to the appropriate decoder. If no BOM is present, UTF-8 is assumed. +- UTF-8 input files with BOM are handled automatically. +- On Windows, additional codepages installed on the system are available via the Windows API, even if not shown by `--list-codepages`. +- Use `--list-codepages` to see the built-in code pages with their names and descriptions. + ### Packages #### sqlcmd executable diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index 7d69b24b..5bf59494 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -82,6 +82,11 @@ type SQLCmdArguments struct { ChangePassword string ChangePasswordAndExit string TraceFile string + CodePage string + // codePageSettings stores the parsed CodePageSettings after validation. + // This avoids parsing CodePage twice (in Validate and run). + codePageSettings *sqlcmd.CodePageSettings + ListCodePages bool // Keep Help at the end of the list Help bool } @@ -171,6 +176,12 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true) case a.ServerCertificate != "" && !encryptConnectionAllowsTLS(a.EncryptConnection): err = localizer.Errorf("The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict).") + case a.CodePage != "": + if codePageSettings, parseErr := sqlcmd.ParseCodePage(a.CodePage); parseErr != nil { + err = localizer.Errorf(`'-f %s': %v`, a.CodePage, parseErr) + } else { + a.codePageSettings = codePageSettings + } } } if err != nil { @@ -239,6 +250,17 @@ func Execute(version string) { listLocalServers() os.Exit(0) } + // List supported codepages + if args.ListCodePages { + fmt.Println(localizer.Sprintf("Supported Code Pages:")) + fmt.Println() + fmt.Printf("%-8s %-20s %s\n", "Code", "Name", "Description") + fmt.Printf("%-8s %-20s %s\n", "----", "----", "-----------") + for _, cp := range sqlcmd.SupportedCodePages() { + fmt.Printf("%-8d %-20s %s\n", cp.CodePage, cp.Name, cp.Description) + } + os.Exit(0) + } if len(argss) > 0 { fmt.Printf("%s'%s': Unknown command. Enter '--help' for command help.", sqlcmdErrorPrefix, argss[0]) os.Exit(1) @@ -479,6 +501,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().BoolVarP(&args.EnableColumnEncryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption")) rootCmd.Flags().StringVarP(&args.ChangePassword, "change-password", "z", "", localizer.Sprintf("New password")) rootCmd.Flags().StringVarP(&args.ChangePasswordAndExit, "change-password-exit", "Z", "", localizer.Sprintf("New password and exit")) + rootCmd.Flags().StringVarP(&args.CodePage, "code-page", "f", "", localizer.Sprintf("Specifies the code page for input/output. Use 65001 for UTF-8. Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage]")) + rootCmd.Flags().BoolVar(&args.ListCodePages, "list-codepages", false, localizer.Sprintf("List supported code pages and exit")) } func setScriptVariable(v string) string { @@ -817,6 +841,11 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { defer s.StopCloseHandler() s.UnicodeOutputFile = args.UnicodeOutputFile + // Apply codepage settings (already parsed and validated in Validate) + if args.codePageSettings != nil { + s.CodePage = args.codePageSettings + } + if args.DisableCmd != nil { s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 511816b2..cfdbcf31 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -123,6 +123,22 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool { return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem" }}, + // Codepage flag tests + {[]string{"-f", "65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "65001" + }}, + {[]string{"-f", "i:1252,o:65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "i:1252,o:65001" + }}, + {[]string{"-f", "o:65001,i:1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "o:65001,i:1252" + }}, + {[]string{"--code-page", "1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "1252" + }}, + {[]string{"--list-codepages"}, func(args SQLCmdArguments) bool { + return args.ListCodePages + }}, } for _, test := range commands { @@ -178,6 +194,11 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "disable", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "strict", "-F", "myserver.domain.com", "-J", "/path/to/cert.pem"}, "The -F and the -J options are mutually exclusive."}, + // Codepage validation tests + {[]string{"-f", "invalid"}, `'-f invalid': invalid codepage: invalid`}, + {[]string{"-f", "99999"}, `'-f 99999': unsupported codepage 99999`}, + {[]string{"-f", "i:invalid"}, `'-f i:invalid': invalid input codepage: i:invalid`}, + {[]string{"-f", "x:1252"}, `'-f x:1252': invalid codepage: x:1252`}, } for _, test := range commands { diff --git a/pkg/sqlcmd/codepage.go b/pkg/sqlcmd/codepage.go new file mode 100644 index 00000000..fc44e5b5 --- /dev/null +++ b/pkg/sqlcmd/codepage.go @@ -0,0 +1,214 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "sort" + "strconv" + "strings" + + "github.com/microsoft/go-sqlcmd/internal/localizer" + "golang.org/x/text/encoding" + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/japanese" + "golang.org/x/text/encoding/korean" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/encoding/traditionalchinese" + "golang.org/x/text/encoding/unicode" +) + +// codepageEntry defines a codepage with its encoding and metadata +type codepageEntry struct { + encoding encoding.Encoding // nil for UTF-8 (Go's native encoding) + name string + description string +} + +// codepageRegistry is the single source of truth for all supported codepages +// that work cross-platform. Both GetEncoding and SupportedCodePages use this +// registry. On Windows, additional codepages installed on the system are also +// available via the Windows API fallback in GetEncoding. +var codepageRegistry = map[int]codepageEntry{ + // Unicode + 65001: {nil, "UTF-8", "Unicode (UTF-8)"}, + 1200: {unicode.UTF16(unicode.LittleEndian, unicode.UseBOM), "UTF-16LE", "Unicode (UTF-16 Little-Endian)"}, + 1201: {unicode.UTF16(unicode.BigEndian, unicode.UseBOM), "UTF-16BE", "Unicode (UTF-16 Big-Endian)"}, + + // OEM/DOS codepages + 437: {charmap.CodePage437, "CP437", "OEM United States"}, + 850: {charmap.CodePage850, "CP850", "OEM Multilingual Latin 1"}, + 852: {charmap.CodePage852, "CP852", "OEM Latin 2"}, + 855: {charmap.CodePage855, "CP855", "OEM Cyrillic"}, + 858: {charmap.CodePage858, "CP858", "OEM Multilingual Latin 1 + Euro"}, + 860: {charmap.CodePage860, "CP860", "OEM Portuguese"}, + 862: {charmap.CodePage862, "CP862", "OEM Hebrew"}, + 863: {charmap.CodePage863, "CP863", "OEM Canadian French"}, + 865: {charmap.CodePage865, "CP865", "OEM Nordic"}, + 866: {charmap.CodePage866, "CP866", "OEM Russian"}, + + // Windows codepages + 874: {charmap.Windows874, "Windows-874", "Thai"}, + 1250: {charmap.Windows1250, "Windows-1250", "Central European"}, + 1251: {charmap.Windows1251, "Windows-1251", "Cyrillic"}, + 1252: {charmap.Windows1252, "Windows-1252", "Western European"}, + 1253: {charmap.Windows1253, "Windows-1253", "Greek"}, + 1254: {charmap.Windows1254, "Windows-1254", "Turkish"}, + 1255: {charmap.Windows1255, "Windows-1255", "Hebrew"}, + 1256: {charmap.Windows1256, "Windows-1256", "Arabic"}, + 1257: {charmap.Windows1257, "Windows-1257", "Baltic"}, + 1258: {charmap.Windows1258, "Windows-1258", "Vietnamese"}, + + // ISO-8859 codepages + 28591: {charmap.ISO8859_1, "ISO-8859-1", "Latin 1 (Western European)"}, + 28592: {charmap.ISO8859_2, "ISO-8859-2", "Latin 2 (Central European)"}, + 28593: {charmap.ISO8859_3, "ISO-8859-3", "Latin 3 (South European)"}, + 28594: {charmap.ISO8859_4, "ISO-8859-4", "Latin 4 (North European)"}, + 28595: {charmap.ISO8859_5, "ISO-8859-5", "Cyrillic"}, + 28596: {charmap.ISO8859_6, "ISO-8859-6", "Arabic"}, + 28597: {charmap.ISO8859_7, "ISO-8859-7", "Greek"}, + 28598: {charmap.ISO8859_8, "ISO-8859-8", "Hebrew"}, + 28599: {charmap.ISO8859_9, "ISO-8859-9", "Turkish"}, + 28600: {charmap.ISO8859_10, "ISO-8859-10", "Nordic"}, + 28603: {charmap.ISO8859_13, "ISO-8859-13", "Baltic"}, + 28604: {charmap.ISO8859_14, "ISO-8859-14", "Celtic"}, + 28605: {charmap.ISO8859_15, "ISO-8859-15", "Latin 9 (Western European with Euro)"}, + 28606: {charmap.ISO8859_16, "ISO-8859-16", "Latin 10 (South-Eastern European)"}, + + // Cyrillic + 20866: {charmap.KOI8R, "KOI8-R", "Russian"}, + 21866: {charmap.KOI8U, "KOI8-U", "Ukrainian"}, + + // Macintosh + 10000: {charmap.Macintosh, "Macintosh", "Mac Roman"}, + 10007: {charmap.MacintoshCyrillic, "x-mac-cyrillic", "Mac Cyrillic"}, + + // EBCDIC + 37: {charmap.CodePage037, "IBM037", "EBCDIC US-Canada"}, + 1047: {charmap.CodePage1047, "IBM1047", "EBCDIC Latin 1/Open System"}, + 1140: {charmap.CodePage1140, "IBM01140", "EBCDIC US-Canada with Euro"}, + + // Japanese + 932: {japanese.ShiftJIS, "Shift_JIS", "Japanese (Shift-JIS)"}, + 20932: {japanese.EUCJP, "EUC-JP", "Japanese (EUC)"}, + 50220: {japanese.ISO2022JP, "ISO-2022-JP", "Japanese (JIS)"}, + 50221: {japanese.ISO2022JP, "csISO2022JP", "Japanese (JIS-Allow 1 byte Kana)"}, + 50222: {japanese.ISO2022JP, "ISO-2022-JP", "Japanese (JIS-Allow 1 byte Kana SO/SI)"}, + + // Korean + 949: {korean.EUCKR, "EUC-KR", "Korean"}, + 51949: {korean.EUCKR, "EUC-KR", "Korean (EUC)"}, + + // Simplified Chinese + 936: {simplifiedchinese.GBK, "GBK", "Chinese Simplified (GBK)"}, + 54936: {simplifiedchinese.GB18030, "GB18030", "Chinese Simplified (GB18030)"}, + 52936: {simplifiedchinese.HZGB2312, "HZ-GB-2312", "Chinese Simplified (HZ)"}, + + // Traditional Chinese + 950: {traditionalchinese.Big5, "Big5", "Chinese Traditional (Big5)"}, +} + +// CodePageSettings holds the input and output codepage settings +type CodePageSettings struct { + InputCodePage int + OutputCodePage int +} + +// ParseCodePage parses the -f codepage argument +// Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage] +func ParseCodePage(arg string) (*CodePageSettings, error) { + if arg == "" { + return nil, nil + } + + settings := &CodePageSettings{} + parts := strings.Split(arg, ",") + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + if strings.HasPrefix(strings.ToLower(part), "i:") { + // Input codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "i:")) + if err != nil { + return nil, localizer.Errorf("invalid input codepage: %s", part) + } + settings.InputCodePage = cp + } else if strings.HasPrefix(strings.ToLower(part), "o:") { + // Output codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "o:")) + if err != nil { + return nil, localizer.Errorf("invalid output codepage: %s", part) + } + settings.OutputCodePage = cp + } else { + // Both input and output + cp, err := strconv.Atoi(part) + if err != nil { + return nil, localizer.Errorf("invalid codepage: %s", part) + } + settings.InputCodePage = cp + settings.OutputCodePage = cp + } + } + + // If a non-empty argument was provided but no codepage was parsed, + // treat this as an error rather than silently disabling codepage handling. + if settings.InputCodePage == 0 && settings.OutputCodePage == 0 { + return nil, localizer.Errorf("invalid codepage: %s", arg) + } + + // Validate codepages + if settings.InputCodePage != 0 { + if _, err := GetEncoding(settings.InputCodePage); err != nil { + return nil, err + } + } + if settings.OutputCodePage != 0 { + if _, err := GetEncoding(settings.OutputCodePage); err != nil { + return nil, err + } + } + + return settings, nil +} + +// GetEncoding returns the encoding for a given Windows codepage number. +// Returns nil for UTF-8 (65001) since Go uses UTF-8 natively. +// If the codepage is not in the built-in registry, falls back to +// OS-specific support (Windows API on Windows, error on other platforms). +func GetEncoding(codepage int) (encoding.Encoding, error) { + entry, ok := codepageRegistry[codepage] + if !ok { + // Fallback to system-provided codepage support + return getSystemCodePageEncoding(codepage) + } + return entry.encoding, nil +} + +// CodePageInfo describes a supported codepage +type CodePageInfo struct { + CodePage int + Name string + Description string +} + +// SupportedCodePages returns a list of all supported codepages with descriptions +func SupportedCodePages() []CodePageInfo { + result := make([]CodePageInfo, 0, len(codepageRegistry)) + for cp, entry := range codepageRegistry { + result = append(result, CodePageInfo{ + CodePage: cp, + Name: entry.name, + Description: entry.description, + }) + } + // Sort by codepage number for consistent output + sort.Slice(result, func(i, j int) bool { + return result[i].CodePage < result[j].CodePage + }) + return result +} diff --git a/pkg/sqlcmd/codepage_other.go b/pkg/sqlcmd/codepage_other.go new file mode 100644 index 00000000..2908778f --- /dev/null +++ b/pkg/sqlcmd/codepage_other.go @@ -0,0 +1,23 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build !windows + +package sqlcmd + +import ( + "strconv" + + "github.com/microsoft/go-sqlcmd/internal/localizer" + "golang.org/x/text/encoding" +) + +// getSystemCodePageEncoding returns an error on non-Windows platforms +// since we don't have access to Windows API for codepage conversion. +// The built-in codepageRegistry covers the most common codepages. +// For additional codepages (e.g., Japanese EBCDIC), use Windows. +func getSystemCodePageEncoding(codepage int) (encoding.Encoding, error) { + // Use %s with strconv.Itoa to avoid locale-based number formatting + // that would add thousands separators (e.g., "99,999" instead of "99999") + return nil, localizer.Errorf("unsupported codepage %s", strconv.Itoa(codepage)) +} diff --git a/pkg/sqlcmd/codepage_test.go b/pkg/sqlcmd/codepage_test.go new file mode 100644 index 00000000..6844b0af --- /dev/null +++ b/pkg/sqlcmd/codepage_test.go @@ -0,0 +1,391 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "strconv" + "testing" + + "github.com/stretchr/testify/assert" + "golang.org/x/text/transform" +) + +func TestParseCodePage(t *testing.T) { + tests := []struct { + name string + arg string + wantInput int + wantOutput int + wantErr bool + errContains string + }{ + { + name: "empty string", + arg: "", + wantInput: 0, + wantOutput: 0, + wantErr: false, + }, + { + name: "single codepage sets both", + arg: "65001", + wantInput: 65001, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input only", + arg: "i:1252", + wantInput: 1252, + wantOutput: 0, + wantErr: false, + }, + { + name: "output only", + arg: "o:65001", + wantInput: 0, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input and output", + arg: "i:1252,o:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "output and input reversed", + arg: "o:65001,i:1252", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "uppercase prefix", + arg: "I:1252,O:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "invalid codepage number", + arg: "abc", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "invalid input codepage", + arg: "i:abc", + wantErr: true, + errContains: "invalid input codepage", + }, + { + name: "invalid output codepage", + arg: "o:xyz", + wantErr: true, + errContains: "invalid output codepage", + }, + { + name: "unsupported codepage", + arg: "99999", + wantErr: true, + errContains: "codepage", // Error message varies by platform + }, + { + name: "comma only produces no codepage", + arg: ",", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "whitespace only produces no codepage", + arg: " ", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "multiple commas produce no codepage", + arg: ",,,", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "Japanese Shift JIS", + arg: "932", + wantInput: 932, + wantOutput: 932, + wantErr: false, + }, + { + name: "Chinese GBK", + arg: "936", + wantInput: 936, + wantOutput: 936, + wantErr: false, + }, + { + name: "Korean", + arg: "949", + wantInput: 949, + wantOutput: 949, + wantErr: false, + }, + { + name: "Traditional Chinese Big5", + arg: "950", + wantInput: 950, + wantOutput: 950, + wantErr: false, + }, + { + name: "EBCDIC", + arg: "37", + wantInput: 37, + wantOutput: 37, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings, err := ParseCodePage(tt.arg) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + if tt.arg == "" { + assert.Nil(t, settings) + return + } + assert.NotNil(t, settings) + assert.Equal(t, tt.wantInput, settings.InputCodePage) + assert.Equal(t, tt.wantOutput, settings.OutputCodePage) + }) + } +} + +func TestGetEncoding(t *testing.T) { + tests := []struct { + codepage int + wantNil bool // UTF-8 returns nil encoding + wantErr bool + }{ + // Unicode + {65001, true, false}, // UTF-8 + {1200, false, false}, // UTF-16LE + {1201, false, false}, // UTF-16BE + + // OEM/DOS + {437, false, false}, + {850, false, false}, + {866, false, false}, + + // Windows + {874, false, false}, + {1250, false, false}, + {1251, false, false}, + {1252, false, false}, + {1253, false, false}, + {1254, false, false}, + {1255, false, false}, + {1256, false, false}, + {1257, false, false}, + {1258, false, false}, + + // ISO-8859 + {28591, false, false}, + {28592, false, false}, + {28605, false, false}, + + // Cyrillic + {20866, false, false}, + {21866, false, false}, + + // Macintosh + {10000, false, false}, + {10007, false, false}, + + // EBCDIC + {37, false, false}, + {1047, false, false}, + {1140, false, false}, + + // CJK + {932, false, false}, // Japanese Shift JIS + {20932, false, false}, // EUC-JP + {50220, false, false}, // ISO-2022-JP + {949, false, false}, // Korean EUC-KR + {936, false, false}, // Chinese GBK + {54936, false, false}, // GB18030 + {950, false, false}, // Big5 + + // Invalid + {99999, false, true}, + {12345, false, true}, + } + + for _, tt := range tests { + t.Run(strconv.Itoa(tt.codepage), func(t *testing.T) { + enc, err := GetEncoding(tt.codepage) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + if tt.wantNil { + assert.Nil(t, enc, "UTF-8 should return nil encoding") + } else { + assert.NotNil(t, enc, "non-UTF-8 codepage should return encoding") + } + }) + } +} + +func TestSupportedCodePages(t *testing.T) { + cps := SupportedCodePages() + + // Should have entries + assert.Greater(t, len(cps), 0, "should return codepages") + + // Each returned codepage should be valid in GetEncoding + for _, cp := range cps { + _, err := GetEncoding(cp.CodePage) + assert.NoError(t, err, "SupportedCodePages entry %d should be valid in GetEncoding", cp.CodePage) + assert.NotEmpty(t, cp.Name, "codepage %d should have a name", cp.CodePage) + assert.NotEmpty(t, cp.Description, "codepage %d should have a description", cp.CodePage) + } + + // Result should be sorted by codepage number + for i := 1; i < len(cps); i++ { + assert.Less(t, cps[i-1].CodePage, cps[i].CodePage, "codepages should be sorted") + } + + // Check some well-known codepages are present + known := map[int]bool{ + 65001: false, // UTF-8 + 1252: false, // Windows Western + 437: false, // DOS US + 932: false, // Japanese + } + for _, cp := range cps { + if _, ok := known[cp.CodePage]; ok { + known[cp.CodePage] = true + } + } + for cp, found := range known { + assert.True(t, found, "well-known codepage %d should be in list", cp) + } +} + +func TestGetEncodingWindowsFallback(t *testing.T) { + // Japanese EBCDIC (20290) is not in our built-in registry but is available on Windows + // This test verifies that the Windows API fallback works for codepages not in our registry + cp := 20290 // IBM EBCDIC Japanese Katakana Extended + + enc, err := GetEncoding(cp) + + // On Windows, this should succeed because the Windows API can handle this codepage + // On other platforms, this should fail with a helpful error message + if err != nil { + // Expected on non-Windows platforms + assert.Contains(t, err.Error(), "codepage") + } else { + // Expected on Windows - verify the encoding works + assert.NotNil(t, enc) + + // Test round-trip encoding/decoding + // EBCDIC 'A' is 0xC1 + decoder := enc.NewDecoder() + decoded, err := decoder.String(string([]byte{0xC1})) + assert.NoError(t, err, "decoder should work") + assert.Equal(t, "A", decoded, "EBCDIC 0xC1 should decode to 'A'") + + encoder := enc.NewEncoder() + encoded, err := encoder.String("A") + assert.NoError(t, err, "encoder should work") + assert.Equal(t, []byte{0xC1}, []byte(encoded), "'A' should encode to EBCDIC 0xC1") + } + + // Also test that a completely made-up codepage fails on all platforms + _, err = GetEncoding(99999) + assert.Error(t, err, "invalid codepage should fail on all platforms") + assert.Contains(t, err.Error(), "codepage") +} + +func TestWindowsEncodingStreaming(t *testing.T) { + // This test exercises that the Windows API fallback encoding can be used in + // streaming-like scenarios and that it handles single-byte data and + // incomplete UTF-8 input correctly. + + // Japanese EBCDIC (20290) is a good test case as it's only available via Windows API + cp := 20290 // IBM EBCDIC Japanese Katakana Extended + + enc, err := GetEncoding(cp) + if err != nil { + t.Skip("Codepage 20290 not available on this platform") + } + + // Test decoder streaming with transform.Reader + t.Run("decoder streaming", func(t *testing.T) { + // Create a simple EBCDIC encoded string: "ABC" = 0xC1 0xC2 0xC3 + ebcdicData := []byte{0xC1, 0xC2, 0xC3} + + decoder := enc.NewDecoder() + + // Simulate streaming by processing one byte at a time + var result []byte + for i := 0; i < len(ebcdicData); i++ { + decoder.Reset() // Reset between chunks for clean state + dst := make([]byte, 32) + nDst, _, err := decoder.Transform(dst, ebcdicData[i:i+1], i == len(ebcdicData)-1) + if err != nil && err != transform.ErrShortSrc { + t.Fatalf("Transform failed at byte %d: %v", i, err) + } + result = append(result, dst[:nDst]...) + } + assert.Equal(t, "ABC", string(result), "streaming decode should produce 'ABC'") + }) + + // Test encoder streaming + t.Run("encoder streaming", func(t *testing.T) { + // Test encoding "ABC" one character at a time + input := "ABC" + encoder := enc.NewEncoder() + + var result []byte + for i := 0; i < len(input); i++ { + encoder.Reset() // Reset between chunks for clean state + dst := make([]byte, 32) + nDst, _, err := encoder.Transform(dst, []byte(input[i:i+1]), i == len(input)-1) + if err != nil && err != transform.ErrShortSrc { + t.Fatalf("Transform failed at char %d: %v", i, err) + } + result = append(result, dst[:nDst]...) + } + expected := []byte{0xC1, 0xC2, 0xC3} // "ABC" in EBCDIC + assert.Equal(t, expected, result, "streaming encode should produce EBCDIC ABC") + }) + + // Test encoder handles incomplete UTF-8 correctly + t.Run("encoder incomplete UTF-8", func(t *testing.T) { + encoder := enc.NewEncoder() + dst := make([]byte, 32) + + // Send first byte of a 2-byte UTF-8 sequence (é = 0xC3 0xA9) + incompleteUTF8 := []byte{0xC3} // First byte of é + _, _, err := encoder.Transform(dst, incompleteUTF8, false) + // Should return ErrShortSrc because the sequence is incomplete + assert.Equal(t, transform.ErrShortSrc, err, "incomplete UTF-8 should return ErrShortSrc when not at EOF") + + // At EOF, incomplete sequence should be an error + encoder.Reset() + _, _, err = encoder.Transform(dst, incompleteUTF8, true) + assert.Error(t, err, "incomplete UTF-8 at EOF should return error") + }) +} diff --git a/pkg/sqlcmd/codepage_windows.go b/pkg/sqlcmd/codepage_windows.go new file mode 100644 index 00000000..3eb1a746 --- /dev/null +++ b/pkg/sqlcmd/codepage_windows.go @@ -0,0 +1,315 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build windows + +package sqlcmd + +import ( + "errors" + "strconv" + "unicode/utf16" + "unicode/utf8" + "unsafe" + + "github.com/microsoft/go-sqlcmd/internal/localizer" + "golang.org/x/sys/windows" + "golang.org/x/text/encoding" + "golang.org/x/text/transform" +) + +const ( + // MB_ERR_INVALID_CHARS causes MultiByteToWideChar to fail if it encounters + // an invalid character in the source string (including incomplete sequences) + mbErrInvalidChars = 0x00000008 + // Maximum bytes that might form a single character in any Windows codepage + // (most DBCS codepages use 2 bytes, but we use 4 for safety) + maxMultibyteCharLen = 4 +) + +var ( + kernel32 = windows.NewLazySystemDLL("kernel32.dll") + procMultiByteToWideChar = kernel32.NewProc("MultiByteToWideChar") + procWideCharToMultiByte = kernel32.NewProc("WideCharToMultiByte") +) + +// windowsCodePageEncoding implements encoding.Encoding using Windows API +type windowsCodePageEncoding struct { + codepage uint32 +} + +func (e *windowsCodePageEncoding) NewDecoder() *encoding.Decoder { + return &encoding.Decoder{Transformer: &windowsDecoder{codepage: e.codepage}} +} + +func (e *windowsCodePageEncoding) NewEncoder() *encoding.Encoder { + return &encoding.Encoder{Transformer: &windowsEncoder{codepage: e.codepage}} +} + +// windowsDecoder converts from a Windows codepage to UTF-8. +// It buffers incomplete multibyte sequences between Transform calls. +type windowsDecoder struct { + codepage uint32 + buf [maxMultibyteCharLen]byte // buffer for incomplete sequences + bufLen int // number of bytes in buffer +} + +func (d *windowsDecoder) Reset() { + d.bufLen = 0 +} + +func (d *windowsDecoder) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + // Prepend any buffered bytes from previous call + var input []byte + if d.bufLen > 0 { + input = make([]byte, d.bufLen+len(src)) + copy(input, d.buf[:d.bufLen]) + copy(input[d.bufLen:], src) + } else { + input = src + } + + if len(input) == 0 { + return 0, 0, nil + } + + // Try to convert with MB_ERR_INVALID_CHARS to detect incomplete sequences + n, _, errno := procMultiByteToWideChar.Call( + uintptr(d.codepage), + mbErrInvalidChars, + uintptr(unsafe.Pointer(&input[0])), + uintptr(len(input)), + 0, + 0, + ) + + // If conversion failed, it might be due to incomplete trailing sequence + if n == 0 && errno == windows.ERROR_NO_UNICODE_TRANSLATION { + if atEOF { + // At EOF with incomplete sequence - this is an error + d.bufLen = 0 + return 0, len(src), errors.New("incomplete multibyte sequence at end of input") + } + + // Not at EOF - try removing bytes from the end until conversion succeeds + // This finds the incomplete trailing sequence + for trimLen := 1; trimLen <= len(input) && trimLen <= maxMultibyteCharLen; trimLen++ { + tryLen := len(input) - trimLen + if tryLen <= 0 { + // Need more input - buffer what we have + if len(input) <= maxMultibyteCharLen { + copy(d.buf[:], input) + d.bufLen = len(input) + return 0, len(src), transform.ErrShortSrc + } + break + } + + n, _, errno = procMultiByteToWideChar.Call( + uintptr(d.codepage), + mbErrInvalidChars, + uintptr(unsafe.Pointer(&input[0])), + uintptr(tryLen), + 0, + 0, + ) + if n > 0 || errno != windows.ERROR_NO_UNICODE_TRANSLATION { + // Found a valid prefix - buffer the trailing bytes + trailingBytes := input[tryLen:] + copy(d.buf[:], trailingBytes) + d.bufLen = len(trailingBytes) + input = input[:tryLen] + break + } + } + + // If still failing, buffer everything and wait for more + if n == 0 { + if len(input) <= maxMultibyteCharLen { + copy(d.buf[:], input) + d.bufLen = len(input) + return 0, len(src), transform.ErrShortSrc + } + // Input is larger than max char length but still invalid - real error + d.bufLen = 0 + return 0, len(src), errors.New("invalid multibyte sequence") + } + } else if n == 0 { + if errno != windows.ERROR_SUCCESS { + d.bufLen = 0 + return 0, 0, errno + } + d.bufLen = 0 + return 0, 0, errors.New("MultiByteToWideChar failed") + } else { + // Success - clear buffer since we'll consume all input + d.bufLen = 0 + } + + // Allocate wide char buffer and do the actual conversion + wideChars := make([]uint16, n) + n, _, errno = procMultiByteToWideChar.Call( + uintptr(d.codepage), + 0, // Don't use MB_ERR_INVALID_CHARS here - we already validated + uintptr(unsafe.Pointer(&input[0])), + uintptr(len(input)), + uintptr(unsafe.Pointer(&wideChars[0])), + uintptr(len(wideChars)), + ) + if n == 0 { + if errno != windows.ERROR_SUCCESS { + return 0, 0, errno + } + return 0, 0, errors.New("MultiByteToWideChar failed") + } + + // Convert UTF-16 to UTF-8 + runes := utf16.Decode(wideChars[:n]) + utf8Bytes := []byte(string(runes)) + + if len(utf8Bytes) > len(dst) { + return 0, 0, transform.ErrShortDst + } + + copy(dst, utf8Bytes) + return len(utf8Bytes), len(src), err +} + +// windowsEncoder converts from UTF-8 to a Windows codepage. +// It buffers incomplete UTF-8 sequences between Transform calls. +type windowsEncoder struct { + codepage uint32 + buf [utf8.UTFMax]byte // buffer for incomplete UTF-8 sequences + bufLen int // number of bytes in buffer +} + +func (e *windowsEncoder) Reset() { + e.bufLen = 0 +} + +func (e *windowsEncoder) Transform(dst, src []byte, atEOF bool) (nDst, nSrc int, err error) { + // Prepend any buffered bytes from previous call + var input []byte + if e.bufLen > 0 { + input = make([]byte, e.bufLen+len(src)) + copy(input, e.buf[:e.bufLen]) + copy(input[e.bufLen:], src) + } else { + input = src + } + + if len(input) == 0 { + return 0, 0, nil + } + + // Find the last complete UTF-8 sequence + validLen := len(input) + for validLen > 0 && !utf8.Valid(input[:validLen]) { + validLen-- + } + + // Check for incomplete trailing sequence + if validLen < len(input) { + trailingBytes := input[validLen:] + if atEOF { + // At EOF with incomplete UTF-8 - this is an error + e.bufLen = 0 + return 0, len(src), errors.New("incomplete UTF-8 sequence at end of input") + } + // Buffer the incomplete trailing bytes for next call + if len(trailingBytes) <= utf8.UTFMax { + copy(e.buf[:], trailingBytes) + e.bufLen = len(trailingBytes) + } else { + // Shouldn't happen with valid partial UTF-8, but handle it + e.bufLen = 0 + return 0, len(src), errors.New("invalid UTF-8 sequence") + } + input = input[:validLen] + } else { + e.bufLen = 0 + } + + if len(input) == 0 { + // Only incomplete sequence - need more input + return 0, len(src), transform.ErrShortSrc + } + + // Convert UTF-8 to UTF-16 + runes := []rune(string(input)) + wideChars := utf16.Encode(runes) + + if len(wideChars) == 0 { + return 0, len(src), nil + } + + // First call to get required buffer size + n, _, errno := procWideCharToMultiByte.Call( + uintptr(e.codepage), + 0, + uintptr(unsafe.Pointer(&wideChars[0])), + uintptr(len(wideChars)), + 0, + 0, + 0, + 0, + ) + if n == 0 { + if errno != windows.ERROR_SUCCESS { + return 0, 0, errno + } + return 0, 0, errors.New("WideCharToMultiByte failed") + } + + if int(n) > len(dst) { + return 0, 0, transform.ErrShortDst + } + + // Convert to multibyte + n, _, errno = procWideCharToMultiByte.Call( + uintptr(e.codepage), + 0, + uintptr(unsafe.Pointer(&wideChars[0])), + uintptr(len(wideChars)), + uintptr(unsafe.Pointer(&dst[0])), + uintptr(len(dst)), + 0, + 0, + ) + if n == 0 { + if errno != windows.ERROR_SUCCESS { + return 0, 0, errno + } + return 0, 0, errors.New("WideCharToMultiByte failed") + } + + return int(n), len(src), err +} + +// isCodePageValid checks if a codepage is valid/installed on Windows +func isCodePageValid(codepage uint32) bool { + // Try to convert a simple byte - if the codepage is invalid, this will fail + src := []byte{0x41} // 'A' + n, _, _ := procMultiByteToWideChar.Call( + uintptr(codepage), + 0, + uintptr(unsafe.Pointer(&src[0])), + 1, + 0, + 0, + ) + return n > 0 +} + +// getSystemCodePageEncoding returns an encoding using Windows API for codepages +// not in our built-in registry. If the codepage is not available, it returns +// a nil encoding and a non-nil error. +func getSystemCodePageEncoding(codepage int) (encoding.Encoding, error) { + cp := uint32(codepage) + if !isCodePageValid(cp) { + // Use %s with strconv.Itoa to avoid locale-based number formatting + // that would add thousands separators (e.g., "99,999" instead of "99999") + return nil, localizer.Errorf("unsupported codepage %s", strconv.Itoa(codepage)) + } + return &windowsCodePageEncoding{codepage: cp}, nil +} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..c3938997 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -6,6 +6,7 @@ package sqlcmd import ( "flag" "fmt" + "io" "os" "regexp" "sort" @@ -13,10 +14,28 @@ import ( "strings" "github.com/microsoft/go-sqlcmd/internal/color" + "github.com/microsoft/go-sqlcmd/internal/localizer" "golang.org/x/text/encoding/unicode" "golang.org/x/text/transform" ) +// transformWriteCloser wraps a transform.Writer and ensures the underlying +// file is closed when Close() is called. +type transformWriteCloser struct { + *transform.Writer + underlying io.Closer +} + +// Close flushes the transform writer and closes the underlying file. +func (t *transformWriteCloser) Close() error { + // Close the transform writer (flushes pending data) + if err := t.Writer.Close(); err != nil { + _ = t.underlying.Close() + return err + } + return t.underlying.Close() +} + // Command defines a sqlcmd action which can be intermixed with the SQL batch // Commands for sqlcmd are defined at https://docs.microsoft.com/sql/tools/sqlcmd-utility#sqlcmd-commands type Command struct { @@ -324,8 +343,29 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { // ODBC sqlcmd doesn't write a BOM but we will. // Maybe the endian-ness should be configurable. win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) - encoder := transform.NewWriter(o, win16le.NewEncoder()) + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, win16le.NewEncoder()), + underlying: o, + } s.SetOutput(encoder) + } else if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + // Use specified output codepage + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + _ = o.Close() + return err + } + if enc != nil { + // Transform from UTF-8 to specified encoding + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, enc.NewEncoder()), + underlying: o, + } + s.SetOutput(encoder) + } else { + // UTF-8, no transformation needed + s.SetOutput(o) + } } else { s.SetOutput(o) } @@ -352,7 +392,28 @@ func errorCommand(s *Sqlcmd, args []string, line uint) error { if err != nil { return InvalidFileError(err, args[0]) } - s.SetError(o) + // Apply output codepage if configured + if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + if cerr := o.Close(); cerr != nil { + return localizer.Errorf("%v; additionally, closing error file %q failed: %v", err, args[0], cerr) + } + return err + } + if enc == nil { + // UTF-8 (or default) encoding: write directly without transform + s.SetError(o) + } else { + encoder := &transformWriteCloser{ + Writer: transform.NewWriter(o, enc.NewEncoder()), + underlying: o, + } + s.SetError(encoder) + } + } else { + s.SetError(o) + } } return nil } diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..dc28333c 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -458,3 +458,99 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestOutputCodePageCommand(t *testing.T) { + tests := []struct { + name string + codepage int + expectedBytes []byte + inputText string + }{ + { + name: "UTF-8 output", + codepage: 65001, + inputText: "café", + expectedBytes: []byte("café"), + }, + { + name: "Windows-1252 output", + codepage: 1252, + inputText: "café", + expectedBytes: []byte{0x63, 0x61, 0x66, 0xe9}, // "café" in Windows-1252 + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Set up codepage + s.CodePage = &CodePageSettings{ + OutputCodePage: tt.codepage, + } + + // Create temp file for output + file, err := os.CreateTemp("", "sqlcmdout") + require.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + fileName := file.Name() + _ = file.Close() + + // Run the OUT command + err = outCommand(s, []string{fileName}, 1) + require.NoError(t, err, "outCommand") + + // Write some text + _, err = s.GetOutput().Write([]byte(tt.inputText)) + require.NoError(t, err, "Write") + + // Close to flush + if closer, ok := s.GetOutput().(interface{ Close() error }); ok { + require.NoError(t, closer.Close(), "Close output") + } + + // Read the file and check encoding + content, err := os.ReadFile(fileName) + require.NoError(t, err, "ReadFile") + assert.Equal(t, tt.expectedBytes, content, "Output encoding mismatch") + }) + } +} + +func TestErrorCodePageCommand(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Set up codepage for Windows-1252 + s.CodePage = &CodePageSettings{ + OutputCodePage: 1252, + } + + // Create temp file for error output + file, err := os.CreateTemp("", "sqlcmderr") + require.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + fileName := file.Name() + _ = file.Close() + + // Run the ERROR command + err = errorCommand(s, []string{fileName}, 1) + require.NoError(t, err, "errorCommand") + + // Write some text with special characters + _, err = s.err.Write([]byte("Error: café")) + require.NoError(t, err, "Write") + + // Close to flush + if closer, ok := s.err.(interface{ Close() error }); ok { + require.NoError(t, closer.Close(), "Close error") + } + + // Read the file and check encoding + content, err := os.ReadFile(fileName) + require.NoError(t, err, "ReadFile") + // "Error: café" in Windows-1252 + expected := []byte{0x45, 0x72, 0x72, 0x6f, 0x72, 0x3a, 0x20, 0x63, 0x61, 0x66, 0xe9} + assert.Equal(t, expected, content, "Error output encoding mismatch") +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..a76d2ac5 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -86,6 +86,8 @@ type Sqlcmd struct { UnicodeOutputFile bool // EchoInput tells the GO command to print the batch text before running the query EchoInput bool + // CodePage specifies input/output file encoding + CodePage *CodePageSettings colorizer color.Colorizer termchan chan os.Signal } @@ -331,9 +333,38 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error { } defer f.Close() b := s.batch.batchline - utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) - unicodeReader := transform.NewReader(f, utf16bom) - scanner := bufio.NewReader(unicodeReader) + + // Set up the reader with appropriate encoding + var reader io.Reader + if s.CodePage != nil && s.CodePage.InputCodePage != 0 { + // Use specified input codepage + enc, err := GetEncoding(s.CodePage.InputCodePage) + if err != nil { + return err + } + if enc != nil { + // Transform from specified encoding to UTF-8 + // For UTF-16 codepages, wrap with BOMOverride to strip BOM if present + if s.CodePage.InputCodePage == 1200 || s.CodePage.InputCodePage == 1201 { + // UTF-16 LE/BE: use BOMOverride to handle BOM gracefully + reader = transform.NewReader(f, unicode.BOMOverride(enc.NewDecoder())) + } else { + reader = transform.NewReader(f, enc.NewDecoder()) + } + } else { + // UTF-8 codepage (65001): BOMOverride detects UTF-8/UTF-16LE/UTF-16BE BOMs and + // switches decoder accordingly, falling back to UTF-8 when no BOM is present + utf8bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) + reader = transform.NewReader(f, utf8bom) + } + } else { + // Default (no -f flag): BOMOverride detects UTF-8/UTF-16LE/UTF-16BE BOMs at + // the start of input and switches decoder accordingly; falls back to UTF-8 + // when no BOM is present (see golang.org/x/text/encoding/unicode.BOMOverride) + utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) + reader = transform.NewReader(f, utf16bom) + } + scanner := bufio.NewReader(reader) curLine := s.batch.read echoFileLines := s.echoFileLines ln := make([]byte, 0, 2*1024*1024) diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index dfe97d1a..a28958ec 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -19,6 +19,7 @@ import ( "github.com/google/uuid" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) const oneRowAffected = "(1 row affected)" @@ -232,6 +233,69 @@ func TestIncludeFileQuotedIdentifiers(t *testing.T) { } } +func TestIncludeFileWithInputCodePage(t *testing.T) { + tests := []struct { + name string + codepage int + fileContent []byte + expectedText string + }{ + { + name: "Windows-1252 input", + codepage: 1252, + fileContent: []byte{0x73, 0x65, 0x6c, 0x65, 0x63, 0x74, 0x20, 0x27, 0x63, 0x61, 0x66, 0xe9, 0x27}, // "select 'café'" in Windows-1252 + expectedText: "select 'café'", + }, + { + name: "UTF-16 LE with BOM", + codepage: 1200, + fileContent: []byte{0xFF, 0xFE, 0x68, 0x00, 0x69, 0x00}, // BOM + "hi" in UTF-16 LE + expectedText: "hi", + }, + { + name: "UTF-16 LE without BOM", + codepage: 1200, + fileContent: []byte{0x68, 0x00, 0x69, 0x00}, // "hi" in UTF-16 LE (no BOM) + expectedText: "hi", + }, + { + name: "UTF-16 BE with BOM", + codepage: 1201, + fileContent: []byte{0xFE, 0xFF, 0x00, 0x68, 0x00, 0x69}, // BOM + "hi" in UTF-16 BE + expectedText: "hi", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + // Create temp file with encoded content + file, err := os.CreateTemp("", "sqlcmdinput*.sql") + require.NoError(t, err, "os.CreateTemp") + defer os.Remove(file.Name()) + + _, err = file.Write(tt.fileContent) + require.NoError(t, err, "Write") + err = file.Close() + require.NoError(t, err, "Close") + + // Set up Sqlcmd with InputCodePage + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + s.CodePage = &CodePageSettings{ + InputCodePage: tt.codepage, + } + + // Include the file but don't execute (processAll=false) + err = s.IncludeFile(file.Name(), false) + require.NoError(t, err, "IncludeFile") + + // Check that the batch contains the expected decoded text + batchText := s.batch.String() + assert.Contains(t, batchText, tt.expectedText, "batch should contain decoded text") + }) + } +} + func TestGetRunnableQuery(t *testing.T) { v := InitializeVariables(false) v.Set("var1", "v1")