diff --git a/CHANGELOG.md b/CHANGELOG.md new file mode 100644 index 00000000..060795d2 --- /dev/null +++ b/CHANGELOG.md @@ -0,0 +1,33 @@ +# Changelog + +All notable changes to this project will be documented in this file. + +## [Unreleased] + +### Added +- `:HELP` command to display available sqlcmd commands +- `-p` flag for printing performance statistics after each batch +- `-j` flag for printing raw error messages without formatting +- `:PERFTRACE` command to redirect timing output to file +- `:SERVERLIST` command to list available SQL Server instances on the network +- Multi-line `EXIT(query)` support in interactive mode - queries with unbalanced parentheses now prompt for continuation + +### Fixed +- Statistics format (`-p` flag) now matches ODBC sqlcmd output format +- Panic on empty args slice in command parser + +### Changed +- **Breaking for go-sqlcmd users**: `-u` (Unicode output) no longer writes a UTF-16LE BOM (Byte Order Mark) to output files. This change aligns go-sqlcmd with ODBC sqlcmd behavior, which never wrote a BOM. If your workflows depended on the BOM being present, you may need to adjust them. + +## Notes on ODBC sqlcmd Compatibility + +This release significantly improves compatibility with the original ODBC-based sqlcmd: + +| Feature | Previous go-sqlcmd | Now | ODBC sqlcmd | +|---------|-------------------|-----|-------------| +| `-u` output BOM | Wrote BOM | No BOM | No BOM ✓ | +| `-p` statistics format | Different format | Matches | Matches ✓ | +| `-r` without argument | Required argument | Defaults to 0 | Defaults to 0 ✓ | +| `EXIT(query)` multi-line | Not supported | Supported | Supported ✓ | +| `:HELP` command | Not available | Available | Available ✓ | +| `:SERVERLIST` command | Not available | Available | Available ✓ | diff --git a/README.md b/README.md index fe26e192..20174f0a 100644 --- a/README.md +++ b/README.md @@ -132,9 +132,7 @@ The following switches have different behavior in this version of `sqlcmd` compa - If both `-N` and `-C` are provided, sqlcmd will use their values for encryption negotiation. - To provide the value of the host name in the server certificate when using strict encryption, pass the host name with `-F`. Example: `-Ns -F myhost.domain.com` - More information about client/server encryption negotiation can be found at -- `-u` The generated Unicode output file will have the UTF16 Little-Endian Byte-order mark (BOM) written to it. - Some behaviors that were kept to maintain compatibility with `OSQL` may be changed, such as alignment of column headers for some data types. -- All commands must fit on one line, even `EXIT`. Interactive mode will not check for open parentheses or quotes for commands and prompt for successive lines. The ODBC sqlcmd allows the query run by `EXIT(query)` to span multiple lines. - `-i` doesn't handle a comma `,` in a file name correctly unless the file name argument is triple quoted. For example: `sqlcmd -i """select,100.sql"""` will try to open a file named `sql,100.sql` while `sqlcmd -i "select,100.sql"` will try to open two files `select` and `100.sql` - If using a single `-i` flag to pass multiple file names, there must be a space after the `-i`. Example: `-i file1.sql file2.sql` @@ -175,6 +173,13 @@ program_name sqlcmd net_transport Named pipe ``` +- The new `-p` (`--print-statistics`) flag prints performance statistics after each batch execution, including network packet size, transaction count, and clock time (total, average, and transactions per second). +- The new `-j` (`--raw-errors`) flag prints raw error messages without the standard "Msg #, Level, State, Server, Line" prefix formatting. +- The new `:HELP` interactive command displays a list of all available sqlcmd commands with descriptions. +- The new `:PERFTRACE |STDERR|STDOUT` interactive command redirects timing output to a file or stream. This is useful when using `-p` to separate statistics from query output. +- The new `:SERVERLIST` interactive command lists local and network SQL Server instances (same as `-L` flag but available during an interactive session). +- `EXIT(query)` now supports multi-line queries in interactive mode. When an unclosed parenthesis is detected, sqlcmd prompts for additional input until the query is complete. + ### Azure Active Directory Authentication `sqlcmd` supports a broader range of AAD authentication models (over the original ODBC based `sqlcmd`), based on the [azidentity package](https://pkg.go.dev/github.com/Azure/azure-sdk-for-go/sdk/azidentity). The implementation relies on an AAD Connector in the [driver](https://github.com/microsoft/go-mssqldb). diff --git a/cmd/sqlcmd/sqlcmd.go b/cmd/sqlcmd/sqlcmd.go index ea655b47..263e20bf 100644 --- a/cmd/sqlcmd/sqlcmd.go +++ b/cmd/sqlcmd/sqlcmd.go @@ -5,20 +5,16 @@ package sqlcmd import ( - "context" "errors" "fmt" - "net" "os" "regexp" "runtime/trace" "strconv" "strings" - "time" mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-mssqldb/azuread" - "github.com/microsoft/go-mssqldb/msdsn" "github.com/microsoft/go-sqlcmd/internal/localizer" "github.com/microsoft/go-sqlcmd/pkg/console" "github.com/microsoft/go-sqlcmd/pkg/sqlcmd" @@ -82,6 +78,8 @@ type SQLCmdArguments struct { ChangePassword string ChangePasswordAndExit string TraceFile string + PrintStatistics bool + RawErrors bool // Keep Help at the end of the list Help bool } @@ -236,7 +234,11 @@ func Execute(version string) { fmt.Println() fmt.Println(localizer.Sprintf("Servers:")) } - listLocalServers() + instances, _ := sqlcmd.ListServers(0) + servers := sqlcmd.FormatServerList(instances) + for _, s := range servers { + fmt.Println(" ", s) + } os.Exit(0) } if len(argss) > 0 { @@ -479,6 +481,8 @@ func setFlags(rootCmd *cobra.Command, args *SQLCmdArguments) { rootCmd.Flags().BoolVarP(&args.EnableColumnEncryption, "enable-column-encryption", "g", false, localizer.Sprintf("Enable column encryption")) rootCmd.Flags().StringVarP(&args.ChangePassword, "change-password", "z", "", localizer.Sprintf("New password")) rootCmd.Flags().StringVarP(&args.ChangePasswordAndExit, "change-password-exit", "Z", "", localizer.Sprintf("New password and exit")) + rootCmd.Flags().BoolVarP(&args.PrintStatistics, "print-statistics", "p", false, localizer.Sprintf("Print performance statistics for each batch")) + rootCmd.Flags().BoolVarP(&args.RawErrors, "raw-errors", "j", false, localizer.Sprintf("Print raw error messages without additional formatting")) } func setScriptVariable(v string) string { @@ -817,6 +821,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.Cmd.DisableSysCommands(args.errorOnBlockedCmd()) } s.EchoInput = args.EchoInput + s.PrintStatistics = args.PrintStatistics if args.BatchTerminator != "GO" { err = s.Cmd.SetBatchTerminator(args.BatchTerminator) if err != nil { @@ -828,7 +833,7 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { } s.Connect = &connectConfig - s.Format = sqlcmd.NewSQLCmdDefaultFormatter(args.TrimSpaces, args.getControlCharacterBehavior()) + s.Format = sqlcmd.NewSQLCmdDefaultFormatterWithOptions(args.TrimSpaces, args.getControlCharacterBehavior(), args.RawErrors) if args.OutputFile != "" { err = s.RunCommand(s.Cmd["OUT"], []string{args.OutputFile}) if err != nil { @@ -911,76 +916,3 @@ func run(vars *sqlcmd.Variables, args *SQLCmdArguments) (int, error) { s.SetError(nil) return s.Exitcode, err } - -func listLocalServers() { - bmsg := []byte{byte(msdsn.BrowserAllInstances)} - resp := make([]byte, 16*1024-1) - dialer := &net.Dialer{} - ctx, cancel := context.WithTimeout(context.Background(), time.Second*30) - defer cancel() - conn, err := dialer.DialContext(ctx, "udp", ":1434") - // silently ignore failures to connect, same as ODBC - if err != nil { - return - } - defer conn.Close() - dl, _ := ctx.Deadline() - _ = conn.SetDeadline(dl) - _, err = conn.Write(bmsg) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - read, err := conn.Read(resp) - if err != nil { - if !errors.Is(err, os.ErrDeadlineExceeded) { - fmt.Println(err) - } - return - } - - data := parseInstances(resp[:read]) - instances := make([]string, 0, len(data)) - for s := range data { - if s == "MSSQLSERVER" { - - instances = append(instances, "(local)", data[s]["ServerName"]) - } else { - instances = append(instances, fmt.Sprintf(`%s\%s`, data[s]["ServerName"], s)) - } - } - for _, s := range instances { - fmt.Println(" ", s) - } -} - -func parseInstances(msg []byte) msdsn.BrowserData { - results := msdsn.BrowserData{} - if len(msg) > 3 && msg[0] == 5 { - out_s := string(msg[3:]) - tokens := strings.Split(out_s, ";") - instdict := map[string]string{} - got_name := false - var name string - for _, token := range tokens { - if got_name { - instdict[name] = token - got_name = false - } else { - name = token - if len(name) == 0 { - if len(instdict) == 0 { - break - } - results[strings.ToUpper(instdict["InstanceName"])] = instdict - instdict = map[string]string{} - continue - } - got_name = true - } - } - } - return results -} diff --git a/cmd/sqlcmd/testdata/unicodeout.txt b/cmd/sqlcmd/testdata/unicodeout.txt index 1cb61880..60a0def5 100644 Binary files a/cmd/sqlcmd/testdata/unicodeout.txt and b/cmd/sqlcmd/testdata/unicodeout.txt differ diff --git a/cmd/sqlcmd/testdata/unicodeout_linux.txt b/cmd/sqlcmd/testdata/unicodeout_linux.txt index f949932a..64c0b393 100644 Binary files a/cmd/sqlcmd/testdata/unicodeout_linux.txt and b/cmd/sqlcmd/testdata/unicodeout_linux.txt differ diff --git a/pkg/sqlcmd/commands.go b/pkg/sqlcmd/commands.go index 66dd1dba..43366b97 100644 --- a/pkg/sqlcmd/commands.go +++ b/pkg/sqlcmd/commands.go @@ -11,6 +11,7 @@ import ( "sort" "strconv" "strings" + "time" "github.com/microsoft/go-sqlcmd/internal/color" "golang.org/x/text/encoding/unicode" @@ -113,6 +114,21 @@ func newCommands() Commands { action: xmlCommand, name: "XML", }, + "HELP": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:HELP(?:[ \t]+(.*$)|$)`), + action: helpCommand, + name: "HELP", + }, + "PERFTRACE": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:PERFTRACE(?:[ \t]+(.*$)|$)`), + action: perftraceCommand, + name: "PERFTRACE", + }, + "SERVERLIST": { + regex: regexp.MustCompile(`(?im)^[\t ]*?:SERVERLIST(?:[ \t]+(.*$)|$)`), + action: serverlistCommand, + name: "SERVERLIST", + }, } } @@ -212,6 +228,8 @@ func (c Commands) SetBatchTerminator(terminator string) error { // With no (), it just exits without running any query // With () it runs whatever batch is in the buffer then exits // With any text between () it runs the text as a query then exits +// In interactive mode, EXIT(query) can span multiple lines - sqlcmd will +// continue prompting for input until the closing ) is found. func exitCommand(s *Sqlcmd, args []string, line uint) error { if len(args) == 0 { return ErrExitRequested @@ -220,7 +238,26 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { if params == "" { return ErrExitRequested } - if !strings.HasPrefix(params, "(") || !strings.HasSuffix(params, ")") { + + // Check if we have an opening paren + if !strings.HasPrefix(params, "(") { + return InvalidCommandError("EXIT", line) + } + + // If we don't have a closing paren, try to read continuation lines (interactive mode only) + if !isExitParenBalanced(params) { + if s.lineIo == nil { + // Not in interactive mode, can't read more lines + return InvalidCommandError("EXIT", line) + } + var err error + params, err = readExitContinuation(s, params) + if err != nil { + return err + } + } + + if !strings.HasSuffix(params, ")") { return InvalidCommandError("EXIT", line) } // First we save the current batch @@ -249,11 +286,58 @@ func exitCommand(s *Sqlcmd, args []string, line uint) error { if len(query1) > 0 || len(query2) > 0 { query := query1 + SqlcmdEol + query2 + startTime := time.Now() s.Exitcode, _ = s.runQuery(query) + elapsed := time.Since(startTime) + s.PrintPerformanceStatistics(elapsed.Milliseconds(), 1) } return ErrExitRequested } +// isExitParenBalanced checks if the parentheses in an EXIT command argument are balanced. +// It tracks quotes to avoid counting parens inside string literals. +func isExitParenBalanced(s string) bool { + depth := 0 + var quote rune = 0 + for _, c := range s { + switch { + case quote != 0: + // Inside a quoted string + if c == quote { + quote = 0 + } + case c == '\'' || c == '"': + quote = c + case c == '[': + quote = ']' // SQL Server bracket quoting + case c == '(': + depth++ + case c == ')': + depth-- + } + } + return depth == 0 +} + +// readExitContinuation reads additional lines from the console until the EXIT +// parentheses are balanced. This enables multi-line EXIT(query) in interactive mode. +func readExitContinuation(s *Sqlcmd, params string) (string, error) { + var builder strings.Builder + builder.WriteString(params) + + for !isExitParenBalanced(builder.String()) { + // Show continuation prompt + s.lineIo.SetPrompt(" -> ") + line, err := s.lineIo.Readline() + if err != nil { + return "", err + } + builder.WriteString(SqlcmdEol) + builder.WriteString(line) + } + return builder.String(), nil +} + // quitCommand immediately exits the program without running any more batches func quitCommand(s *Sqlcmd, args []string, line uint) error { if args != nil && strings.TrimSpace(args[0]) != "" { @@ -290,12 +374,15 @@ func goCommand(s *Sqlcmd, args []string, line uint) error { return nil } query = s.getRunnableQuery(query) + startTime := time.Now() for i := 0; i < n; i++ { if retcode, err := s.runQuery(query); err != nil { s.Exitcode = retcode return err } } + elapsed := time.Since(startTime) + s.PrintPerformanceStatistics(elapsed.Milliseconds(), n) s.batch.Reset(nil) return nil } @@ -321,9 +408,8 @@ func outCommand(s *Sqlcmd, args []string, line uint) error { return InvalidFileError(err, args[0]) } if s.UnicodeOutputFile { - // ODBC sqlcmd doesn't write a BOM but we will. - // Maybe the endian-ness should be configurable. - win16le := unicode.UTF16(unicode.LittleEndian, unicode.UseBOM) + // Match ODBC sqlcmd behavior: UTF-16LE without BOM + win16le := unicode.UTF16(unicode.LittleEndian, unicode.IgnoreBOM) encoder := transform.NewWriter(o, win16le.NewEncoder()) s.SetOutput(encoder) } else { @@ -357,6 +443,52 @@ func errorCommand(s *Sqlcmd, args []string, line uint) error { return nil } +// perftraceCommand changes the performance trace writer to use a file +func perftraceCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) == 0 || args[0] == "" { + return InvalidCommandError(":PERFTRACE", line) + } + filePath, err := resolveArgumentVariables(s, []rune(args[0]), true) + if err != nil { + return err + } + switch { + case strings.EqualFold(filePath, "stderr"): + s.SetPerftrace(os.Stderr) + case strings.EqualFold(filePath, "stdout"): + s.SetPerftrace(os.Stdout) + default: + o, err := os.OpenFile(filePath, os.O_TRUNC|os.O_CREATE|os.O_WRONLY, 0o644) + if err != nil { + return InvalidFileError(err, args[0]) + } + s.SetPerftrace(o) + } + return nil +} + +// serverlistCommand lists SQL Server instances on the network +func serverlistCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError(":SERVERLIST", line) + } + + output := s.GetOutput() + fmt.Fprintf(output, "Servers:%s", SqlcmdEol) + + instances, err := ListServers(0) + if err != nil { + // Silently ignore errors (same as ODBC sqlcmd) + return nil + } + + servers := FormatServerList(instances) + for _, srv := range servers { + fmt.Fprintf(output, " %s%s", srv, SqlcmdEol) + } + return nil +} + func readFileCommand(s *Sqlcmd, args []string, line uint) error { if args == nil || len(args) != 1 { return InvalidCommandError(":R", line) @@ -413,6 +545,45 @@ func listVarCommand(s *Sqlcmd, args []string, line uint) error { return nil } +// helpCommand displays the list of available sqlcmd commands +func helpCommand(s *Sqlcmd, args []string, line uint) error { + if len(args) > 0 && strings.TrimSpace(args[0]) != "" { + return InvalidCommandError(":HELP", line) + } + + helpText := []struct { + cmd string + desc string + }{ + {"ED", "Edit the statement cache using the default editor"}, + {"!!", "Execute operating system command"}, + {":CONNECT server[\\instance] [-l timeout] [-U user [-P password]]", "Connect to a server"}, + {":ERROR |STDERR|STDOUT", "Redirect error output to file"}, + {":EXIT", "Exit sqlcmd"}, + {":EXIT()", "Execute query, exit returning no value"}, + {":EXIT()", "Execute query and exit returning numeric result"}, + {"GO []", "Execute batch [n times]"}, + {":HELP", "Show this list of commands"}, + {":LIST", "Print statement cache contents"}, + {":LISTVAR", "List scripting variables"}, + {":ON ERROR [EXIT|IGNORE]", "Action on error"}, + {":OUT |STDERR|STDOUT", "Redirect output to file"}, + {":PERFTRACE |STDERR|STDOUT", "Redirect timing output to file"}, + {":QUIT", "Exit sqlcmd immediately"}, + {":SERVERLIST", "List local and network SQL Server instances"}, + {":R ", "Read input from file"}, + {":RESET", "Clear statement cache"}, + {":SETVAR ", "Set scripting variable"}, + {":XML [ON|OFF]", "Enable or disable XML output mode"}, + } + + output := s.GetOutput() + for _, h := range helpText { + fmt.Fprintf(output, "%-60s %s%s", h.cmd, h.desc, SqlcmdEol) + } + return nil +} + // resetCommand resets the statement cache func resetCommand(s *Sqlcmd, args []string, line uint) error { if s.batch != nil { diff --git a/pkg/sqlcmd/commands_test.go b/pkg/sqlcmd/commands_test.go index 6197aa3f..a1e1589d 100644 --- a/pkg/sqlcmd/commands_test.go +++ b/pkg/sqlcmd/commands_test.go @@ -6,6 +6,7 @@ package sqlcmd import ( "bytes" "fmt" + "io" "os" "strings" "testing" @@ -161,6 +162,87 @@ func TestVarCommands(t *testing.T) { } +func TestHelpCommand(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + + err := helpCommand(s, []string{""}, 1) + assert.NoError(t, err, "helpCommand") + s.SetOutput(nil) + + o := buf.buf.String() + t.Logf("Help output:\n'%s'", o) + + // Verify key commands are listed + assert.Contains(t, o, ":HELP", "Help output should contain :HELP") + assert.Contains(t, o, ":CONNECT", "Help output should contain :CONNECT") + assert.Contains(t, o, ":SETVAR", "Help output should contain :SETVAR") + assert.Contains(t, o, "GO", "Help output should contain GO") + assert.Contains(t, o, ":EXIT", "Help output should contain :EXIT") + assert.Contains(t, o, ":QUIT", "Help output should contain :QUIT") + + // Verify help command rejects arguments + err = helpCommand(s, []string{"invalid"}, 1) + assert.Error(t, err, "helpCommand should reject arguments") +} + +func TestPerftraceCommand(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + + // Test setting perftrace to stdout + err := perftraceCommand(s, []string{"stdout"}, 1) + assert.NoError(t, err, "perftraceCommand with stdout") + + // Test setting perftrace to stderr + err = perftraceCommand(s, []string{"stderr"}, 1) + assert.NoError(t, err, "perftraceCommand with stderr") + + // Test setting perftrace to a file + tmpFile, err := os.CreateTemp("", "perftrace_test_*.txt") + assert.NoError(t, err, "CreateTemp") + tmpFileName := tmpFile.Name() + tmpFile.Close() + defer os.Remove(tmpFileName) + + err = perftraceCommand(s, []string{tmpFileName}, 1) + assert.NoError(t, err, "perftraceCommand with file") + + // Verify we can write to the perftrace output + _, err = fmt.Fprintf(s.GetPerftrace(), "test output") + assert.NoError(t, err, "Write to perftrace") + + // Test error case - empty argument + err = perftraceCommand(s, []string{""}, 1) + assert.Error(t, err, "perftraceCommand should reject empty argument") + + // Test error case - no argument + err = perftraceCommand(s, nil, 1) + assert.Error(t, err, "perftraceCommand should reject nil argument") +} + +func TestServerlistCommand(t *testing.T) { + vars := InitializeVariables(false) + s := New(nil, "", vars) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + + // Test serverlist command (will output "Servers:" header even with no servers found) + err := serverlistCommand(s, []string{""}, 1) + assert.NoError(t, err, "serverlistCommand") + s.SetOutput(nil) + + o := buf.buf.String() + // Should at minimum contain the "Servers:" header + assert.Contains(t, o, "Servers:", "Output should contain Servers: header") + + // Verify serverlist command rejects arguments + err = serverlistCommand(s, []string{"invalid"}, 1) + assert.Error(t, err, "serverlistCommand should reject arguments") +} + // memoryBuffer has both Write and Close methods for use as io.WriteCloser type memoryBuffer struct { buf *bytes.Buffer @@ -458,3 +540,56 @@ func TestExitCommandAppendsParameterToCurrentBatch(t *testing.T) { } } + +func TestIsExitParenBalanced(t *testing.T) { + tests := []struct { + input string + balanced bool + }{ + {"()", true}, + {"(select 1)", true}, + {"(select 1", false}, + {"(select (1 + 2))", true}, + {"(select ')')", true}, // paren inside string + {"(select \"(\")", true}, // paren inside double-quoted string + {"(select [col)])", true}, // paren inside bracket-quoted identifier + {"(select 1) extra", true}, // balanced even with trailing text + {"((nested))", true}, + {"((nested)", false}, + } + for _, tt := range tests { + t.Run(tt.input, func(t *testing.T) { + result := isExitParenBalanced(tt.input) + assert.Equal(t, tt.balanced, result, "isExitParenBalanced(%q)", tt.input) + }) + } +} + +func TestExitMultiLine(t *testing.T) { + // Test that multi-line EXIT works in interactive mode + s, buf := setupSqlCmdWithMemoryOutput(t) + defer buf.Close() + + // Simulate multi-line input: "EXIT(SELECT 1" then " + 2)" + lines := []string{" + 2)"} + lineIndex := 0 + + s.lineIo = &testConsole{ + OnReadLine: func() (string, error) { + if lineIndex >= len(lines) { + return "", io.EOF + } + line := lines[lineIndex] + lineIndex++ + return line, nil + }, + OnPasswordPrompt: func(prompt string) ([]byte, error) { + return nil, nil + }, + } + + // Directly call exitCommand with unclosed paren + err := exitCommand(s, []string{"(SELECT 1"}, 1) + assert.Equal(t, ErrExitRequested, err, "exitCommand should return ErrExitRequested") + assert.Equal(t, 3, s.Exitcode, "EXIT should set exit code to query result (1+2=3)") +} diff --git a/pkg/sqlcmd/format.go b/pkg/sqlcmd/format.go index 55bd2e25..8288948d 100644 --- a/pkg/sqlcmd/format.go +++ b/pkg/sqlcmd/format.go @@ -85,6 +85,7 @@ type sqlCmdFormatterType struct { maxColNameLen int colorizer color.Colorizer xml bool + rawErrors bool } // NewSQLCmdDefaultFormatter returns a Formatter that mimics the original ODBC-based sqlcmd formatter @@ -97,6 +98,17 @@ func NewSQLCmdDefaultFormatter(removeTrailingSpaces bool, ccb ControlCharacterBe } } +// NewSQLCmdDefaultFormatterWithOptions returns a Formatter with additional options +func NewSQLCmdDefaultFormatterWithOptions(removeTrailingSpaces bool, ccb ControlCharacterBehavior, rawErrors bool) Formatter { + return &sqlCmdFormatterType{ + removeTrailingSpaces: removeTrailingSpaces, + format: "horizontal", + colorizer: color.New(false), + ccb: ccb, + rawErrors: rawErrors, + } +} + // Adds the given string to the current line, wrapping it based on the screen width setting func (f *sqlCmdFormatterType) writeOut(s string, t color.TextType) { w := f.vars.ScreenWidth() @@ -223,10 +235,12 @@ func (f *sqlCmdFormatterType) AddError(err error) { switch e := (err).(type) { case mssql.Error: if print = f.vars.ErrorLevel() <= 0 || e.Class >= uint8(f.vars.ErrorLevel()); print { - if len(e.ProcName) > 0 { - b.WriteString(localizer.Sprintf("Msg %#v, Level %d, State %d, Server %s, Procedure %s, Line %#v%s", e.Number, e.Class, e.State, e.ServerName, e.ProcName, e.LineNo, SqlcmdEol)) - } else { - b.WriteString(localizer.Sprintf("Msg %#v, Level %d, State %d, Server %s, Line %#v%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol)) + if !f.rawErrors { + if len(e.ProcName) > 0 { + b.WriteString(localizer.Sprintf("Msg %#v, Level %d, State %d, Server %s, Procedure %s, Line %#v%s", e.Number, e.Class, e.State, e.ServerName, e.ProcName, e.LineNo, SqlcmdEol)) + } else { + b.WriteString(localizer.Sprintf("Msg %#v, Level %d, State %d, Server %s, Line %#v%s", e.Number, e.Class, e.State, e.ServerName, e.LineNo, SqlcmdEol)) + } } msg = strings.TrimPrefix(msg, "mssql: ") } diff --git a/pkg/sqlcmd/format_test.go b/pkg/sqlcmd/format_test.go index 0f304632..fa92373e 100644 --- a/pkg/sqlcmd/format_test.go +++ b/pkg/sqlcmd/format_test.go @@ -4,10 +4,12 @@ package sqlcmd import ( + "bytes" "context" "strings" "testing" + mssql "github.com/microsoft/go-mssqldb" "github.com/microsoft/go-sqlcmd/internal/color" "github.com/stretchr/testify/assert" ) @@ -158,3 +160,61 @@ func TestFormatterXmlMode(t *testing.T) { assert.NoError(t, err, "runSqlCmd returned error") assert.Equal(t, ``+SqlcmdEol, buf.buf.String()) } + +func TestNewSQLCmdDefaultFormatterWithOptions(t *testing.T) { + // Test that the formatter can be created with raw errors option + f := NewSQLCmdDefaultFormatterWithOptions(true, ControlIgnore, true) + assert.NotNil(t, f, "Formatter should not be nil") + + // Verify it's the correct type with rawErrors enabled + formatter, ok := f.(*sqlCmdFormatterType) + assert.True(t, ok, "Should be sqlCmdFormatterType") + assert.True(t, formatter.rawErrors, "rawErrors should be true") + assert.True(t, formatter.removeTrailingSpaces, "removeTrailingSpaces should be true") + + // Test with rawErrors disabled + f2 := NewSQLCmdDefaultFormatterWithOptions(false, ControlReplace, false) + formatter2, _ := f2.(*sqlCmdFormatterType) + assert.False(t, formatter2.rawErrors, "rawErrors should be false") + assert.False(t, formatter2.removeTrailingSpaces, "removeTrailingSpaces should be false") +} + +func TestRawErrorsOutput(t *testing.T) { + // Test that raw errors mode (-j flag) omits the "Msg #, Level, State, Server, Line" prefix + v := InitializeVariables(true) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + + // Test with rawErrors = false (default behavior) + formatter := NewSQLCmdDefaultFormatterWithOptions(true, ControlIgnore, false) + formatter.BeginBatch("", v, buf, buf) + + // Simulate an SQL error + mssqlErr := mssql.Error{ + Number: 208, + Class: 16, + State: 1, + ServerName: "TESTSERVER", + Message: "Invalid object name 'nonexistent'.", + } + formatter.AddError(mssqlErr) + + output := buf.buf.String() + assert.Contains(t, output, "Msg 208", "Default mode should include Msg prefix") + assert.Contains(t, output, "Level 16", "Default mode should include Level") + assert.Contains(t, output, "State 1", "Default mode should include State") + assert.Contains(t, output, "TESTSERVER", "Default mode should include Server") + assert.Contains(t, output, "Invalid object name", "Should include error message") + + // Test with rawErrors = true (-j flag) + buf.buf.Reset() + rawFormatter := NewSQLCmdDefaultFormatterWithOptions(true, ControlIgnore, true) + rawFormatter.BeginBatch("", v, buf, buf) + + rawFormatter.AddError(mssqlErr) + + rawOutput := buf.buf.String() + assert.NotContains(t, rawOutput, "Msg 208", "Raw mode should NOT include Msg prefix") + assert.NotContains(t, rawOutput, "Level 16", "Raw mode should NOT include Level") + assert.NotContains(t, rawOutput, "State 1", "Raw mode should NOT include State") + assert.Contains(t, rawOutput, "Invalid object name", "Raw mode should still include error message") +} diff --git a/pkg/sqlcmd/server_discovery.go b/pkg/sqlcmd/server_discovery.go new file mode 100644 index 00000000..7ba861be --- /dev/null +++ b/pkg/sqlcmd/server_discovery.go @@ -0,0 +1,131 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "context" + "errors" + "net" + "os" + "strings" + "time" + + "github.com/microsoft/go-mssqldb/msdsn" +) + +const ( + // DefaultBrowserTimeout is the default timeout for SQL Server Browser queries + DefaultBrowserTimeout = 30 * time.Second +) + +// ServerInstance represents a discovered SQL Server instance +type ServerInstance struct { + ServerName string + InstanceName string + IsClustered string + Version string + Port string +} + +// ListServers discovers SQL Server instances on the network using the SQL Server Browser service. +// It sends a UDP broadcast to port 1434 and parses the response. +// Returns a slice of ServerInstance and any error encountered. +func ListServers(timeout time.Duration) ([]ServerInstance, error) { + if timeout == 0 { + timeout = DefaultBrowserTimeout + } + + bmsg := []byte{byte(msdsn.BrowserAllInstances)} + resp := make([]byte, 16*1024-1) + + dialer := &net.Dialer{} + ctx, cancel := context.WithTimeout(context.Background(), timeout) + defer cancel() + + conn, err := dialer.DialContext(ctx, "udp", ":1434") + if err != nil { + return nil, err + } + defer conn.Close() + + dl, _ := ctx.Deadline() + _ = conn.SetDeadline(dl) + + _, err = conn.Write(bmsg) + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + return []ServerInstance{}, nil + } + return nil, err + } + + read, err := conn.Read(resp) + if err != nil { + if errors.Is(err, os.ErrDeadlineExceeded) { + return []ServerInstance{}, nil + } + return nil, err + } + + data := parseInstanceData(resp[:read]) + instances := make([]ServerInstance, 0, len(data)) + + for instName, props := range data { + inst := ServerInstance{ + ServerName: props["ServerName"], + InstanceName: instName, + IsClustered: props["IsClustered"], + Version: props["Version"], + Port: props["tcp"], + } + instances = append(instances, inst) + } + + return instances, nil +} + +// parseInstanceData parses the SQL Server Browser response into a map of instance data +func parseInstanceData(msg []byte) msdsn.BrowserData { + results := msdsn.BrowserData{} + if len(msg) > 3 && msg[0] == 5 { + outS := string(msg[3:]) + tokens := strings.Split(outS, ";") + instdict := map[string]string{} + gotName := false + var name string + for _, token := range tokens { + if gotName { + instdict[name] = token + gotName = false + } else { + name = token + if len(name) == 0 { + if len(instdict) == 0 { + break + } + results[strings.ToUpper(instdict["InstanceName"])] = instdict + instdict = map[string]string{} + continue + } + gotName = true + } + } + } + return results +} + +// FormatServerList formats the list of server instances for display +func FormatServerList(instances []ServerInstance) []string { + result := make([]string, 0, len(instances)*2) + for _, inst := range instances { + if inst.InstanceName == "MSSQLSERVER" { + // Default instance - show both (local) and server name (same as ODBC sqlcmd) + result = append(result, "(local)", inst.ServerName) + } else { + // Named instance + result = append(result, inst.ServerName+"\\"+inst.InstanceName) + } + } + return result +} diff --git a/pkg/sqlcmd/server_discovery_test.go b/pkg/sqlcmd/server_discovery_test.go new file mode 100644 index 00000000..c9ce2a6b --- /dev/null +++ b/pkg/sqlcmd/server_discovery_test.go @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT license. + +package sqlcmd + +import ( + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestParseInstanceData(t *testing.T) { + tests := []struct { + name string + input []byte + expected map[string]map[string]string + }{ + { + name: "single default instance", + // Format: 0x05 (response type), 2 bytes length, then key;value;key;value;; + input: []byte{0x05, 0x00, 0x00, + 'S', 'e', 'r', 'v', 'e', 'r', 'N', 'a', 'm', 'e', ';', 'M', 'Y', 'S', 'E', 'R', 'V', 'E', 'R', ';', + 'I', 'n', 's', 't', 'a', 'n', 'c', 'e', 'N', 'a', 'm', 'e', ';', 'M', 'S', 'S', 'Q', 'L', 'S', 'E', 'R', 'V', 'E', 'R', ';', + 'I', 's', 'C', 'l', 'u', 's', 't', 'e', 'r', 'e', 'd', ';', 'N', 'o', ';', + 'V', 'e', 'r', 's', 'i', 'o', 'n', ';', '1', '5', '.', '0', '.', '2', '0', '0', '0', '.', '5', ';', + 't', 'c', 'p', ';', '1', '4', '3', '3', ';', + ';'}, + expected: map[string]map[string]string{ + "MSSQLSERVER": { + "ServerName": "MYSERVER", + "InstanceName": "MSSQLSERVER", + "IsClustered": "No", + "Version": "15.0.2000.5", + "tcp": "1433", + }, + }, + }, + { + name: "empty response - too short", + input: []byte{0x05, 0x00}, + expected: map[string]map[string]string{}, + }, + { + name: "empty response - wrong type", + input: []byte{0x04, 0x00, 0x00, 'a', ';', 'b', ';', ';'}, + expected: map[string]map[string]string{}, + }, + { + name: "nil input", + input: nil, + expected: map[string]map[string]string{}, + }, + { + name: "empty input", + input: []byte{}, + expected: map[string]map[string]string{}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := parseInstanceData(tc.input) + assert.Equal(t, len(tc.expected), len(result), "Number of instances should match") + for instName, expectedProps := range tc.expected { + actualProps, ok := result[instName] + assert.True(t, ok, "Instance %s should exist", instName) + for key, expectedValue := range expectedProps { + assert.Equal(t, expectedValue, actualProps[key], "Property %s should match", key) + } + } + }) + } +} + +func TestFormatServerList(t *testing.T) { + tests := []struct { + name string + instances []ServerInstance + expected []string + }{ + { + name: "empty list", + instances: []ServerInstance{}, + expected: []string{}, + }, + { + name: "default instance only", + instances: []ServerInstance{ + {ServerName: "MYSERVER", InstanceName: "MSSQLSERVER", IsClustered: "No", Version: "15.0", Port: "1433"}, + }, + expected: []string{"(local)", "MYSERVER"}, + }, + { + name: "named instance only", + instances: []ServerInstance{ + {ServerName: "MYSERVER", InstanceName: "SQL2019", IsClustered: "No", Version: "15.0", Port: "1434"}, + }, + expected: []string{"MYSERVER\\SQL2019"}, + }, + { + name: "multiple instances mixed", + instances: []ServerInstance{ + {ServerName: "SERVER1", InstanceName: "MSSQLSERVER", IsClustered: "No", Version: "15.0", Port: "1433"}, + {ServerName: "SERVER1", InstanceName: "DEV", IsClustered: "No", Version: "14.0", Port: "1435"}, + {ServerName: "SERVER2", InstanceName: "PROD", IsClustered: "Yes", Version: "15.0", Port: "1436"}, + }, + expected: []string{"(local)", "SERVER1", "SERVER1\\DEV", "SERVER2\\PROD"}, + }, + { + name: "named instance with different case preserved", + instances: []ServerInstance{ + {ServerName: "MyServer", InstanceName: "SqlExpress", IsClustered: "No", Version: "15.0", Port: "1433"}, + }, + expected: []string{"MyServer\\SqlExpress"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + result := FormatServerList(tc.instances) + assert.Equal(t, tc.expected, result) + }) + } +} + +func TestFormatServerListPreservesOrder(t *testing.T) { + // Verify that the order of instances is preserved in output + instances := []ServerInstance{ + {ServerName: "ALPHA", InstanceName: "MSSQLSERVER"}, + {ServerName: "BETA", InstanceName: "TEST"}, + } + + result := FormatServerList(instances) + + // Default instance should produce (local) then ServerName + assert.Equal(t, "(local)", result[0]) + assert.Equal(t, "ALPHA", result[1]) + // Named instance should be ServerName\InstanceName + assert.Equal(t, "BETA\\TEST", result[2]) +} diff --git a/pkg/sqlcmd/sqlcmd.go b/pkg/sqlcmd/sqlcmd.go index 5e572a94..d429d8a1 100644 --- a/pkg/sqlcmd/sqlcmd.go +++ b/pkg/sqlcmd/sqlcmd.go @@ -86,6 +86,10 @@ type Sqlcmd struct { UnicodeOutputFile bool // EchoInput tells the GO command to print the batch text before running the query EchoInput bool + // PrintStatistics enables printing of timing statistics after each batch + PrintStatistics bool + // perftrace is the writer for performance trace output (set by :perftrace command) + perftrace io.WriteCloser colorizer color.Colorizer termchan chan os.Signal } @@ -236,6 +240,60 @@ func (s *Sqlcmd) SetError(e io.WriteCloser) { s.err = e } +// GetPerftrace returns the io.Writer to use for performance trace output +func (s *Sqlcmd) GetPerftrace() io.Writer { + if s.perftrace == nil { + return s.GetOutput() + } + return s.perftrace +} + +// SetPerftrace sets the io.WriteCloser to use for performance trace output +func (s *Sqlcmd) SetPerftrace(p io.WriteCloser) { + if s.perftrace != nil && s.perftrace != os.Stderr && s.perftrace != os.Stdout { + s.perftrace.Close() + } + s.perftrace = p +} + +// PrintPerformanceStatistics prints the performance statistics to the perftrace output. +// This matches the ODBC sqlcmd format: network packet size, transaction count, and clock time. +// The function returns early without printing if PrintStatistics is false or numXacts <= 0. +// Output format: "Network packet size (bytes): N\nX xact[s]:\nClock Time (ms.): total T avg A (R xacts per sec.)" +func (s *Sqlcmd) PrintPerformanceStatistics(totalTimeMs int64, numXacts int) { + if !s.PrintStatistics || numXacts <= 0 { + return + } + + perfout := s.GetPerftrace() + + // DefaultPacketSize is the default network packet size in bytes + const DefaultPacketSize = 4096 + + // Get packet size from the connection settings + packetSize := s.Connect.PacketSize + if packetSize == 0 { + packetSize = DefaultPacketSize + } + + // Ensure minimum time of 1ms for calculations + if totalTimeMs < 1 { + totalTimeMs = 1 + } + + avgTimeMs := float64(totalTimeMs) / float64(numXacts) + if avgTimeMs < 1 { + avgTimeMs = 1 + } + xactsPerSec := 1000.0 / avgTimeMs + + // Format matches ODBC sqlcmd MSG_PERF_STATS format: + // "\r\nNetwork packet size (bytes): %s\r\n%d xact[s]:\r\nClock Time (ms.): total %7ld avg %s (%s xacts per sec.)\r\n" + fmt.Fprintf(perfout, "%sNetwork packet size (bytes): %d%s", SqlcmdEol, packetSize, SqlcmdEol) + fmt.Fprintf(perfout, "%d xact[s]:%s", numXacts, SqlcmdEol) + fmt.Fprintf(perfout, "Clock Time (ms.): total %7d avg %.2f (%.2f xacts per sec.)%s", totalTimeMs, avgTimeMs, xactsPerSec, SqlcmdEol) +} + // WriteError writes the error on specified stream func (s *Sqlcmd) WriteError(stream io.Writer, err error) { if serr, ok := err.(SqlcmdError); ok { diff --git a/pkg/sqlcmd/sqlcmd_test.go b/pkg/sqlcmd/sqlcmd_test.go index dfe97d1a..8a049a40 100644 --- a/pkg/sqlcmd/sqlcmd_test.go +++ b/pkg/sqlcmd/sqlcmd_test.go @@ -705,3 +705,92 @@ func TestSqlcmdPrefersSharedMemoryProtocol(t *testing.T) { assert.EqualValuesf(t, "np", msdsn.ProtocolParsers[3].Protocol(), "np should be fourth protocol") } + +func TestPrintPerformanceStatistics(t *testing.T) { + v := InitializeVariables(true) + s := New(nil, "", v) + s.Connect = &ConnectSettings{PacketSize: 4096} + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + s.PrintStatistics = true + + // Test printing statistics + s.PrintPerformanceStatistics(1000, 2) + output := buf.buf.String() + + // Verify the output matches ODBC sqlcmd format + assert.Contains(t, output, "Network packet size (bytes): 4096", "Output should contain packet size") + assert.Contains(t, output, "2 xact[s]:", "Output should contain transaction count") + assert.Contains(t, output, "Clock Time (ms.): total", "Output should contain clock time header") + assert.Contains(t, output, "1000", "Output should contain total time") + assert.Contains(t, output, "500.00", "Output should contain avg time (1000/2)") + assert.Contains(t, output, "xacts per sec", "Output should contain xacts per sec") +} + +func TestPrintPerformanceStatisticsDisabled(t *testing.T) { + v := InitializeVariables(true) + s := New(nil, "", v) + buf := &memoryBuffer{buf: new(bytes.Buffer)} + s.SetOutput(buf) + s.PrintStatistics = false + + // Test that statistics are not printed when disabled + s.PrintPerformanceStatistics(1000, 2) + assert.Empty(t, buf.buf.String(), "No output when PrintStatistics is false") +} + +// TestPrintStatisticsIntegration tests the -p flag with real query execution +func TestPrintStatisticsIntegration(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + s.PrintStatistics = true + + // Execute a real query via GO command + err := runSqlCmd(t, s, []string{"SELECT 1 AS test_column", "GO"}) + assert.NoError(t, err, "runSqlCmd should succeed") + + output := buf.buf.String() + t.Logf("Output:\n%s", output) + + // Verify real statistics were printed after query execution + assert.Contains(t, output, "Network packet size (bytes):", "Should print network packet size") + assert.Contains(t, output, "xact[s]:", "Should print transaction count") + assert.Contains(t, output, "Clock Time (ms.):", "Should print clock time") + assert.Contains(t, output, "xacts per sec", "Should print transactions per second") + + // Verify the query was actually executed (result value is present) + assert.Contains(t, output, "1", "Query result should be in output") + assert.Contains(t, output, "(1 row affected)", "Should show row count") +} + +// TestRawErrorsIntegration tests the -j flag with a real database error +func TestRawErrorsIntegration(t *testing.T) { + s, buf := setupSqlCmdWithMemoryOutput(t) + + // Use the default formatter (NOT raw errors) + s.Format = NewSQLCmdDefaultFormatterWithOptions(true, ControlIgnore, false) + s.Format.BeginBatch("", s.vars, buf, buf) + + // Execute a query that will cause an error + buf.buf.Reset() + err := runSqlCmd(t, s, []string{"SELECT * FROM nonexistent_table_xyz_12345", "GO"}) + // Error is expected to be handled within sqlcmd, not returned + _ = err + + output := buf.buf.String() + // Default mode should include the Msg prefix + assert.Contains(t, output, "Msg", "Default mode should include Msg prefix") + assert.Contains(t, output, "Invalid object name", "Should contain error message") + + // Now test with raw errors (-j flag) + buf.buf.Reset() + s.Format = NewSQLCmdDefaultFormatterWithOptions(true, ControlIgnore, true) + s.Format.BeginBatch("", s.vars, buf, buf) + + err = runSqlCmd(t, s, []string{"SELECT * FROM another_nonexistent_table_abc_67890", "GO"}) + _ = err + + rawOutput := buf.buf.String() + // Raw mode should NOT include the Msg prefix line + assert.NotContains(t, rawOutput, "Msg 208", "Raw mode should NOT include Msg prefix") + assert.Contains(t, rawOutput, "Invalid object name", "Raw mode should still show error message") +}