diff --git a/README.md b/README.md index fe26e192..4488ebc5 100644 --- a/README.md +++ b/README.md @@ -122,7 +122,7 @@ The `sqlcmd` project aims to be a complete port of the original ODBC sqlcmd to t - There are new posix-style versions of each flag, such as `--input-file` for `-i`. `sqlcmd -?` will print those parameter names. Those new names do not preserve backward compatibility with ODBC `sqlcmd`. For example, to specify multiple input file names using `--input-file`, the file names must be comma-delimited, not space-delimited. The following switches have different behavior in this version of `sqlcmd` compared to the original ODBC based `sqlcmd`. -- `-R` switch is ignored. The go runtime does not provide access to user locale information, and it's not readily available through syscall on all supported platforms. +- `-R` switch enables regional formatting for numeric, currency, and date/time values based on the user's locale. Formatting includes locale-specific thousand separators for numbers, and locale-specific date/time formats. On Windows, the user's default locale is detected from system settings. On Linux/macOS, the locale is detected from environment variables (`LC_ALL`, `LC_MESSAGES`, `LANG`). - `-I` switch is ignored; quoted identifiers are always set on. To disable quoted identifier behavior, add `SET QUOTED IDENTIFIER OFF` in your scripts. - `-N` now takes an optional string value that can be one of `s[trict]`,`t[rue]`,`m[andatory]`, `yes`,`1`, `o[ptional]`,`no`, `0`, `f[alse]`, or `disable` to specify the encryption choice. - If `-N` is passed but no value is provided, `true` is used. @@ -133,6 +133,7 @@ The following switches have different behavior in this version of `sqlcmd` compa - To provide the value of the host name in the server certificate when using strict encryption, pass the host name with `-F`. Example: `-Ns -F myhost.domain.com` - More information about client/server encryption negotiation can be found at - `-u` The generated Unicode output file will have the UTF16 Little-Endian Byte-order mark (BOM) written to it. +- `-f` Specifies the code page for input and output files. Format: `codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage]`. Use `65001` for UTF-8. Supported codepages include Unicode (65001, 1200, 1201), Windows (874, 1250-1258), OEM/DOS (437, 850, etc.), ISO-8859 (28591-28606), CJK (932, 936, 949, 950), and EBCDIC (37, 1047, 1140). Use `--list-codepages` to see all supported code pages. - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. - All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. - `-i` doesn't handle a comma `,` in a file name correctly unless the file name argument is triple quoted. For example: diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index ea655b47..94dd17c6 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -82,6 +82,9 @@ type SQLCmdArguments struct { ChangePassword string ChangePasswordAndExit string TraceFile string + CodePage string + ListCodePages bool + UseRegionalSettings bool // Keep Help at the end of the list Help bool } @@ -171,6 +174,10 @@ func (a *SQLCmdArguments) Validate(c *cobra.Command) (err error) { err = rangeParameterError("-t", fmt.Sprint(a.QueryTimeout), 0, 65534, true) case a.ServerCertificate != "" && !encryptConnectionAllowsTLS(a.EncryptConnection): err = localizer.Errorf("The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict).") + case a.CodePage != "": + if _, parseErr := sqlcmd.ParseCodePage(a.CodePage); parseErr != nil { + err = localizer.Errorf(`'-f %s': %v`, a.CodePage, parseErr) + } } } if err != nil { @@ -239,6 +246,17 @@ func Execute(version string) { listLocalServers() os.Exit(0) } + // List supported codepages + if args.ListCodePages { + fmt.Println(localizer.Sprintf("Supported Code Pages:")) + fmt.Println() + fmt.Printf("%-8s %-20s %s\n", "Code", "Name", "Description") + fmt.Printf("%-8s %-20s %s\n", "----", "----", "-----------") + for _, cp := range sqlcmd.SupportedCodePages() { + fmt.Printf("%-8d %-20s %s\n", cp.CodePage, cp.Name, cp.Description) + } + os.Exit(0) + } if len(argss) > 0 { fmt.Printf("%s'%s': Unknown command. Enter '--help' for command help.", sqlcmdErrorPrefix, argss[0]) os.Exit(1) @@ -472,13 +490,15 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().StringVarP(&args.ListServers, listServers, "L", "", localizer.Sprintf("%s List servers. Pass %s to omit 'Servers:' output.", "-L[c]", "c")) rootCmd.Flags().BoolVarP(&args.DedicatedAdminConnection, "dedicated-admin-connection", "A", false, localizer.Sprintf("Dedicated administrator connection")) _ = rootCmd.Flags().BoolP("enable-quoted-identifiers", "I", true, localizer.Sprintf("Provided for backward compatibility. Quoted identifiers are always enabled")) - _ = rootCmd.Flags().BoolP("client-regional-setting", "R", false, localizer.Sprintf("Provided for backward compatibility. Client regional settings are not used")) + rootCmd.Flags().BoolVarP(&args.UseRegionalSettings, "client-regional-setting", "R", false, localizer.Sprintf("Use client regional settings for currency, date, and time formatting")) _ = rootCmd.Flags().IntP(removeControlCharacters, "k", 0, localizer.Sprintf("%s Remove control characters from output. Pass 1 to substitute a space per character, 2 for a space per consecutive characters", "-k [1|2]")) rootCmd.Flags().BoolVarP(&args.EchoInput, "echo-input", "e", false, localizer.Sprintf("Echo input")) rootCmd.Flags().IntVarP(&args.QueryTimeout, "query-timeout", "t", 0, "Query timeout") rootCmd.Flags().BoolVarP(&args.EnableColumnEncryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption")) rootCmd.Flags().StringVarP(&args.ChangePassword, "change-password", "z", "", localizer.Sprintf("New password")) rootCmd.Flags().StringVarP(&args.ChangePasswordAndExit, "change-password-exit", "Z", "", localizer.Sprintf("New password and exit")) + rootCmd.Flags().StringVarP(&args.CodePage, "code-page", "f", "", localizer.Sprintf("Specifies the code page for input/output. Use 65001 for UTF-8. Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage]")) + rootCmd.Flags().BoolVar(&args.ListCodePages, "list-codepages", false, localizer.Sprintf("List supported code pages and exit")) } func setScriptVariable(v string) string { @@ -813,6 +833,15 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { defer s.StopCloseHandler() s.UnicodeOutputFile = args.UnicodeOutputFile + // Parse and apply codepage settings + if args.CodePage != "" { + codePageSettings, err := sqlcmd.ParseCodePage(args.CodePage) + if err != nil { + return 1, localizer.Errorf("Invalid code page: %v", err) + } + s.CodePage = codePageSettings + } + if args.DisableCmd != nil { s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } @@ -828,7 +857,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } s.Connect = &connectConfig - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior()) + s.Format = sqlcmd.NewSQLCmdDefaultFormatterWithRegional(args.TrimSpaces, args.getControlCharacterBehavior(), args.UseRegionalSettings) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) if err != nil { diff --git a/cmd/sqlcmd/sqlcmd_test.go b/cmd/sqlcmd/sqlcmd_test.go index 511816b2..81e77a83 100644 --- a/cmd/sqlcmd/sqlcmd_test.go +++ b/cmd/sqlcmd/sqlcmd_test.go @@ -123,6 +123,29 @@ func TestValidCommandLineToArgsConversion(t *testing.T) { {[]string{"-N", "true", "-J", "/path/to/cert2.pem"}, func(args SQLCmdArguments) bool { return args.EncryptConnection == "true" && args.ServerCertificate == "/path/to/cert2.pem" }}, + // Codepage flag tests + {[]string{"-f", "65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "65001" + }}, + {[]string{"-f", "i:1252,o:65001"}, func(args SQLCmdArguments) bool { + return args.CodePage == "i:1252,o:65001" + }}, + {[]string{"-f", "o:65001,i:1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "o:65001,i:1252" + }}, + {[]string{"--code-page", "1252"}, func(args SQLCmdArguments) bool { + return args.CodePage == "1252" + }}, + {[]string{"--list-codepages"}, func(args SQLCmdArguments) bool { + return args.ListCodePages + }}, + // Regional settings flag test + {[]string{"-R"}, func(args SQLCmdArguments) bool { + return args.UseRegionalSettings + }}, + {[]string{"--client-regional-setting"}, func(args SQLCmdArguments) bool { + return args.UseRegionalSettings + }}, } for _, test := range commands { @@ -178,6 +201,11 @@ func TestInvalidCommandLine(t *testing.T) { {[]string{"-N", "optional", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "disable", "-J", "/path/to/cert.pem"}, "The -J parameter requires encryption to be enabled (-N true, -N mandatory, or -N strict)."}, {[]string{"-N", "strict", "-F", "myserver.domain.com", "-J", "/path/to/cert.pem"}, "The -F and the -J options are mutually exclusive."}, + // Codepage validation tests + {[]string{"-f", "invalid"}, `'-f invalid': invalid codepage: invalid`}, + {[]string{"-f", "99999"}, `'-f 99999': unsupported codepage 99999`}, + {[]string{"-f", "i:invalid"}, `'-f i:invalid': invalid input codepage: i:invalid`}, + {[]string{"-f", "x:1252"}, `'-f x:1252': invalid codepage: x:1252`}, } for _, test := range commands { diff --git a/pkg/sqlcmd/codepage.go b/pkg/sqlcmd/codepage.go new file mode 100644 index 00000000..231940cb --- /dev/null +++ b/pkg/sqlcmd/codepage.go @@ -0,0 +1,318 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "fmt" + "strconv" + "strings" + + "golang.org/x/text/encoding" + "golang.org/x/text/encoding/charmap" + "golang.org/x/text/encoding/japanese" + "golang.org/x/text/encoding/korean" + "golang.org/x/text/encoding/simplifiedchinese" + "golang.org/x/text/encoding/traditionalchinese" + "golang.org/x/text/encoding/unicode" +) + +// CodePageSettings holds the input and output codepage settings +type CodePageSettings struct { + InputCodePage int + OutputCodePage int +} + +// ParseCodePage parses the -f codepage argument +// Format: codepage | i:codepage[,o:codepage] | o:codepage[,i:codepage] +func ParseCodePage(arg string) (*CodePageSettings, error) { + if arg == "" { + return nil, nil + } + + settings := &CodePageSettings{} + parts := strings.Split(arg, ",") + + for _, part := range parts { + part = strings.TrimSpace(part) + if part == "" { + continue + } + + if strings.HasPrefix(strings.ToLower(part), "i:") { + // Input codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "i:")) + if err != nil { + return nil, fmt.Errorf("invalid input codepage: %s", part) + } + settings.InputCodePage = cp + } else if strings.HasPrefix(strings.ToLower(part), "o:") { + // Output codepage + cp, err := strconv.Atoi(strings.TrimPrefix(strings.ToLower(part), "o:")) + if err != nil { + return nil, fmt.Errorf("invalid output codepage: %s", part) + } + settings.OutputCodePage = cp + } else { + // Both input and output + cp, err := strconv.Atoi(part) + if err != nil { + return nil, fmt.Errorf("invalid codepage: %s", part) + } + settings.InputCodePage = cp + settings.OutputCodePage = cp + } + } + + // Validate codepages + if settings.InputCodePage != 0 { + if _, err := GetEncoding(settings.InputCodePage); err != nil { + return nil, err + } + } + if settings.OutputCodePage != 0 { + if _, err := GetEncoding(settings.OutputCodePage); err != nil { + return nil, err + } + } + + return settings, nil +} + +// GetEncoding returns the encoding for a given Windows codepage number. +// Returns nil for UTF-8 (65001) since Go uses UTF-8 natively. +func GetEncoding(codepage int) (encoding.Encoding, error) { + switch codepage { + // Unicode encodings + case 65001: + // UTF-8 - Go's native encoding, return nil to indicate no transformation needed + return nil, nil + case 1200: + // UTF-16LE + return unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM), nil + case 1201: + // UTF-16BE + return unicode.UTF16(unicode.BigEndian, unicode.IgnoreBOM), nil + + // OEM/DOS codepages + case 437: + return charmap.CodePage437, nil + case 850: + return charmap.CodePage850, nil + case 852: + return charmap.CodePage852, nil + case 855: + return charmap.CodePage855, nil + case 858: + return charmap.CodePage858, nil + case 860: + return charmap.CodePage860, nil + case 862: + return charmap.CodePage862, nil + case 863: + return charmap.CodePage863, nil + case 865: + return charmap.CodePage865, nil + case 866: + return charmap.CodePage866, nil + + // Windows codepages + case 874: + return charmap.Windows874, nil + case 1250: + return charmap.Windows1250, nil + case 1251: + return charmap.Windows1251, nil + case 1252: + return charmap.Windows1252, nil + case 1253: + return charmap.Windows1253, nil + case 1254: + return charmap.Windows1254, nil + case 1255: + return charmap.Windows1255, nil + case 1256: + return charmap.Windows1256, nil + case 1257: + return charmap.Windows1257, nil + case 1258: + return charmap.Windows1258, nil + + // ISO-8859 codepages + case 28591: + return charmap.ISO8859_1, nil + case 28592: + return charmap.ISO8859_2, nil + case 28593: + return charmap.ISO8859_3, nil + case 28594: + return charmap.ISO8859_4, nil + case 28595: + return charmap.ISO8859_5, nil + case 28596: + return charmap.ISO8859_6, nil + case 28597: + return charmap.ISO8859_7, nil + case 28598: + return charmap.ISO8859_8, nil + case 28599: + return charmap.ISO8859_9, nil + case 28600: + return charmap.ISO8859_10, nil + case 28603: + return charmap.ISO8859_13, nil + case 28604: + return charmap.ISO8859_14, nil + case 28605: + return charmap.ISO8859_15, nil + case 28606: + return charmap.ISO8859_16, nil + + // Cyrillic + case 20866: + return charmap.KOI8R, nil + case 21866: + return charmap.KOI8U, nil + + // Macintosh + case 10000: + return charmap.Macintosh, nil + case 10007: + return charmap.MacintoshCyrillic, nil + + // EBCDIC codepages + case 37: + return charmap.CodePage037, nil + case 1047: + return charmap.CodePage1047, nil + case 1140: + return charmap.CodePage1140, nil + + // Japanese + case 932: + // Shift JIS (Windows-31J) + return japanese.ShiftJIS, nil + case 20932: + // EUC-JP + return japanese.EUCJP, nil + case 50220, 50221, 50222: + // ISO-2022-JP + return japanese.ISO2022JP, nil + + // Korean + case 949: + // EUC-KR (Korean) + return korean.EUCKR, nil + case 51949: + // EUC-KR alternate + return korean.EUCKR, nil + + // Simplified Chinese + case 936: + // GBK (Simplified Chinese) + return simplifiedchinese.GBK, nil + case 54936: + // GB18030 + return simplifiedchinese.GB18030, nil + case 52936: + // HZ-GB2312 + return simplifiedchinese.HZGB2312, nil + + // Traditional Chinese + case 950: + // Big5 + return traditionalchinese.Big5, nil + + default: + return nil, fmt.Errorf("unsupported codepage %d", codepage) + } +} + +// CodePageInfo describes a supported codepage +type CodePageInfo struct { + CodePage int + Name string + Description string +} + +// SupportedCodePages returns a list of all supported codepages with descriptions +func SupportedCodePages() []CodePageInfo { + return []CodePageInfo{ + // Unicode + {65001, "UTF-8", "Unicode (UTF-8)"}, + {1200, "UTF-16LE", "Unicode (UTF-16 Little-Endian)"}, + {1201, "UTF-16BE", "Unicode (UTF-16 Big-Endian)"}, + + // OEM/DOS codepages + {437, "CP437", "OEM United States"}, + {850, "CP850", "OEM Multilingual Latin 1"}, + {852, "CP852", "OEM Latin 2"}, + {855, "CP855", "OEM Cyrillic"}, + {858, "CP858", "OEM Multilingual Latin 1 + Euro"}, + {860, "CP860", "OEM Portuguese"}, + {862, "CP862", "OEM Hebrew"}, + {863, "CP863", "OEM Canadian French"}, + {865, "CP865", "OEM Nordic"}, + {866, "CP866", "OEM Russian"}, + + // Windows codepages + {874, "Windows-874", "Thai"}, + {1250, "Windows-1250", "Central European"}, + {1251, "Windows-1251", "Cyrillic"}, + {1252, "Windows-1252", "Western European"}, + {1253, "Windows-1253", "Greek"}, + {1254, "Windows-1254", "Turkish"}, + {1255, "Windows-1255", "Hebrew"}, + {1256, "Windows-1256", "Arabic"}, + {1257, "Windows-1257", "Baltic"}, + {1258, "Windows-1258", "Vietnamese"}, + + // ISO-8859 codepages + {28591, "ISO-8859-1", "Latin 1 (Western European)"}, + {28592, "ISO-8859-2", "Latin 2 (Central European)"}, + {28593, "ISO-8859-3", "Latin 3 (South European)"}, + {28594, "ISO-8859-4", "Latin 4 (North European)"}, + {28595, "ISO-8859-5", "Cyrillic"}, + {28596, "ISO-8859-6", "Arabic"}, + {28597, "ISO-8859-7", "Greek"}, + {28598, "ISO-8859-8", "Hebrew"}, + {28599, "ISO-8859-9", "Turkish"}, + {28600, "ISO-8859-10", "Nordic"}, + {28603, "ISO-8859-13", "Baltic"}, + {28604, "ISO-8859-14", "Celtic"}, + {28605, "ISO-8859-15", "Latin 9 (Western European with Euro)"}, + {28606, "ISO-8859-16", "Latin 10 (South-Eastern European)"}, + + // Cyrillic + {20866, "KOI8-R", "Russian"}, + {21866, "KOI8-U", "Ukrainian"}, + + // Macintosh + {10000, "Macintosh", "Mac Roman"}, + {10007, "x-mac-cyrillic", "Mac Cyrillic"}, + + // EBCDIC + {37, "IBM037", "EBCDIC US-Canada"}, + {1047, "IBM1047", "EBCDIC Latin 1/Open System"}, + {1140, "IBM01140", "EBCDIC US-Canada with Euro"}, + + // Japanese + {932, "Shift_JIS", "Japanese (Shift-JIS)"}, + {20932, "EUC-JP", "Japanese (EUC)"}, + {50220, "ISO-2022-JP", "Japanese (JIS)"}, + {50221, "csISO2022JP", "Japanese (JIS-Allow 1 byte Kana)"}, + {50222, "ISO-2022-JP", "Japanese (JIS-Allow 1 byte Kana SO/SI)"}, + + // Korean + {949, "EUC-KR", "Korean"}, + {51949, "EUC-KR", "Korean (EUC)"}, + + // Simplified Chinese + {936, "GBK", "Chinese Simplified (GBK)"}, + {54936, "GB18030", "Chinese Simplified (GB18030)"}, + {52936, "HZ-GB-2312", "Chinese Simplified (HZ)"}, + + // Traditional Chinese + {950, "Big5", "Chinese Traditional (Big5)"}, + } +} diff --git a/pkg/sqlcmd/codepage_test.go b/pkg/sqlcmd/codepage_test.go new file mode 100644 index 00000000..23160c92 --- /dev/null +++ b/pkg/sqlcmd/codepage_test.go @@ -0,0 +1,229 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "fmt" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseCodePage(t *testing.T) { + tests := []struct { + name string + arg string + wantInput int + wantOutput int + wantErr bool + errContains string + }{ + { + name: "empty string", + arg: "", + wantInput: 0, + wantOutput: 0, + wantErr: false, + }, + { + name: "single codepage sets both", + arg: "65001", + wantInput: 65001, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input only", + arg: "i:1252", + wantInput: 1252, + wantOutput: 0, + wantErr: false, + }, + { + name: "output only", + arg: "o:65001", + wantInput: 0, + wantOutput: 65001, + wantErr: false, + }, + { + name: "input and output", + arg: "i:1252,o:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "output and input reversed", + arg: "o:65001,i:1252", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "uppercase prefix", + arg: "I:1252,O:65001", + wantInput: 1252, + wantOutput: 65001, + wantErr: false, + }, + { + name: "invalid codepage number", + arg: "abc", + wantErr: true, + errContains: "invalid codepage", + }, + { + name: "invalid input codepage", + arg: "i:abc", + wantErr: true, + errContains: "invalid input codepage", + }, + { + name: "invalid output codepage", + arg: "o:xyz", + wantErr: true, + errContains: "invalid output codepage", + }, + { + name: "unsupported codepage", + arg: "99999", + wantErr: true, + errContains: "unsupported codepage", + }, + { + name: "Japanese Shift JIS", + arg: "932", + wantInput: 932, + wantOutput: 932, + wantErr: false, + }, + { + name: "Chinese GBK", + arg: "936", + wantInput: 936, + wantOutput: 936, + wantErr: false, + }, + { + name: "Korean", + arg: "949", + wantInput: 949, + wantOutput: 949, + wantErr: false, + }, + { + name: "Traditional Chinese Big5", + arg: "950", + wantInput: 950, + wantOutput: 950, + wantErr: false, + }, + { + name: "EBCDIC", + arg: "37", + wantInput: 37, + wantOutput: 37, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + settings, err := ParseCodePage(tt.arg) + if tt.wantErr { + assert.Error(t, err) + if tt.errContains != "" { + assert.Contains(t, err.Error(), tt.errContains) + } + return + } + assert.NoError(t, err) + if tt.arg == "" { + assert.Nil(t, settings) + return + } + assert.NotNil(t, settings) + assert.Equal(t, tt.wantInput, settings.InputCodePage) + assert.Equal(t, tt.wantOutput, settings.OutputCodePage) + }) + } +} + +func TestGetEncoding(t *testing.T) { + tests := []struct { + codepage int + wantNil bool // UTF-8 returns nil encoding + wantErr bool + }{ + // Unicode + {65001, true, false}, // UTF-8 + {1200, false, false}, // UTF-16LE + {1201, false, false}, // UTF-16BE + + // OEM/DOS + {437, false, false}, + {850, false, false}, + {866, false, false}, + + // Windows + {874, false, false}, + {1250, false, false}, + {1251, false, false}, + {1252, false, false}, + {1253, false, false}, + {1254, false, false}, + {1255, false, false}, + {1256, false, false}, + {1257, false, false}, + {1258, false, false}, + + // ISO-8859 + {28591, false, false}, + {28592, false, false}, + {28605, false, false}, + + // Cyrillic + {20866, false, false}, + {21866, false, false}, + + // Macintosh + {10000, false, false}, + {10007, false, false}, + + // EBCDIC + {37, false, false}, + {1047, false, false}, + {1140, false, false}, + + // CJK + {932, false, false}, // Japanese Shift JIS + {20932, false, false}, // EUC-JP + {50220, false, false}, // ISO-2022-JP + {949, false, false}, // Korean EUC-KR + {936, false, false}, // Chinese GBK + {54936, false, false}, // GB18030 + {950, false, false}, // Big5 + + // Invalid + {99999, false, true}, + {12345, false, true}, + } + + for _, tt := range tests { + t.Run(fmt.Sprintf("cp_%d", tt.codepage), func(t *testing.T) { + enc, err := GetEncoding(tt.codepage) + if tt.wantErr { + assert.Error(t, err) + return + } + assert.NoError(t, err) + if tt.wantNil { + assert.Nil(t, enc, "UTF-8 should return nil encoding") + } else { + assert.NotNil(t, enc, "non-UTF-8 codepage should return encoding") + } + }) + } +} diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..9efbb5a9 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -326,6 +326,21 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) encoder := transform.NewWriter(o, win16le.NewEncoder()) s.SetOutput(encoder) + } else if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + // Use specified output codepage + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + _ = o.Close() + return err + } + if enc != nil { + // Transform from UTF-8 to specified encoding + encoder := transform.NewWriter(o, enc.NewEncoder()) + s.SetOutput(encoder) + } else { + // UTF-8, no transformation needed + s.SetOutput(o) + } } else { s.SetOutput(o) } @@ -352,7 +367,25 @@ func errorCommand(s *Sqlcmd, args []string, line uint) error { if err != nil { return InvalidFileError(err, args[0]) } - s.SetError(o) + // Apply output codepage if configured + if s.CodePage != nil && s.CodePage.OutputCodePage != 0 { + enc, err := GetEncoding(s.CodePage.OutputCodePage) + if err != nil { + if cerr := o.Close(); cerr != nil { + return fmt.Errorf("%w (and closing error file %s failed: %v)", err, filePath, cerr) + } + return err + } + if enc == nil { + // No transformation required (e.g., UTF-8), write directly + s.SetError(o) + } else { + encoder := transform.NewWriter(o, enc.NewEncoder()) + s.SetError(encoder) + } + } else { + s.SetError(o) + } } return nil } diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..a39b9b38 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -85,15 +85,23 @@ type sqlCmdFormatterType struct { maxColNameLen int colorizer color.Colorizer xml bool + regional *RegionalSettings } // NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBehavior) Formatter { + return NewSQLCmdDefaultFormatterWithRegional(removeTrailingSpaces, ccb, false) +} + +// NewSQLCmdDefaultFormatterWithRegional returns a Formatter with optional regional settings support +// When useRegionalSettings is true, numeric and date/time values are formatted according to the user's locale +func NewSQLCmdDefaultFormatterWithRegional(removeTrailingSpaces bool, ccb ControlCharacterBehavior, useRegionalSettings bool) Formatter { return &sqlCmdFormatterType{ removeTrailingSpaces: removeTrailingSpaces, format: "horizontal", colorizer: color.New(false), ccb: ccb, + regional: NewRegionalSettings(useRegionalSettings), } } @@ -478,11 +486,12 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { if *j == nil { row[n] = "NULL" } else { + typeName := f.columnDetails[n].col.DatabaseTypeName() switch x := (*j).(type) { case []byte: if isBinaryDataType(&f.columnDetails[n].col) { row[n] = decodeBinary(x) - } else if f.columnDetails[n].col.DatabaseTypeName() == "UNIQUEIDENTIFIER" { + } else if typeName == "UNIQUEIDENTIFIER" { // Unscramble the guid // see https://github.com/denisenkom/go-mssqldb/issues/56 x[0], x[1], x[2], x[3] = x[3], x[2], x[1], x[0] @@ -498,28 +507,55 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { row[n] = string(x) } case string: - row[n] = x + // Apply regional formatting for DECIMAL/NUMERIC/MONEY when represented as string + if f.regional.IsEnabled() { + switch typeName { + case "DECIMAL", "NUMERIC": + row[n] = f.regional.FormatNumber(x) + case "MONEY", "SMALLMONEY": + row[n] = f.regional.FormatMoney(x) + default: + row[n] = x + } + } else { + row[n] = x + } case time.Time: - // Go lacks any way to get the user's preferred time format or even the system default - switch f.columnDetails[n].col.DatabaseTypeName() { - case "DATE": - row[n] = x.Format("2006-01-02") - case "DATETIME": - row[n] = x.Format(dateTimeFormatString(3, false)) - case "DATETIME2": - row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, false)) - case "SMALLDATETIME": - row[n] = x.Format(dateTimeFormatString(0, false)) - case "DATETIMEOFFSET": - row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, true)) - case "TIME": - format := "15:04:05" - if f.columnDetails[n].scale > 0 { - format = fmt.Sprintf("%s.%0*d", format, f.columnDetails[n].scale, 0) + // Apply regional formatting when -R is enabled + if f.regional.IsEnabled() { + switch typeName { + case "DATE": + row[n] = f.regional.FormatDate(x) + case "DATETIME", "DATETIME2", "SMALLDATETIME": + row[n] = f.regional.FormatDateTime(x, f.columnDetails[n].scale, false) + case "DATETIMEOFFSET": + row[n] = f.regional.FormatDateTime(x, f.columnDetails[n].scale, true) + case "TIME": + row[n] = f.regional.FormatTime(x, f.columnDetails[n].scale) + default: + row[n] = x.Format(time.RFC3339) + } + } else { + switch typeName { + case "DATE": + row[n] = x.Format("2006-01-02") + case "DATETIME": + row[n] = x.Format(dateTimeFormatString(3, false)) + case "DATETIME2": + row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, false)) + case "SMALLDATETIME": + row[n] = x.Format(dateTimeFormatString(0, false)) + case "DATETIMEOFFSET": + row[n] = x.Format(dateTimeFormatString(f.columnDetails[n].scale, true)) + case "TIME": + format := "15:04:05" + if f.columnDetails[n].scale > 0 { + format = fmt.Sprintf("%s.%0*d", format, f.columnDetails[n].scale, 0) + } + row[n] = x.Format(format) + default: + row[n] = x.Format(time.RFC3339) } - row[n] = x.Format(format) - default: - row[n] = x.Format(time.RFC3339) } case fmt.Stringer: row[n] = x.String() @@ -531,9 +567,19 @@ func (f *sqlCmdFormatterType) scanRow(rows *sql.Rows) ([]string, error) { row[n] = "0" } default: - var err error - if row[n], err = fmt.Sprintf("%v", x), nil; err != nil { - return nil, err + val := fmt.Sprintf("%v", x) + // Apply regional formatting for numeric types + if f.regional.IsEnabled() { + switch typeName { + case "DECIMAL", "NUMERIC": + row[n] = f.regional.FormatNumber(val) + case "MONEY", "SMALLMONEY": + row[n] = f.regional.FormatMoney(val) + default: + row[n] = val + } + } else { + row[n] = val } } } diff --git a/pkg/sqlcmd/regional.go b/pkg/sqlcmd/regional.go new file mode 100644 index 00000000..29bed7ac --- /dev/null +++ b/pkg/sqlcmd/regional.go @@ -0,0 +1,278 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "strconv" + "strings" + "time" + + "golang.org/x/text/language" +) + +// RegionalSettings provides locale-aware formatting for output when -R is used +type RegionalSettings struct { + enabled bool + tag language.Tag + dateFmt string + timeFmt string +} + +// NewRegionalSettings creates a new RegionalSettings instance +// If enabled is false, all format methods return values unchanged +func NewRegionalSettings(enabled bool) *RegionalSettings { + r := &RegionalSettings{enabled: enabled} + if enabled { + r.tag = detectUserLocale() + r.dateFmt, r.timeFmt = getLocaleDateTimeFormats(r.tag) + } + return r +} + +// IsEnabled returns whether regional formatting is active +func (r *RegionalSettings) IsEnabled() bool { + return r.enabled +} + +// FormatNumber formats a numeric value with locale-specific thousand separators +// Used for DECIMAL and NUMERIC types. Formatting is done purely by string +// manipulation to preserve all digits of high-precision values. +func (r *RegionalSettings) FormatNumber(value string) string { + if !r.enabled || value == "" || value == "NULL" { + return value + } + + // Handle leading sign + negative := strings.HasPrefix(value, "-") + if negative { + value = value[1:] + } + + // Split into integer and decimal parts using the SQL-style decimal point. + // We do not change any digits; we only insert locale-specific separators. + parts := strings.SplitN(value, ".", 2) + intPart := parts[0] + + // Add thousand separators using locale convention (pure string manipulation) + formatted := addThousandSeparators(intPart, r.tag) + if len(parts) > 1 { + formatted += getDecimalSeparator(r.tag) + parts[1] + } + if negative { + formatted = "-" + formatted + } + return formatted +} + +// FormatMoney formats a currency value with locale-specific formatting +// Used for MONEY and SMALLMONEY types +// Uses pure string manipulation to preserve precision for large values +func (r *RegionalSettings) FormatMoney(value string) string { + if !r.enabled || value == "" || value == "NULL" { + return value + } + + // MONEY/SMALLMONEY are fixed-point with 4 decimal places. + // Avoid float64 to prevent rounding of large values; format via string operations. + negative := strings.HasPrefix(value, "-") + cleanValue := value + if negative { + cleanValue = value[1:] + } + + // Split into integer and fractional parts + parts := strings.SplitN(cleanValue, ".", 2) + intPart := parts[0] + fracPart := "" + if len(parts) > 1 { + fracPart = parts[1] + } + + // Normalize fractional part to exactly 4 digits, matching SQL Server MONEY display + if len(fracPart) == 0 { + fracPart = "0000" + } else if len(fracPart) < 4 { + fracPart = fracPart + strings.Repeat("0", 4-len(fracPart)) + } else if len(fracPart) > 4 { + fracPart = fracPart[:4] + } + + // Apply locale-specific thousand separators to the integer part + formattedInt := addThousandSeparators(intPart, r.tag) + + // Combine with locale-specific decimal separator + formatted := formattedInt + getDecimalSeparator(r.tag) + fracPart + if negative { + formatted = "-" + formatted + } + return formatted +} + +// FormatDate formats a date value using locale-specific date format +// Used for DATE type +func (r *RegionalSettings) FormatDate(t time.Time) string { + if !r.enabled { + return t.Format("2006-01-02") + } + return t.Format(r.dateFmt) +} + +// FormatDateTime formats a datetime value using locale-specific format +// Used for DATETIME, DATETIME2, SMALLDATETIME types +func (r *RegionalSettings) FormatDateTime(t time.Time, scale int, addOffset bool) string { + if !r.enabled { + return t.Format(dateTimeFormatString(scale, addOffset)) + } + + // Combine date and time in regional format + datePart := t.Format(r.dateFmt) + timePart := t.Format(r.timeFmt) + + result := datePart + " " + timePart + if scale > 0 { + // Add fractional seconds + frac := t.Nanosecond() / (1000000000 / pow10(scale)) + result += getDecimalSeparator(r.tag) + padLeftStr(strconv.Itoa(frac), scale, '0') + } + if addOffset { + _, offset := t.Zone() + hours := offset / 3600 + minutes := (offset % 3600) / 60 + if minutes < 0 { + minutes = -minutes + } + result += " " + formatOffset(hours, minutes) + } + return result +} + +// FormatTime formats a time value using locale-specific time format +// Used for TIME type +func (r *RegionalSettings) FormatTime(t time.Time, scale int) string { + if !r.enabled { + format := "15:04:05" + if scale > 0 { + format = format + "." + strings.Repeat("0", scale) + } + return t.Format(format) + } + + result := t.Format(r.timeFmt) + if scale > 0 { + frac := t.Nanosecond() / (1000000000 / pow10(scale)) + result += getDecimalSeparator(r.tag) + padLeftStr(strconv.Itoa(frac), scale, '0') + } + return result +} + +// Helper functions + +func pow10(n int) int { + result := 1 + for i := 0; i < n; i++ { + result *= 10 + } + return result +} + +func padLeftStr(s string, length int, pad rune) string { + for len(s) < length { + s = string(pad) + s + } + return s +} + +func formatOffset(hours, minutes int) string { + sign := "+" + if hours < 0 { + sign = "-" + hours = -hours + } + return sign + padLeftStr(strconv.Itoa(hours), 2, '0') + ":" + padLeftStr(strconv.Itoa(minutes), 2, '0') +} + +// getDecimalSeparator returns the decimal separator for the given locale +func getDecimalSeparator(tag language.Tag) string { + // Common decimal separators by language + base, _ := tag.Base() + switch base.String() { + case "de", "fr", "es", "it", "pt", "nl", "pl", "cs", "sk", "hu", "ro", "bg", "hr", "sl", "sr", "tr", "el", "ru", "uk", "be", "fi", "sv", "no", "da", "is": + return "," + default: + return "." + } +} + +// getThousandSeparator returns the thousand separator for the given locale +func getThousandSeparator(tag language.Tag) string { + base, _ := tag.Base() + switch base.String() { + case "de", "fr", "es", "it", "pt", "nl", "pl", "cs", "sk", "hu", "ro", "bg", "hr", "sl", "sr", "tr", "el", "ru", "uk", "be", "fi", "sv", "no", "da", "is": + // These locales use period or space as thousand separator + return "." + default: + return "," + } +} + +// addThousandSeparators adds thousand separators to an integer string +func addThousandSeparators(s string, tag language.Tag) string { + sep := getThousandSeparator(tag) + if len(s) <= 3 { + return s + } + + var result strings.Builder + start := len(s) % 3 + if start == 0 { + start = 3 + } + result.WriteString(s[:start]) + for i := start; i < len(s); i += 3 { + result.WriteString(sep) + result.WriteString(s[i : i+3]) + } + return result.String() +} + +// getLocaleDateTimeFormats returns the date and time format strings for the locale +func getLocaleDateTimeFormats(tag language.Tag) (dateFmt, timeFmt string) { + // Default to ISO format + dateFmt = "2006-01-02" + timeFmt = "15:04:05" + + base, _ := tag.Base() + region, _ := tag.Region() + + // Set date format based on locale + switch base.String() { + case "en": + if region.String() == "US" { + dateFmt = "01/02/2006" + } else { + dateFmt = "02/01/2006" + } + case "de", "ru", "pl", "cs", "sk", "hu", "ro", "bg", "hr", "sl", "sr", "uk", "be": + dateFmt = "02.01.2006" + case "fr", "pt", "es", "it", "nl", "tr", "el": + dateFmt = "02/01/2006" + case "ja", "zh", "ko": + dateFmt = "2006/01/02" + case "fi", "sv", "no", "da", "is": + dateFmt = "2006-01-02" + } + + // Set time format based on locale (12hr vs 24hr) + switch base.String() { + case "en": + if region.String() == "US" || region.String() == "CA" || region.String() == "AU" { + timeFmt = "03:04:05 PM" + } + case "ja", "ko": + // These can use 12hr with different AM/PM + timeFmt = "15:04:05" + } + + return dateFmt, timeFmt +} diff --git a/pkg/sqlcmd/regional_darwin.go b/pkg/sqlcmd/regional_darwin.go new file mode 100644 index 00000000..58e16be7 --- /dev/null +++ b/pkg/sqlcmd/regional_darwin.go @@ -0,0 +1,84 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build darwin + +package sqlcmd + +import ( + "os" + "os/exec" + "strings" + + "golang.org/x/text/language" +) + +// detectUserLocale returns the user's locale from macOS settings +func detectUserLocale() language.Tag { + // First try environment variables (same as Linux) + for _, envVar := range []string{"LC_ALL", "LC_MESSAGES", "LANG"} { + if locale := os.Getenv(envVar); locale != "" { + tag := parseUnixLocale(locale) + if tag != language.Und { + return tag + } + } + } + + // Fall back to macOS defaults command + if locale := getMacOSLocale(); locale != "" { + if tag, err := language.Parse(locale); err == nil { + return tag + } + } + + return language.English +} + +// getMacOSLocale gets the locale from macOS system preferences +func getMacOSLocale() string { + // Try to get the locale from defaults + cmd := exec.Command("defaults", "read", "-g", "AppleLocale") + output, err := cmd.Output() + if err != nil { + return "" + } + locale := strings.TrimSpace(string(output)) + // Convert macOS format (en_US) to BCP 47 format (en-US) + return strings.Replace(locale, "_", "-", -1) +} + +// parseUnixLocale converts a Unix locale string to a language.Tag +// Examples: "en_US.UTF-8", "de_DE", "fr_FR.utf8", "C", "POSIX" +func parseUnixLocale(locale string) language.Tag { + // Handle special cases + if locale == "C" || locale == "POSIX" || locale == "" { + return language.English + } + + // Remove encoding suffix (e.g., ".UTF-8") + if idx := strings.Index(locale, "."); idx != -1 { + locale = locale[:idx] + } + + // Remove modifier (e.g., "@euro") + if idx := strings.Index(locale, "@"); idx != -1 { + locale = locale[:idx] + } + + // Convert underscore to hyphen for BCP 47 format + locale = strings.Replace(locale, "_", "-", -1) + + if tag, err := language.Parse(locale); err == nil { + return tag + } + + // Try with just the language part + if idx := strings.Index(locale, "-"); idx != -1 { + if tag, err := language.Parse(locale[:idx]); err == nil { + return tag + } + } + + return language.Und +} diff --git a/pkg/sqlcmd/regional_linux.go b/pkg/sqlcmd/regional_linux.go new file mode 100644 index 00000000..cb5a4350 --- /dev/null +++ b/pkg/sqlcmd/regional_linux.go @@ -0,0 +1,62 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build linux + +package sqlcmd + +import ( + "os" + "strings" + + "golang.org/x/text/language" +) + +// detectUserLocale returns the user's locale from environment variables +func detectUserLocale() language.Tag { + // Check standard locale environment variables in order of precedence + for _, envVar := range []string{"LC_ALL", "LC_MESSAGES", "LANG"} { + if locale := os.Getenv(envVar); locale != "" { + tag := parseUnixLocale(locale) + if tag != language.Und { + return tag + } + } + } + return language.English +} + +// parseUnixLocale converts a Unix locale string to a language.Tag +// Examples: "en_US.UTF-8", "de_DE", "fr_FR.utf8", "C", "POSIX" +func parseUnixLocale(locale string) language.Tag { + // Handle special cases + if locale == "C" || locale == "POSIX" || locale == "" { + return language.English + } + + // Remove encoding suffix (e.g., ".UTF-8") + if idx := strings.Index(locale, "."); idx != -1 { + locale = locale[:idx] + } + + // Remove modifier (e.g., "@euro") + if idx := strings.Index(locale, "@"); idx != -1 { + locale = locale[:idx] + } + + // Convert underscore to hyphen for BCP 47 format + locale = strings.Replace(locale, "_", "-", -1) + + if tag, err := language.Parse(locale); err == nil { + return tag + } + + // Try with just the language part + if idx := strings.Index(locale, "-"); idx != -1 { + if tag, err := language.Parse(locale[:idx]); err == nil { + return tag + } + } + + return language.Und +} diff --git a/pkg/sqlcmd/regional_other.go b/pkg/sqlcmd/regional_other.go new file mode 100644 index 00000000..71195d32 --- /dev/null +++ b/pkg/sqlcmd/regional_other.go @@ -0,0 +1,64 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build !windows && !linux && !darwin + +package sqlcmd + +import ( + "os" + "strings" + + "golang.org/x/text/language" +) + +// detectUserLocale returns the user's locale from environment variables. +// This is a fallback implementation for platforms other than Windows, Linux, and Darwin. +// It uses the same environment variable approach as Linux. +func detectUserLocale() language.Tag { + // Check standard locale environment variables in order of precedence + for _, envVar := range []string{"LC_ALL", "LC_MESSAGES", "LANG"} { + if locale := os.Getenv(envVar); locale != "" { + tag := parseUnixLocale(locale) + if tag != language.Und { + return tag + } + } + } + return language.English +} + +// parseUnixLocale converts a Unix locale string to a language.Tag +// Examples: "en_US.UTF-8", "de_DE", "fr_FR.utf8", "C", "POSIX" +func parseUnixLocale(locale string) language.Tag { + // Handle special cases + if locale == "C" || locale == "POSIX" || locale == "" { + return language.English + } + + // Remove encoding suffix (e.g., ".UTF-8") + if idx := strings.Index(locale, "."); idx != -1 { + locale = locale[:idx] + } + + // Remove modifier (e.g., "@euro") + if idx := strings.Index(locale, "@"); idx != -1 { + locale = locale[:idx] + } + + // Convert underscore to hyphen for BCP 47 format + locale = strings.Replace(locale, "_", "-", -1) + + if tag, err := language.Parse(locale); err == nil { + return tag + } + + // Try with just the language part + if idx := strings.Index(locale, "-"); idx != -1 { + if tag, err := language.Parse(locale[:idx]); err == nil { + return tag + } + } + + return language.Und +} diff --git a/pkg/sqlcmd/regional_test.go b/pkg/sqlcmd/regional_test.go new file mode 100644 index 00000000..b073b1ef --- /dev/null +++ b/pkg/sqlcmd/regional_test.go @@ -0,0 +1,218 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "testing" + "time" + + "github.com/stretchr/testify/assert" + "golang.org/x/text/language" +) + +func TestRegionalSettings_Disabled(t *testing.T) { + r := NewRegionalSettings(false) + assert.False(t, r.IsEnabled()) + + // When disabled, all values should pass through unchanged + assert.Equal(t, "1234.56", r.FormatNumber("1234.56")) + assert.Equal(t, "1234.5600", r.FormatMoney("1234.5600")) + + testTime := time.Date(2024, 1, 15, 14, 30, 45, 0, time.UTC) + assert.Equal(t, "2024-01-15", r.FormatDate(testTime)) + assert.Equal(t, "14:30:45", r.FormatTime(testTime, 0)) +} + +func TestRegionalSettings_Enabled(t *testing.T) { + r := NewRegionalSettings(true) + assert.True(t, r.IsEnabled()) + + // When enabled, values should be formatted according to locale + // The specific format depends on the system locale, so we just verify it works + number := r.FormatNumber("1234.56") + assert.NotEmpty(t, number) + + money := r.FormatMoney("1234.5600") + assert.NotEmpty(t, money) +} + +func TestRegionalSettings_NullHandling(t *testing.T) { + r := NewRegionalSettings(true) + + // NULL values should pass through unchanged + assert.Equal(t, "NULL", r.FormatNumber("NULL")) + assert.Equal(t, "NULL", r.FormatMoney("NULL")) + assert.Equal(t, "", r.FormatNumber("")) + assert.Equal(t, "", r.FormatMoney("")) +} + +func TestGetDecimalSeparator(t *testing.T) { + tests := []struct { + locale string + expected string + }{ + {"en-US", "."}, + {"en-GB", "."}, + {"de-DE", ","}, + {"fr-FR", ","}, + {"es-ES", ","}, + {"ja-JP", "."}, + {"zh-CN", "."}, + } + + for _, tc := range tests { + t.Run(tc.locale, func(t *testing.T) { + tag := language.MustParse(tc.locale) + sep := getDecimalSeparator(tag) + assert.Equal(t, tc.expected, sep, "Decimal separator for %s", tc.locale) + }) + } +} + +func TestGetThousandSeparator(t *testing.T) { + tests := []struct { + locale string + expected string + }{ + {"en-US", ","}, + {"en-GB", ","}, + {"de-DE", "."}, + {"fr-FR", "."}, + {"ja-JP", ","}, + } + + for _, tc := range tests { + t.Run(tc.locale, func(t *testing.T) { + tag := language.MustParse(tc.locale) + sep := getThousandSeparator(tag) + assert.Equal(t, tc.expected, sep, "Thousand separator for %s", tc.locale) + }) + } +} + +func TestAddThousandSeparators(t *testing.T) { + enUS := language.MustParse("en-US") + deDE := language.MustParse("de-DE") + + tests := []struct { + input string + locale language.Tag + expected string + }{ + {"1", enUS, "1"}, + {"12", enUS, "12"}, + {"123", enUS, "123"}, + {"1234", enUS, "1,234"}, + {"12345", enUS, "12,345"}, + {"123456", enUS, "123,456"}, + {"1234567", enUS, "1,234,567"}, + {"1234567890", enUS, "1,234,567,890"}, + {"1234", deDE, "1.234"}, + {"1234567", deDE, "1.234.567"}, + } + + for _, tc := range tests { + t.Run(tc.input, func(t *testing.T) { + result := addThousandSeparators(tc.input, tc.locale) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestGetLocaleDateTimeFormats(t *testing.T) { + tests := []struct { + locale string + wantDate string + wantTime string + description string + }{ + {"en-US", "01/02/2006", "03:04:05 PM", "US English uses M/D/Y and 12-hour time"}, + {"en-GB", "02/01/2006", "15:04:05", "UK English uses D/M/Y and 24-hour time"}, + {"de-DE", "02.01.2006", "15:04:05", "German uses D.M.Y format"}, + {"ja-JP", "2006/01/02", "15:04:05", "Japanese uses Y/M/D format"}, + {"fi-FI", "2006-01-02", "15:04:05", "Finnish uses ISO format"}, + } + + for _, tc := range tests { + t.Run(tc.locale, func(t *testing.T) { + tag := language.MustParse(tc.locale) + dateFmt, timeFmt := getLocaleDateTimeFormats(tag) + assert.Equal(t, tc.wantDate, dateFmt, tc.description) + assert.Equal(t, tc.wantTime, timeFmt, tc.description) + }) + } +} + +func TestFormatOffset(t *testing.T) { + tests := []struct { + hours int + minutes int + expected string + }{ + {0, 0, "+00:00"}, + {5, 30, "+05:30"}, + {-5, 0, "-05:00"}, + {-8, 0, "-08:00"}, + {12, 45, "+12:45"}, + {-12, 0, "-12:00"}, + } + + for _, tc := range tests { + t.Run(tc.expected, func(t *testing.T) { + result := formatOffset(tc.hours, tc.minutes) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestPow10(t *testing.T) { + tests := []struct { + n int + expected int + }{ + {0, 1}, + {1, 10}, + {2, 100}, + {3, 1000}, + {6, 1000000}, + } + + for _, tc := range tests { + result := pow10(tc.n) + assert.Equal(t, tc.expected, result) + } +} + +func TestPadLeftStr(t *testing.T) { + tests := []struct { + input string + length int + pad rune + expected string + }{ + {"5", 2, '0', "05"}, + {"12", 2, '0', "12"}, + {"1", 4, '0', "0001"}, + {"abc", 5, ' ', " abc"}, + } + + for _, tc := range tests { + result := padLeftStr(tc.input, tc.length, tc.pad) + assert.Equal(t, tc.expected, result) + } +} + +func TestNewSQLCmdDefaultFormatterWithRegional(t *testing.T) { + // Test that the formatter is created correctly with regional settings + f := NewSQLCmdDefaultFormatterWithRegional(false, ControlIgnore, true) + assert.NotNil(t, f) + + // Test without regional settings + f2 := NewSQLCmdDefaultFormatterWithRegional(false, ControlIgnore, false) + assert.NotNil(t, f2) + + // Test backward compatibility - NewSQLCmdDefaultFormatter should work + f3 := NewSQLCmdDefaultFormatter(false, ControlIgnore) + assert.NotNil(t, f3) +} diff --git a/pkg/sqlcmd/regional_windows.go b/pkg/sqlcmd/regional_windows.go new file mode 100644 index 00000000..846b0dd7 --- /dev/null +++ b/pkg/sqlcmd/regional_windows.go @@ -0,0 +1,136 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +//go:build windows + +package sqlcmd + +import ( + "syscall" + "unsafe" + + "golang.org/x/text/language" +) + +var ( + kernel32 = syscall.NewLazyDLL("kernel32.dll") + procGetUserDefaultLCID = kernel32.NewProc("GetUserDefaultLCID") +) + +// detectUserLocale returns the user's locale from Windows settings +func detectUserLocale() language.Tag { + // Get user default locale + ret, _, _ := procGetUserDefaultLCID.Call() + lcid := uint32(ret) + locale := lcidToLanguageTag(lcid) + if tag, err := language.Parse(locale); err == nil { + return tag + } + return language.English +} + +// suppressUnused is used to prevent "imported and not used" errors +var _ = unsafe.Sizeof(0) + +// lcidToLanguageTag converts a Windows LCID to a BCP 47 language tag +func lcidToLanguageTag(lcid uint32) string { + // Common LCID mappings + // See: https://docs.microsoft.com/en-us/openspecs/windows_protocols/ms-lcid + switch lcid { + case 0x0409: + return "en-US" + case 0x0809: + return "en-GB" + case 0x0c09: + return "en-AU" + case 0x1009: + return "en-CA" + case 0x0407: + return "de-DE" + case 0x0807: + return "de-CH" + case 0x0c07: + return "de-AT" + case 0x040c: + return "fr-FR" + case 0x080c: + return "fr-BE" + case 0x0c0c: + return "fr-CA" + case 0x100c: + return "fr-CH" + case 0x0410: + return "it-IT" + case 0x0810: + return "it-CH" + case 0x0c0a: + return "es-ES" + case 0x080a: + return "es-MX" + case 0x2c0a: + return "es-AR" + case 0x0416: + return "pt-BR" + case 0x0816: + return "pt-PT" + case 0x0413: + return "nl-NL" + case 0x0813: + return "nl-BE" + case 0x0419: + return "ru-RU" + case 0x0415: + return "pl-PL" + case 0x0405: + return "cs-CZ" + case 0x041b: + return "sk-SK" + case 0x040e: + return "hu-HU" + case 0x0418: + return "ro-RO" + case 0x0402: + return "bg-BG" + case 0x041a: + return "hr-HR" + case 0x0424: + return "sl-SI" + case 0x0c1a: + return "sr-Latn-RS" + case 0x081a: + return "sr-Cyrl-RS" + case 0x041f: + return "tr-TR" + case 0x0408: + return "el-GR" + case 0x0422: + return "uk-UA" + case 0x0423: + return "be-BY" + case 0x040b: + return "fi-FI" + case 0x041d: + return "sv-SE" + case 0x0414: + return "nb-NO" + case 0x0814: + return "nn-NO" + case 0x0406: + return "da-DK" + case 0x040f: + return "is-IS" + case 0x0411: + return "ja-JP" + case 0x0412: + return "ko-KR" + case 0x0804: + return "zh-CN" + case 0x0404: + return "zh-TW" + case 0x0c04: + return "zh-HK" + default: + // Default to US English + return "en-US" + } +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..da41dd76 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -86,6 +86,8 @@ type Sqlcmd struct { UnicodeOutputFile bool // EchoInput tells the GO command to print the batch text before running the query EchoInput bool + // CodePage specifies input/output file encoding + CodePage *CodePageSettings colorizer color.Colorizer termchan chan os.Signal } @@ -331,9 +333,26 @@ func (s *Sqlcmd) IncludeFile(path string, processAll bool) error { } defer f.Close() b := s.batch.batchline - utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) - unicodeReader := transform.NewReader(f, utf16bom) - scanner := bufio.NewReader(unicodeReader) + + // Set up the reader with appropriate encoding + var reader io.Reader = f + if s.CodePage != nil && s.CodePage.InputCodePage != 0 { + // Use specified input codepage + enc, err := GetEncoding(s.CodePage.InputCodePage) + if err != nil { + return err + } + if enc != nil { + // Transform from specified encoding to UTF-8 + reader = transform.NewReader(f, enc.NewDecoder()) + } + // If enc is nil, it's UTF-8, no transformation needed + } else { + // Default: auto-detect BOM for UTF-16, fallback to UTF-8 + utf16bom := unicode.BOMOverride(unicode.UTF8.NewDecoder()) + reader = transform.NewReader(f, utf16bom) + } + scanner := bufio.NewReader(reader) curLine := s.batch.read echoFileLines := s.echoFileLines ln := make([]byte, 0, 2*1024*1024)